Files changed (1) hide show
  1. app.py +93 -178
app.py CHANGED
@@ -67,7 +67,6 @@ except Exception as e:
67
  print(f"Failed to load model: {e}")
68
  import traceback
69
  traceback.print_exc()
70
- # Create a dummy model as fallback
71
  model = get_model(NUM_CLASSES).to(DEVICE)
72
  model.eval()
73
  print("Using untrained model as fallback")
@@ -79,7 +78,6 @@ class GradCAM:
79
  self.target_layer = target_layer
80
  self.gradients = None
81
  self.activations = None
82
- self.hook_handles = []
83
 
84
  def save_activation(self, module, input, output):
85
  self.activations = output.detach()
@@ -88,33 +86,23 @@ class GradCAM:
88
  self.gradients = grad_output[0].detach()
89
 
90
  def generate_cam(self, input_image, target_class=None):
91
- # Register hooks
92
  forward_handle = self.target_layer.register_forward_hook(self.save_activation)
93
  backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
94
 
95
  try:
96
- # Forward pass
97
  model_output = self.model(input_image)
98
 
99
  if target_class is None:
100
  target_class = model_output.argmax(dim=1).item()
101
 
102
- # Backward pass
103
  self.model.zero_grad()
104
  class_score = model_output[0, target_class]
105
  class_score.backward(retain_graph=False)
106
 
107
- if self.gradients is None or self.activations is None:
108
- return np.zeros((7, 7)) # Default size for ResNet layer4
109
-
110
  gradients = self.gradients[0]
111
  activations = self.activations[0]
112
-
113
- # Global average pooling of gradients
114
  weights = gradients.mean(dim=(1, 2), keepdim=True)
115
  cam = (weights * activations).sum(dim=0)
116
-
117
- # Apply ReLU and normalize
118
  cam = F.relu(cam)
119
  cam = cam - cam.min()
120
  if cam.max() > 0:
@@ -122,176 +110,113 @@ class GradCAM:
122
 
123
  return cam.detach().cpu().numpy()
124
  finally:
125
- # Remove hooks
126
  forward_handle.remove()
127
  backward_handle.remove()
128
  self.gradients = None
129
  self.activations = None
130
 
131
  def overlay_heatmap(image, heatmap, alpha=0.4):
132
- """Overlay heatmap on original image"""
133
  heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
134
  heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
135
- output = cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
136
- return output
137
 
138
  def predict_galaxy(image):
139
- """Predict galaxy morphology and generate Grad-CAM"""
140
  if image is None:
141
  return None, "Please upload an image."
142
 
143
  if model is None:
144
  return None, "Error: Model not loaded. Please check the logs."
145
 
146
- try:
147
- # Ensure model is in eval mode
148
- model.eval()
149
-
150
- # Convert image to PIL if it's not already
151
- if isinstance(image, np.ndarray):
152
- image = Image.fromarray(image.astype('uint8'))
153
- elif not isinstance(image, Image.Image):
154
- image = Image.open(image) if hasattr(image, 'read') else Image.fromarray(np.array(image))
155
-
156
- # Ensure image is RGB
157
- if image.mode != 'RGB':
158
- image = image.convert('RGB')
159
-
160
- # Preprocess image
161
- img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
162
- img_tensor.requires_grad = True
163
-
164
- # Get prediction
165
- with torch.set_grad_enabled(True):
166
- outputs = model(img_tensor)
167
- probs = F.softmax(outputs, dim=1)
168
- pred_class = outputs.argmax(dim=1).item()
169
- confidence = probs[0][pred_class].item()
170
-
171
- # Generate Grad-CAM
172
- try:
173
- gradcam = GradCAM(model, model.layer4)
174
- cam = gradcam.generate_cam(img_tensor, pred_class)
175
- except Exception as cam_error:
176
- print(f"Grad-CAM error: {cam_error}")
177
- import traceback
178
- traceback.print_exc()
179
- # If Grad-CAM fails, just return the original image
180
- cam = None
181
-
182
- # Prepare original image for overlay
183
- img_np = np.array(image)
184
- img_resized = cv2.resize(img_np, (224, 224))
185
-
186
- # Create overlay if Grad-CAM succeeded
187
- if cam is not None:
188
- try:
189
- overlay = overlay_heatmap(img_resized, cam)
190
- overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
191
- overlay_pil = Image.fromarray(overlay_rgb)
192
- except Exception as overlay_error:
193
- print(f"Overlay error: {overlay_error}")
194
- overlay_pil = image.resize((224, 224))
195
- else:
196
- overlay_pil = image.resize((224, 224))
197
-
198
- # Format results
199
- result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
200
-
201
- # Ensure we return PIL Image
202
- if not isinstance(overlay_pil, Image.Image):
203
- overlay_pil = Image.fromarray(np.array(overlay_pil))
204
-
205
- return overlay_pil, str(result_text)
206
- except Exception as e:
207
- import traceback
208
- error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
209
- print(error_msg) # Print for debugging
210
- return None, f"Error: {str(e)}"
211
 
