Files changed (1) hide show
  1. app.py +146 -324
app.py CHANGED
@@ -8,369 +8,191 @@ import numpy as np
8
  import cv2
9
  import os
10
 
11
- # Workaround for Gradio API schema bug
12
- # Monkey-patch to handle the schema generation error gracefully
13
- try:
14
- import gradio_client.utils as client_utils
15
- original_get_type = client_utils.get_type
16
-
17
- def patched_get_type(schema):
18
- if isinstance(schema, bool):
19
- return "bool"
20
- return original_get_type(schema)
21
-
22
- client_utils.get_type = patched_get_type
23
- except:
24
- pass # If patching fails, continue anyway
25
-
26
  # Model configuration
 
27
  MODEL_PATH = "robust_galaxy_model.pth"
28
- NUM_CLASSES = 2
29
  CLASS_NAMES = ["Elliptical", "Spiral"]
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
- # Image preprocessing
 
 
33
  preprocess = transforms.Compose([
34
  transforms.Resize((224, 224)),
35
  transforms.ToTensor(),
36
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
 
 
37
  ])
38
 
39
- # Load model
 
 
40
  def get_model(num_classes=2):
41
  model = models.resnet18(weights=None)
42
  model.fc = nn.Linear(model.fc.in_features, num_classes)
43
  return model
44
 
45
  def load_model():
46
- model = get_model(NUM_CLASSES)
47
  if os.path.exists(MODEL_PATH):
48
- try:
49
- state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
50
- model.load_state_dict(state_dict)
51
- print(f"Model loaded successfully from {MODEL_PATH}")
52
- except Exception as e:
53
- print(f"Error loading model: {e}")
54
- print("Using untrained model")
55
  else:
56
- print(f"Model file not found at {MODEL_PATH}. Using untrained model.")
57
  model.to(DEVICE)
58
  model.eval()
59
  return model
60
 
61
- # Load model - handle errors gracefully
62
- model = None
63
- try:
64
- model = load_model()
65
- print("Model loaded successfully")
66
- except Exception as e:
67
- print(f"Failed to load model: {e}")
68
- import traceback
69
- traceback.print_exc()
70
- # Create a dummy model as fallback
71
- model = get_model(NUM_CLASSES).to(DEVICE)
72
- model.eval()
73
- print("Using untrained model as fallback")
74
 
75
- # Grad-CAM implementation
 
 
76
  class GradCAM:
77
  def __init__(self, model, target_layer):
78
  self.model = model
79
  self.target_layer = target_layer
80
  self.gradients = None
81
  self.activations = None
82
- self.hook_handles = []
83
-
84
  def save_activation(self, module, input, output):
85
  self.activations = output.detach()
86
-
87
  def save_gradient(self, module, grad_input, grad_output):
88
  self.gradients = grad_output[0].detach()
89
-
90
- def generate_cam(self, input_image, target_class=None):
91
- # Register hooks
92
- forward_handle = self.target_layer.register_forward_hook(self.save_activation)
93
- backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
94
-
95
- try:
96
- # Forward pass
97
- model_output = self.model(input_image)
98
-
99
- if target_class is None:
100
- target_class = model_output.argmax(dim=1).item()
101
-
102
- # Backward pass
103
- self.model.zero_grad()
104
- class_score = model_output[0, target_class]
105
- class_score.backward(retain_graph=False)
106
-
107
- if self.gradients is None or self.activations is None:
108
- return np.zeros((7, 7)) # Default size for ResNet layer4
109
-
110
- gradients = self.gradients[0]
111
- activations = self.activations[0]
112
-
113
- # Global average pooling of gradients
114
- weights = gradients.mean(dim=(1, 2), keepdim=True)
115
- cam = (weights * activations).sum(dim=0)
116
-
117
- # Apply ReLU and normalize
118
- cam = F.relu(cam)
119
- cam = cam - cam.min()
120
- if cam.max() > 0:
121
- cam = cam / cam.max()
122
-
123
- return cam.detach().cpu().numpy()
124
- finally:
125
- # Remove hooks
126
- forward_handle.remove()
127
- backward_handle.remove()
128
- self.gradients = None
129
- self.activations = None
130
-
131
- def overlay_heatmap(image, heatmap, alpha=0.4):
132
- """Overlay heatmap on original image"""
133
- heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
134
- heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
135
- output = cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
136
- return output
137
-
138
- def predict_galaxy(image):
139
- """Predict galaxy morphology and generate Grad-CAM"""
140
  if image is None:
