Files changed (1) hide show
  1. app.py +324 -146
app.py CHANGED
@@ -8,191 +8,369 @@ import numpy as np
8
  import cv2
9
  import os
10
 
11
- # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Model configuration
13
- # ======================
14
  MODEL_PATH = "robust_galaxy_model.pth"
 
15
  CLASS_NAMES = ["Elliptical", "Spiral"]
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
- # ======================
19
- # Preprocessing
20
- # ======================
21
  preprocess = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
  transforms.ToTensor(),
24
- transforms.Normalize(
25
- mean=[0.485, 0.456, 0.406],
26
- std=[0.229, 0.224, 0.225]
27
- )
28
  ])
29
 
30
- # ======================
31
- # Model loading
32
- # ======================
33
  def get_model(num_classes=2):
34
  model = models.resnet18(weights=None)
35
  model.fc = nn.Linear(model.fc.in_features, num_classes)
36
  return model
37
 
38
  def load_model():
39
- model = get_model()
40
  if os.path.exists(MODEL_PATH):
41
- state_dict = torch.load(MODEL_PATH, map_location="cpu")
42
- model.load_state_dict(state_dict)
43
- print("✅ Model loaded")
 
 
 
 
44
  else:
45
- print("⚠️ Model not found, using untrained model")
46
  model.to(DEVICE)
47
  model.eval()
48
  return model
49
 
50
- model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # ======================
53
- # Grad-CAM
54
- # ======================
55
  class GradCAM:
56
  def __init__(self, model, target_layer):
57
  self.model = model
58
  self.target_layer = target_layer
59
  self.gradients = None
60
  self.activations = None
61
-
 
62
  def save_activation(self, module, input, output):
63
  self.activations = output.detach()
64
-
65
  def save_gradient(self, module, grad_input, grad_output):
66
  self.gradients = grad_output[0].detach()