212
- # Custom CSS for black background and white text
 
 
213
  custom_css = """
214
- .gradio-container {
215
- background-color: #000000 !important;
216
- color: #ffffff !important;
217
- }
218
- body {
219
- background-color: #000000 !important;
220
- color: #ffffff !important;
221
- }
222
- .gradio-container * {
223
- color: #ffffff !important;
224
- }
225
- h1, h2, h3, h4, p, label, span, div {
226
- color: #ffffff !important;
227
- }
228
- .gr-markdown, .gr-markdown * {
229
- color: #ffffff !important;
230
- }
231
- .gr-button {
232
- background-color: #333333 !important;
233
- color: #ffffff !important;
234
- border: 1px solid #555555 !important;
235
- }
236
- .gr-button:hover {
237
- background-color: #555555 !important;
238
- }
239
- .gr-textbox, .gr-textbox input, .gr-textbox textarea {
240
- background-color: #1a1a1a !important;
241
- color: #ffffff !important;
242
- border: 1px solid #555555 !important;
243
- }
244
- .gr-image {
245
- background-color: #000000 !important;
246
- border: none !important;
247
- padding: 0 !important;
248
- margin: 0 !important;
249
- }
250
- .gr-image img {
251
- border: none !important;
252
- box-shadow: none !important;
253
- background-color: #000000 !important;
254
- }
255
- .gr-image-container, .image-container, .image-wrapper {
256
- border: none !important;
257
- background-color: #000000 !important;
258
- padding: 0 !important;
259
- margin: 0 !important;
260
- }
261
- .gr-image .toolbar, .gr-image .image-controls {
262
- display: none !important;
263
- }
264
- .gr-image label, .gr-image .label-wrap {
265
- display: none !important;
266
- }
267
- .gr-box {
268
- border: none !important;
269
- background-color: #000000 !important;
270
- }
271
- .panel, .panel-header {
272
- background-color: #000000 !important;
273
- border: none !important;
274
- }
275
  """
276
 
277
- # Create Gradio interface
278
- # Note: There's a known Gradio bug with API schema generation that causes errors
279
- # The app will still work for classification, but API endpoints may fail
280
  with gr.Blocks(css=custom_css) as demo:
281
- # Landing Section
282
  with gr.Column():
283
- landing_img = gr.Image(value="landing.jpg", height=500, show_label=False, container=False)
284
- landing_text = gr.Markdown("""
285
- <div style="text-align: center; padding: 30px; color: white; background-color: #000000; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
286
- <h1 style="font-size: 96px; font-weight: bold; margin: 0 auto 30px auto; text-align: center; width: 100%;">Galaxy Morphology AI</h1>
287
- <p style="font-size: 56px; font-weight: normal; margin: 0 auto; text-align: center; width: 100%;">Classify galaxies with state-of-the-art deep learning</p>
288
  </div>
289
  """)
290
-
291
- # Spacing between sections
292
  gr.Markdown("<div style='height: 60px;'></div>")
293
-
294
- # How Astrophysicists Use This Section
295
  with gr.Row():
296
  with gr.Column(scale=1):
297
  gr.Markdown("""
@@ -310,16 +235,14 @@ with gr.Blocks(css=custom_css) as demo:
310
  extragalactic astronomy.
311
  """)
312
  with gr.Column(scale=1):