141
- return None, "Please upload an image."
142
-
143
- if model is None:
144
- return None, "Error: Model not loaded. Please check the logs."
145
-
146
- try:
147
- # Ensure model is in eval mode
148
- model.eval()
149
-
150
- # Convert image to PIL if it's not already
151
- if isinstance(image, np.ndarray):
152
- image = Image.fromarray(image.astype('uint8'))
153
- elif not isinstance(image, Image.Image):
154
- image = Image.open(image) if hasattr(image, 'read') else Image.fromarray(np.array(image))
155
-
156
- # Ensure image is RGB
157
- if image.mode != 'RGB':
158
- image = image.convert('RGB')
159
-
160
- # Preprocess image
161
- img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
162
- img_tensor.requires_grad = True
163
-
164
- # Get prediction
165
- with torch.set_grad_enabled(True):
166
- outputs = model(img_tensor)
167
- probs = F.softmax(outputs, dim=1)
168
- pred_class = outputs.argmax(dim=1).item()
169
- confidence = probs[0][pred_class].item()
170
-
171
- # Generate Grad-CAM
172
- try:
173
- gradcam = GradCAM(model, model.layer4)
174
- cam = gradcam.generate_cam(img_tensor, pred_class)
175
- except Exception as cam_error:
176
- print(f"Grad-CAM error: {cam_error}")
177
- import traceback
178
- traceback.print_exc()
179
- # If Grad-CAM fails, just return the original image
180
- cam = None
181
-
182
- # Prepare original image for overlay
183
- img_np = np.array(image)
184
- img_resized = cv2.resize(img_np, (224, 224))
185
-
186
- # Create overlay if Grad-CAM succeeded
187
- if cam is not None:
188
- try:
189
- overlay = overlay_heatmap(img_resized, cam)
190
- overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
191
- overlay_pil = Image.fromarray(overlay_rgb)
192
- except Exception as overlay_error:
193
- print(f"Overlay error: {overlay_error}")
194
- overlay_pil = image.resize((224, 224))
195
- else:
196
- overlay_pil = image.resize((224, 224))
197
-
198
- # Format results
199
- result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
200
-
201
- # Ensure we return PIL Image
202
- if not isinstance(overlay_pil, Image.Image):
203
- overlay_pil = Image.fromarray(np.array(overlay_pil))
204
-
205
- return overlay_pil, str(result_text)
206
- except Exception as e:
207
- import traceback
208
- error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
209
- print(error_msg) # Print for debugging
210
- return None, f"Error: {str(e)}"
211
-
212
- # Custom CSS for black background and white text
213
  custom_css = """
214
- .gradio-container {
215
- background-color: #000000 !important;
216
- color: #ffffff !important;
217
- }
218
- body {
219
- background-color: #000000 !important;
220
- color: #ffffff !important;
221
- }
222
- .gradio-container * {
223
- color: #ffffff !important;
224
- }
225
- h1, h2, h3, h4, p, label, span, div {
226
- color: #ffffff !important;
227
- }
228
- .gr-markdown, .gr-markdown * {
229
- color: #ffffff !important;
230
- }
231
- .gr-button {
232
- background-color: #333333 !important;
233
- color: #ffffff !important;
234
- border: 1px solid #555555 !important;
235
- }
236
- .gr-button:hover {
237
- background-color: #555555 !important;
238
- }
239
- .gr-textbox, .gr-textbox input, .gr-textbox textarea {
240
- background-color: #1a1a1a !important;
241
- color: #ffffff !important;
242
- border: 1px solid #555555 !important;
243
- }
244
- .gr-image {
245
- background-color: #000000 !important;
246
- border: none !important;
247
- padding: 0 !important;
248
- margin: 0 !important;
249
- }
250
- .gr-image img {
251
- border: none !important;
252
- box-shadow: none !important;
253
- background-color: #000000 !important;
254
- }
255
- .gr-image-container, .image-container, .image-wrapper {
256
- border: none !important;
257
- background-color: #000000 !important;
258
- padding: 0 !important;
259
- margin: 0 !important;
260
- }
261
- .gr-image .toolbar, .gr-image .image-controls {
262
- display: none !important;
263
- }
264
- .gr-image label, .gr-image .label-wrap {
265
- display: none !important;
266
- }
267
- .gr-box {
268
- border: none !important;
269
- background-color: #000000 !important;
270
- }
271
- .panel, .panel-header {
272
- background-color: #000000 !important;
273
- border: none !important;
274
- }
275
  """