67
-
68
- def generate(self, x, class_idx):
69
- h1 = self.target_layer.register_forward_hook(self.save_activation)
70
- h2 = self.target_layer.register_full_backward_hook(self.save_gradient)
71
-
72
- out = self.model(x)
73
- score = out[0, class_idx]
74
- self.model.zero_grad()
75
- score.backward()
76
-
77
- weights = self.gradients.mean(dim=(2, 3), keepdim=True)
78
- cam = (weights * self.activations).sum(dim=1)
79
- cam = F.relu(cam)
80
- cam = cam - cam.min()
81
- cam = cam / cam.max()
82
-
83
- h1.remove()
84
- h2.remove()
85
-
86
- return cam[0].cpu().numpy()
87
-
88
- def overlay_heatmap(img, cam):
89
- cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
90
- heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
91
- return cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
92
-
93
- # ======================
94
- # Prediction function
95
- # ======================
96
- def predict(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if image is None:
98
- return None, "⚠️ Please upload an image."
99
-
100
- if not isinstance(image, Image.Image):
101
- image = Image.fromarray(image)
102
-
103
- image = image.convert("RGB")
104
- x = preprocess(image).unsqueeze(0).to(DEVICE)
105
- x.requires_grad_(True)
106
-
107
- outputs = model(x)
108
- probs = F.softmax(outputs, dim=1)[0]
109
- pred_idx = probs.argmax().item()
110
- confidence = probs[pred_idx].item()
111
-
112
- cam = GradCAM(model, model.layer4).generate(x, pred_idx)
113
-
114
- img_np = np.array(image.resize((224, 224)))
115
- overlay = overlay_heatmap(img_np, cam)
116
- overlay = Image.fromarray(overlay)
117
-
118
- result_md = f"""
119
- ### 🌌 Prediction
120
- **Class:** `{CLASS_NAMES[pred_idx]}`
121
- **Confidence:** `{confidence*100:.2f}%`
122
-
123
- **Class Probabilities**
124
- - Elliptical: `{probs[0]*100:.2f}%`
125
- - Spiral: `{probs[1]*100:.2f}%`
126
- """
127
-
128
- return overlay, result_md
129
-
130
- # ======================
131
- # Clean Dark UI CSS
132
- # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  custom_css = """
134
- body, .gradio-container {
135
- background-color: #000000;
136
- color: #ffffff;
137
- }
138
-
139
- .gr-image, .gr-image img {
140
- background: #000000 !important;
141
- border: none !important;
142
- }
143
-
144
- .gr-button {
145
- background-color: #222 !important;
146
- color: white !important;
147
- border-radius: 8px;
148
- }
149
-
150
- .gr-button:hover {
151
- background-color: #444 !important;
152
- }
153
-
154
- .gr-markdown {
155
- color: white !important;
156
- }
157
-
158
- footer { display: none !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  """
160
 
161
- # ======================
162
- # UI
163
- # ======================
164
  with gr.Blocks(css=custom_css) as demo:
165
- gr.Markdown(
166
- "# 🌌 Galaxy Morphology Classification\n"
167
- "Upload a galaxy image to classify its morphology and visualize model attention."
168
- )
169
-
 
 
 
 
 
 
 
 
 
170
  with gr.Row():
171
- input_img = gr.Image(
172
- label=None,
173
- type="pil",
174
- show_label=False,
175
- container=False
176
- )
177
-
178
- output_img = gr.Image(
179
- label=None,
180
- show_label=False,
181
- container=False
182
- )
183
-
184
- result_md = gr.Markdown()
185
-
186
- classify_btn = gr.Button("Classify Galaxy")
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  classify_btn.click(
189
- fn=predict,
190
- inputs=input_img,
191
- outputs=[output_img, result_md]
 
192
  )
193
-
194
- # ======================
195
- # Launch
196
- # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  if __name__ == "__main__":
198
- demo.launch()
 
 
 
 
 
 
8
  import cv2
9
  import os
10
 
11
+ # Workaround for Gradio API schema bug
12
+ # Monkey-patch to handle the schema generation error gracefully
13
+ try:
14
+ import gradio_client.utils as client_utils
15
+ original_get_type = client_utils.get_type
16
+
17
+ def patched_get_type(schema):
18
+ if isinstance(schema, bool):
19
+ return "bool"
20
+ return original_get_type(schema)
21
+
22
+ client_utils.get_type = patched_get_type
23
+ except:
24
+ pass # If patching fails, continue anyway
25
+
26
  # Model configuration
 
27
  MODEL_PATH = "robust_galaxy_model.pth"
28
+ NUM_CLASSES = 2
29
  CLASS_NAMES = ["Elliptical", "Spiral"]
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
+ # Image preprocessing
 
 
33
  preprocess = transforms.Compose([
34
  transforms.Resize((224, 224)),
35
  transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
 
 
37
  ])
38
 
39
+ # Load model
 
 
40
  def get_model(num_classes=2):
41
  model = models.resnet18(weights=None)
42
  model.fc = nn.Linear(model.fc.in_features, num_classes)
43
  return model
44
 
45
  def load_model():
46
+ model = get_model(NUM_CLASSES)
47
  if os.path.exists(MODEL_PATH):
48
+ try:
49
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
50
+ model.load_state_dict(state_dict)
51
+ print(f"Model loaded successfully from {MODEL_PATH}")
52
+ except Exception as e:
53
+ print(f"Error loading model: {e}")
54
+ print("Using untrained model")
55
  else:
56
+ print(f"Model file not found at {MODEL_PATH}. Using untrained model.")
57
  model.to(DEVICE)
58
  model.eval()
59
  return model
60
 
61
+ # Load model - handle errors gracefully
62
+ model = None
63
+ try:
64
+ model = load_model()
65
+ print("Model loaded successfully")
66
+ except Exception as e:
67
+ print(f"Failed to load model: {e}")
68
+ import traceback
69
+ traceback.print_exc()
70
+ # Create a dummy model as fallback
71
+ model = get_model(NUM_CLASSES).to(DEVICE)
72
+ model.eval()
73
+ print("Using untrained model as fallback")
74
 
75
+ # Grad-CAM implementation
 
 
76
  class GradCAM:
77
  def __init__(self, model, target_layer):
78
  self.model = model
79
  self.target_layer = target_layer
80
  self.gradients = None
81
  self.activations = None
82
+ self.hook_handles = []
83
+
84
  def save_activation(self, module, input, output):
85
  self.activations = output.detach()
86
+
87
  def save_gradient(self, module, grad_input, grad_output):
88
  self.gradients = grad_output[0].detach()
89
+
90
+ def generate_cam(self, input_image, target_class=None):
91
+ # Register hooks
92
+ forward_handle = self.target_layer.register_forward_hook(self.save_activation)
93
+ backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
94
+
95
+ try:
96
+ # Forward pass
97
+ model_output = self.model(input_image)
98
+
99
+ if target_class is None:
100
+ target_class = model_output.argmax(dim=1).item()
101
+
102
+ # Backward pass
103
+ self.model.zero_grad()
104
+ class_score = model_output[0, target_class]
105
+ class_score.backward(retain_graph=False)
106
+
107
+ if self.gradients is None or self.activations is None:
108
+ return np.zeros((7, 7)) # Default size for ResNet layer4
109
+
110
+ gradients = self.gradients[0]
111
+ activations = self.activations[0]
112
+
113
+ # Global average pooling of gradients
114
+ weights = gradients.mean(dim=(1, 2), keepdim=True)
115
+ cam = (weights * activations).sum(dim=0)
116
+
117
+ # Apply ReLU and normalize
118
+ cam = F.relu(cam)
119
+ cam = cam - cam.min()
120
+ if cam.max() > 0:
121
+ cam = cam / cam.max()
122
+
123
+ return cam.detach().cpu().numpy()
124
+ finally:
125
+ # Remove hooks
126
+ forward_handle.remove()
127
+ backward_handle.remove()
128
+ self.gradients = None
129
+ self.activations = None
130
+
131
+ def overlay_heatmap(image, heatmap, alpha=0.4):
132
+ """Overlay heatmap on original image"""
133
+ heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
134
+ heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
135
+ output = cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
136
+ return output
137
+
138
+ def predict_galaxy(image):
139
+ """Predict galaxy morphology and generate Grad-CAM"""
140
  if image is None:
141
+ return None, "Please upload an image."
142
+
143
+ if model is None:
144
+ return None, "Error: Model not loaded. Please check the logs."
145
+
146
+ try:
147
+ # Ensure model is in eval mode
148
+ model.eval()
149
+
150
+ # Convert image to PIL if it's not already
151
+ if isinstance(image, np.ndarray):
152
+ image = Image.fromarray(image.astype('uint8'))
153
+ elif not isinstance(image, Image.Image):
154
+ image = Image.open(image) if hasattr(image, 'read') else Image.fromarray(np.array(image))
155
+
156
+ # Ensure image is RGB
157
+ if image.mode != 'RGB':
158
+ image = image.convert('RGB')
159
+
160
+ # Preprocess image
161
+ img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
162
+ img_tensor.requires_grad = True
163
+
164
+ # Get prediction
165
+ with torch.set_grad_enabled(True):
166
+ outputs = model(img_tensor)
167
+ probs = F.softmax(outputs, dim=1)
168
+ pred_class = outputs.argmax(dim=1).item()
169
+ confidence = probs[0][pred_class].item()
170
+
171
+ # Generate Grad-CAM
172
+ try:
173
+ gradcam = GradCAM(model, model.layer4)
174
+ cam = gradcam.generate_cam(img_tensor, pred_class)
175
+ except Exception as cam_error:
176
+ print(f"Grad-CAM error: {cam_error}")
177
+ import traceback
178
+ traceback.print_exc()
179
+ # If Grad-CAM fails, just return the original image
180
+ cam = None
181
+
182
+ # Prepare original image for overlay
183
+ img_np = np.array(image)
184
+ img_resized = cv2.resize(img_np, (224, 224))
185
+
186
+ # Create overlay if Grad-CAM succeeded
187
+ if cam is not None:
188
+ try:
189
+ overlay = overlay_heatmap(img_resized, cam)
190
+ overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
191
+ overlay_pil = Image.fromarray(overlay_rgb)
192
+ except Exception as overlay_error:
193
+ print(f"Overlay error: {overlay_error}")
194
+ overlay_pil = image.resize((224, 224))
195
+ else:
196
+ overlay_pil = image.resize((224, 224))
197
+
198
+ # Format results
199
+ result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
200
+
201
+ # Ensure we return PIL Image
202
+ if not isinstance(overlay_pil, Image.Image):
203
+ overlay_pil = Image.fromarray(np.array(overlay_pil))
204
+
205
+ return overlay_pil, str(result_text)
206
+ except Exception as e:
207
+ import traceback
208
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
209
+ print(error_msg) # Print for debugging
210
+ return None, f"Error: {str(e)}"
211
+
212
+ # Custom CSS for black background and white text
213
  custom_css = """
214
+ .gradio-container {
215
+ background-color: #000000 !important;
216
+ color: #ffffff !important;
217
+ }
218
+ body {
219
+ background-color: #000000 !important;
220
+ color: #ffffff !important;
221
+ }
222
+ .gradio-container * {
223
+ color: #ffffff !important;
224
+ }
225
+ h1, h2, h3, h4, p, label, span, div {
226
+ color: #ffffff !important;
227
+ }
228
+ .gr-markdown, .gr-markdown * {
229
+ color: #ffffff !important;
230
+ }
231
+ .gr-button {
232
+ background-color: #333333 !important;
233
+ color: #ffffff !important;
234
+ border: 1px solid #555555 !important;
235
+ }
236
+ .gr-button:hover {
237
+ background-color: #555555 !important;
238
+ }
239
+ .gr-textbox, .gr-textbox input, .gr-textbox textarea {
240
+ background-color: #1a1a1a !important;
241
+ color: #ffffff !important;
242
+ border: 1px solid #555555 !important;
243
+ }
244
+ .gr-image {
245
+ background-color: #000000 !important;
246
+ border: none !important;
247
+ padding: 0 !important;
248
+ margin: 0 !important;
249
+ }
250
+ .gr-image img {
251
+ border: none !important;
252
+ box-shadow: none !important;
253
+ background-color: #000000 !important;
254
+ }
255
+ .gr-image-container, .image-container, .image-wrapper {
256
+ border: none !important;
257
+ background-color: #000000 !important;
258
+ padding: 0 !important;
259
+ margin: 0 !important;
260
+ }
261
+ .gr-image .toolbar, .gr-image .image-controls {
262
+ display: none !important;
263
+ }
264
+ .gr-image label, .gr-image .label-wrap {
265
+ display: none !important;
266
+ }
267
+ .gr-box {
268
+ border: none !important;
269
+ background-color: #000000 !important;
270
+ }
271
+ .panel, .panel-header {
272
+ background-color: #000000 !important;
273
+ border: none !important;
274
+ }
275
  """
276
 
277
+ # Create Gradio interface
278
+ # Note: There's a known Gradio bug with API schema generation that causes errors
279
+ # The app will still work for classification, but API endpoints may fail
280
  with gr.Blocks(css=custom_css) as demo:
281
+ # Landing Section
282
+ with gr.Column():
283
+ landing_img = gr.Image(value="landing.jpg", height=500, show_label=False, container=False)
284
+ landing_text = gr.Markdown("""
285
+ <div style="text-align: center; padding: 30px; color: white; background-color: #000000; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
286
+ <h1 style="font-size: 96px; font-weight: bold; margin: 0 auto 30px auto; text-align: center; width: 100%;">Galaxy Morphology AI</h1>
287
+ <p style="font-size: 56px; font-weight: normal; margin: 0 auto; text-align: center; width: 100%;">Classify galaxies with state-of-the-art deep learning</p>
288
+ </div>
289
+ """)
290
+
291
+ # Spacing between sections
292
+ gr.Markdown("<div style='height: 60px;'></div>")
293
+
294
+ # How Astrophysicists Use This Section
295
  with gr.Row():
296
+ with gr.Column(scale=1):
297
+ gr.Markdown("""
298
+ # How Astrophysicists Use This
299
+
300
+ Galaxy morphology classification is a fundamental tool in modern astrophysics.
301
+ By automatically identifying whether a galaxy is elliptical or spiral, researchers
302
+ can analyze large datasets from telescopes like the Hubble Space Telescope and
303
+ the James Webb Space Telescope. This classification helps understand galaxy
304
+ formation, evolution, and the distribution of matter in the universe.
305
+
306
+ The deep learning model uses convolutional neural networks to identify key
307
+ features in galaxy images, such as spiral arms, central bulges, and overall
308
+ structure. This automated classification enables astronomers to process millions
309
+ of galaxy images efficiently, accelerating discoveries in cosmology and
310
+ extragalactic astronomy.
311
+ """)
312
+ with gr.Column(scale=1):
313
+ astro_img = gr.Image(value="astro.jpg", show_label=False, container=False, height=400)
314
+ gr.Markdown("<p style='text-align: center; color: white; margin-top: 10px;'>Astrophysics Research</p>")
315
+
316
+ # Spacing between sections
317
+ gr.Markdown("<div style='height: 60px;'></div>")
318
+
319
+ # Classification Section
320
+ gr.Markdown("# Galaxy Morphology Classification")
321
+ gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
322
+
323
+ with gr.Row():
324
+ with gr.Column():
325
+ input_image = gr.Image(label="Upload Galaxy Image")
326
+ classify_btn = gr.Button("Classify Galaxy")
327
+
328
+ with gr.Column():
329
+ output_image = gr.Image(label="Grad-CAM Visualization")
330
+ result_text = gr.Textbox(label="Classification Result")
331
+
332
+ # Register the classification function
333
+ # Disable API to avoid Gradio schema generation bug
334
  classify_btn.click(
335
+ fn=predict_galaxy,
336
+ inputs=[input_image],
337
+ outputs=[output_image, result_text],
338
+ api_name=False
339
  )
340
+
341
+ # Spacing between sections
342
+ gr.Markdown("<div style='height: 60px;'></div>")
343
+
344
+ # Dark Energy Section
345
+ gr.Markdown("""
346
+ # Understanding Dark Energy Through Galaxy Morphology
347
+
348
+ Galaxy morphology classification plays a crucial role in understanding dark energy,
349
+ one of the most profound mysteries in modern cosmology. Dark energy is the
350
+ mysterious force driving the accelerated expansion of the universe, and its nature
351
+ remains one of the biggest questions in physics.
352
+
353
+ By classifying large numbers of galaxies and mapping their distribution across
354
+ cosmic time, astronomers can trace the expansion history of the universe.
355
+ Different galaxy types (elliptical vs spiral) form and evolve differently, and
356
+ their relative abundances at different redshifts provide clues about the universe's
357
+ evolution. The distribution and clustering of these galaxies help measure the
358
+ large-scale structure of the universe, which is directly influenced by dark energy.
359
+
360
+ Automated classification systems like this one enable the analysis of millions of
361
+ galaxies from current and future surveys, such as the Vera C. Rubin Observatory's
362
+ Legacy Survey of Space and Time (LSST). These massive datasets will provide
363
+ unprecedented precision in measuring dark energy's properties and understanding
364
+ its role in the fate of the universe.
365
+ """)
366
+
367
+ # Launch the demo
368
+ # For Hugging Face Spaces, Gradio will automatically detect and launch the demo
369
+ # The API error is a known Gradio bug - the app will still work for classification
370
  if __name__ == "__main__":
371
+ try:
372
+ demo.launch(show_api=False)
373
+ except Exception as e:
374
+ # If launch fails, try without API
375
+ print(f"Launch error (non-critical): {e}")
376
+ demo.launch()