313
- astro_img = gr.Image(value="astro.jpg", show_label=False, container=False, height=400)
314
- gr.Markdown("<p style='text-align: center; color: white; margin-top: 10px;'>Astrophysics Research</p>")
315
-
316
- # Spacing between sections
317
  gr.Markdown("<div style='height: 60px;'></div>")
318
-
319
- # Classification Section
320
  gr.Markdown("# Galaxy Morphology Classification")
321
  gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
322
-
323
  with gr.Row():
324
  with gr.Column():
325
  input_image = gr.Image(label="Upload Galaxy Image")
@@ -327,21 +250,17 @@ with gr.Blocks(css=custom_css) as demo:
327
 
328
  with gr.Column():
329
  output_image = gr.Image(label="Grad-CAM Visualization")
330
- result_text = gr.Textbox(label="Classification Result")
331
-
332
- # Register the classification function
333
- # Disable API to avoid Gradio schema generation bug
334
  classify_btn.click(
335
  fn=predict_galaxy,
336
  inputs=[input_image],
337
  outputs=[output_image, result_text],
338
  api_name=False
339
  )
340
-
341
- # Spacing between sections
342
  gr.Markdown("<div style='height: 60px;'></div>")
343
-
344
- # Dark Energy Section
345
  gr.Markdown("""
346
  # Understanding Dark Energy Through Galaxy Morphology
347
 
@@ -364,13 +283,9 @@ with gr.Blocks(css=custom_css) as demo:
364
  its role in the fate of the universe.
365
  """)
366
 
367
- # Launch the demo
368
- # For Hugging Face Spaces, Gradio will automatically detect and launch the demo
369
- # The API error is a known Gradio bug - the app will still work for classification
370
  if __name__ == "__main__":
371
  try:
372
  demo.launch(show_api=False)
373
- except Exception as e:
374
- # If launch fails, try without API
375
- print(f"Launch error (non-critical): {e}")
376
  demo.launch()
 
67
  print(f"Failed to load model: {e}")
68
  import traceback
69
  traceback.print_exc()
 
70
  model = get_model(NUM_CLASSES).to(DEVICE)
71
  model.eval()
72
  print("Using untrained model as fallback")
 
78
  self.target_layer = target_layer
79
  self.gradients = None
80
  self.activations = None
 
81
 
82
  def save_activation(self, module, input, output):
83
  self.activations = output.detach()
 
86
  self.gradients = grad_output[0].detach()
87
 
88
  def generate_cam(self, input_image, target_class=None):
 
89
  forward_handle = self.target_layer.register_forward_hook(self.save_activation)
90
  backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
91
 
92
  try:
 
93
  model_output = self.model(input_image)
94
 
95
  if target_class is None:
96
  target_class = model_output.argmax(dim=1).item()
97
 
 
98
  self.model.zero_grad()
99
  class_score = model_output[0, target_class]
100
  class_score.backward(retain_graph=False)
101
 
 
 
 
102
  gradients = self.gradients[0]
103
  activations = self.activations[0]
 
 
104
  weights = gradients.mean(dim=(1, 2), keepdim=True)
105
  cam = (weights * activations).sum(dim=0)
 
 
106
  cam = F.relu(cam)
107
  cam = cam - cam.min()
108
  if cam.max() > 0:
 
110
 
111
  return cam.detach().cpu().numpy()
112
  finally:
 
113
  forward_handle.remove()
114
  backward_handle.remove()
115
  self.gradients = None
116
  self.activations = None
117
 
118
  def overlay_heatmap(image, heatmap, alpha=0.4):
 
119
  heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
120
  heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
121
+ return cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
 
122
 
123
  def predict_galaxy(image):
 
124
  if image is None:
125
  return None, "Please upload an image."
126
 
127
  if model is None:
128
  return None, "Error: Model not loaded. Please check the logs."
129
 