276
 
277
- # Create Gradio interface
278
- # Note: There's a known Gradio bug with API schema generation that causes errors
279
- # The app will still work for classification, but API endpoints may fail
280
  with gr.Blocks(css=custom_css) as demo:
281
- # Landing Section
282
- with gr.Column():
283
- landing_img = gr.Image(value="landing.jpg", height=500, show_label=False, container=False)
284
- landing_text = gr.Markdown("""
285
- <div style="text-align: center; padding: 30px; color: white; background-color: #000000; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
286
- <h1 style="font-size: 96px; font-weight: bold; margin: 0 auto 30px auto; text-align: center; width: 100%;">Galaxy Morphology AI</h1>
287
- <p style="font-size: 56px; font-weight: normal; margin: 0 auto; text-align: center; width: 100%;">Classify galaxies with state-of-the-art deep learning</p>
288
- </div>
289
- """)
290
-
291
- # Spacing between sections
292
- gr.Markdown("<div style='height: 60px;'></div>")
293
-
294
- # How Astrophysicists Use This Section
295
- with gr.Row():
296
- with gr.Column(scale=1):
297
- gr.Markdown("""
298
- # How Astrophysicists Use This
299
-
300
- Galaxy morphology classification is a fundamental tool in modern astrophysics.
301
- By automatically identifying whether a galaxy is elliptical or spiral, researchers
302
- can analyze large datasets from telescopes like the Hubble Space Telescope and
303
- the James Webb Space Telescope. This classification helps understand galaxy
304
- formation, evolution, and the distribution of matter in the universe.
305
-
306
- The deep learning model uses convolutional neural networks to identify key
307
- features in galaxy images, such as spiral arms, central bulges, and overall
308
- structure. This automated classification enables astronomers to process millions
309
- of galaxy images efficiently, accelerating discoveries in cosmology and
310
- extragalactic astronomy.
311
- """)
312
- with gr.Column(scale=1):
313
- astro_img = gr.Image(value="astro.jpg", show_label=False, container=False, height=400)
314
- gr.Markdown("<p style='text-align: center; color: white; margin-top: 10px;'>Astrophysics Research</p>")
315
-
316
- # Spacing between sections
317
- gr.Markdown("<div style='height: 60px;'></div>")
318
-
319
- # Classification Section
320
- gr.Markdown("# Galaxy Morphology Classification")
321
- gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
322
-
323
  with gr.Row():
324
- with gr.Column():
325
- input_image = gr.Image(label="Upload Galaxy Image")
326
- classify_btn = gr.Button("Classify Galaxy")
327
-
328
- with gr.Column():
329
- output_image = gr.Image(label="Grad-CAM Visualization")
330
- result_text = gr.Textbox(label="Classification Result")
331
-
332
- # Register the classification function
333
- # Disable API to avoid Gradio schema generation bug
 
 
 
 
 
 
 
334
  classify_btn.click(
335
- fn=predict_galaxy,
336
- inputs=[input_image],
337
- outputs=[output_image, result_text],
338
- api_name=False
339
  )
340
-
341
- # Spacing between sections
342
- gr.Markdown("<div style='height: 60px;'></div>")
343
-
344
- # Dark Energy Section
345
- gr.Markdown("""
346
- # Understanding Dark Energy Through Galaxy Morphology
347
-
348
- Galaxy morphology classification plays a crucial role in understanding dark energy,
349
- one of the most profound mysteries in modern cosmology. Dark energy is the
350
- mysterious force driving the accelerated expansion of the universe, and its nature
351
- remains one of the biggest questions in physics.
352
-
353
- By classifying large numbers of galaxies and mapping their distribution across
354
- cosmic time, astronomers can trace the expansion history of the universe.
355
- Different galaxy types (elliptical vs spiral) form and evolve differently, and
356
- their relative abundances at different redshifts provide clues about the universe's
357
- evolution. The distribution and clustering of these galaxies help measure the
358
- large-scale structure of the universe, which is directly influenced by dark energy.
359
-
360
- Automated classification systems like this one enable the analysis of millions of
361
- galaxies from current and future surveys, such as the Vera C. Rubin Observatory's
362
- Legacy Survey of Space and Time (LSST). These massive datasets will provide
363
- unprecedented precision in measuring dark energy's properties and understanding
364
- its role in the fate of the universe.
365
- """)
366
-
367
- # Launch the demo
368
- # For Hugging Face Spaces, Gradio will automatically detect and launch the demo
369
- # The API error is a known Gradio bug - the app will still work for classification
370
  if __name__ == "__main__":
371
- try:
372
- demo.launch(show_api=False)
373
- except Exception as e:
374
- # If launch fails, try without API
375
- print(f"Launch error (non-critical): {e}")
376
- demo.launch()
 
8
  import cv2
9
  import os
10
 
11
+ # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Model configuration
13
+ # ======================
14
  MODEL_PATH = "robust_galaxy_model.pth"
 
15
  CLASS_NAMES = ["Elliptical", "Spiral"]
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ # ======================
19
+ # Preprocessing
20
+ # ======================
21
  preprocess = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
  transforms.ToTensor(),
24
+ transforms.Normalize(
25
+ mean=[0.485, 0.456, 0.406],
26
+ std=[0.229, 0.224, 0.225]
27
+ )
28
  ])
29
 
30
+ # ======================
31
+ # Model loading
32
+ # ======================
33
  def get_model(num_classes=2):
34
  model = models.resnet18(weights=None)
35
  model.fc = nn.Linear(model.fc.in_features, num_classes)
36
  return model
37
 
38
  def load_model():
39
+ model = get_model()
40
  if os.path.exists(MODEL_PATH):
41
+ state_dict = torch.load(MODEL_PATH, map_location="cpu")
42
+ model.load_state_dict(state_dict)
43
+ print("✅ Model loaded")
 
 
 
 
44
  else:
45
+ print("⚠️ Model not found, using untrained model")
46
  model.to(DEVICE)
47
  model.eval()
48
  return model
49
 
50
+ model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # ======================
53
+ # Grad-CAM
54
+ # ======================
55
  class GradCAM:
56
  def __init__(self, model, target_layer):
57
  self.model = model
58
  self.target_layer = target_layer
59
  self.gradients = None
60
  self.activations = None
61
+
 
62
  def save_activation(self, module, input, output):
63
  self.activations = output.detach()
64
+
65
  def save_gradient(self, module, grad_input, grad_output):
66
  self.gradients = grad_output[0].detach()