130
+ model.eval()
131
+
132
+ if isinstance(image, np.ndarray):
133
+ image = Image.fromarray(image.astype("uint8"))
134
+ elif not isinstance(image, Image.Image):
135
+ image = Image.open(image)
136
+
137
+ if image.mode != "RGB":
138
+ image = image.convert("RGB")
139
+
140
+ img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
141
+ img_tensor.requires_grad = True
142
+
143
+ outputs = model(img_tensor)
144
+ probs = F.softmax(outputs, dim=1)
145
+ pred_class = outputs.argmax(dim=1).item()
146
+ confidence = probs[0][pred_class].item()
147
+
148
+ gradcam = GradCAM(model, model.layer4)
149
+ cam = gradcam.generate_cam(img_tensor, pred_class)
150
+
151
+ img_np = np.array(image.resize((224, 224)))
152
+ overlay = overlay_heatmap(img_np, cam)
153
+ overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
154
+ overlay_pil = Image.fromarray(overlay_rgb)
155
+
156
+ result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
157
+
158
+ return overlay_pil, result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # =========================
161
+ # Custom CSS
162
+ # =========================
163
  custom_css = """
164
+ .gradio-container {
165
+ background-color: #000000 !important;
166
+ color: #ffffff !important;
167
+ }
168
+ body {
169
+ background-color: #000000 !important;
170
+ color: #ffffff !important;
171
+ }
172
+
173
+ /* 🔴 FIX 1: REMOVED unsafe global selector */
174
+ /* .gradio-container * { color: #ffffff !important; } */
175
+
176
+ h1, h2, h3, h4, p, label, span, div {
177
+ color: #ffffff !important;
178
+ }
179
+ .gr-markdown, .gr-markdown * {
180
+ color: #ffffff !important;
181
+ }
182
+ .gr-button {
183
+ background-color: #333333 !important;
184
+ color: #ffffff !important;
185
+ border: 1px solid #555555 !important;
186
+ }
187
+ .gr-button:hover {
188
+ background-color: #555555 !important;
189
+ }
190
+ .gr-textbox, .gr-textbox input, .gr-textbox textarea {
191
+ background-color: #1a1a1a !important;
192
+ color: #ffffff !important;
193
+ border: 1px solid #555555 !important;
194
+ }
195
+ .gr-image {
196
+ background-color: #000000 !important;
197
+ border: none !important;
198
+ }
199
+ .gr-image img {
200
+ background-color: #000000 !important;
201
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  """
203
 
204
+ # =========================
205
+ # UI
206
+ # =========================
207
  with gr.Blocks(css=custom_css) as demo:
208
+
209
  with gr.Column():
210
+ gr.Image(value="landing.jpg", height=500, show_label=False, container=False)
211
+ gr.Markdown("""
212
+ <div style="text-align: center; padding: 30px;">
213
+ <h1 style="font-size: 96px; font-weight: bold;">Galaxy Morphology AI</h1>
214
+ <p style="font-size: 56px;">Classify galaxies with state-of-the-art deep learning</p>
215
  </div>
216
  """)
217
+
 
218
  gr.Markdown("<div style='height: 60px;'></div>")
219
+
 
220
  with gr.Row():
221
  with gr.Column(scale=1):
222
  gr.Markdown("""
 
235
  extragalactic astronomy.
236
  """)
237
  with gr.Column(scale=1):
238
+ gr.Image(value="astro.jpg", show_label=False, container=False, height=400)
239
+ gr.Markdown("<p style='text-align: center;'>Astrophysics Research</p>")
240
+
 
241
  gr.Markdown("<div style='height: 60px;'></div>")
242
+
 
243
  gr.Markdown("# Galaxy Morphology Classification")
244
  gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
245
+
246
  with gr.Row():
247
  with gr.Column():
248
  input_image = gr.Image(label="Upload Galaxy Image")
 
250
 
251
  with gr.Column():
252
  output_image = gr.Image(label="Grad-CAM Visualization")
253
+ result_text = gr.Markdown() # 🔴 FIX 2: Textbox → Markdown (read-only)
254
+
 
 
255
  classify_btn.click(
256
  fn=predict_galaxy,
257
  inputs=[input_image],
258
  outputs=[output_image, result_text],
259
  api_name=False
260
  )
261
+
 
262
  gr.Markdown("<div style='height: 60px;'></div>")
263
+
 
264
  gr.Markdown("""
265
  # Understanding Dark Energy Through Galaxy Morphology
266
 
 
283
  its role in the fate of the universe.
284
  """)
285
 
286
+ # Launch
 
 
287
  if __name__ == "__main__":
288
  try:
289
  demo.launch(show_api=False)
290
+ except Exception:
 
 
291
  demo.launch()