67
+
68
+ def generate(self, x, class_idx):
69
+ h1 = self.target_layer.register_forward_hook(self.save_activation)
70
+ h2 = self.target_layer.register_full_backward_hook(self.save_gradient)
71
+
72
+ out = self.model(x)
73
+ score = out[0, class_idx]
74
+ self.model.zero_grad()
75
+ score.backward()
76
+
77
+ weights = self.gradients.mean(dim=(2, 3), keepdim=True)
78
+ cam = (weights * self.activations).sum(dim=1)
79
+ cam = F.relu(cam)
80
+ cam = cam - cam.min()
81
+ cam = cam / cam.max()
82
+
83
+ h1.remove()
84
+ h2.remove()
85
+
86
+ return cam[0].cpu().numpy()
87
+
88
+ def overlay_heatmap(img, cam):
89
+ cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
90
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
91
+ return cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
92
+
93
+ # ======================
94
+ # Prediction function
95
+ # ======================
96
+ def predict(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if image is None:
98
+ return None, "⚠️ Please upload an image."
99
+
100
+ if not isinstance(image, Image.Image):
101
+ image = Image.fromarray(image)
102
+
103
+ image = image.convert("RGB")
104
+ x = preprocess(image).unsqueeze(0).to(DEVICE)
105
+ x.requires_grad_(True)
106
+
107
+ outputs = model(x)
108
+ probs = F.softmax(outputs, dim=1)[0]
109
+ pred_idx = probs.argmax().item()
110
+ confidence = probs[pred_idx].item()
111
+
112
+ cam = GradCAM(model, model.layer4).generate(x, pred_idx)
113
+
114
+ img_np = np.array(image.resize((224, 224)))
115
+ overlay = overlay_heatmap(img_np, cam)
116
+ overlay = Image.fromarray(overlay)
117
+
118
+ result_md = f"""
119
+ ### 🌌 Prediction
120
+ **Class:** `{CLASS_NAMES[pred_idx]}`
121
+ **Confidence:** `{confidence*100:.2f}%`
122
+
123
+ **Class Probabilities**
124
+ - Elliptical: `{probs[0]*100:.2f}%`
125
+ - Spiral: `{probs[1]*100:.2f}%`
126
+ """
127
+
128
+ return overlay, result_md
129
+
130
+ # ======================
131
+ # Clean Dark UI CSS
132
+ # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  custom_css = """
134
+ body, .gradio-container {
135
+ background-color: #000000;
136
+ color: #ffffff;
137
+ }
138
+
139
+ .gr-image, .gr-image img {
140
+ background: #000000 !important;
141
+ border: none !important;
142
+ }
143
+
144
+ .gr-button {
145
+ background-color: #222 !important;
146
+ color: white !important;
147
+ border-radius: 8px;
148
+ }
149
+
150
+ .gr-button:hover {
151
+ background-color: #444 !important;
152
+ }
153
+
154
+ .gr-markdown {
155
+ color: white !important;
156
+ }
157
+
158
+ footer { display: none !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  """
160
 
161
+ # ======================
162
+ # UI
163
+ # ======================
164
  with gr.Blocks(css=custom_css) as demo:
165
+ gr.Markdown(
166
+ "# 🌌 Galaxy Morphology Classification\n"
167
+ "Upload a galaxy image to classify its morphology and visualize model attention."
168
+ )
169
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  with gr.Row():
171
+ input_img = gr.Image(
172
+ label=None,
173
+ type="pil",
174
+ show_label=False,
175
+ container=False
176
+ )
177
+
178
+ output_img = gr.Image(
179
+ label=None,
180
+ show_label=False,
181
+ container=False
182
+ )
183
+
184
+ result_md = gr.Markdown()
185
+
186
+ classify_btn = gr.Button("Classify Galaxy")
187
+
188
  classify_btn.click(
189
+ fn=predict,
190
+ inputs=input_img,
191
+ outputs=[output_img, result_md]
 
192
  )
193
+
194
+ # ======================
195
+ # Launch
196
+ # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  if __name__ == "__main__":
198
+ demo.launch()