update app.py
#3
by tiffany101 - opened
app.py
CHANGED
|
@@ -67,7 +67,6 @@ except Exception as e:
|
|
| 67 |
print(f"Failed to load model: {e}")
|
| 68 |
import traceback
|
| 69 |
traceback.print_exc()
|
| 70 |
-
# Create a dummy model as fallback
|
| 71 |
model = get_model(NUM_CLASSES).to(DEVICE)
|
| 72 |
model.eval()
|
| 73 |
print("Using untrained model as fallback")
|
|
@@ -79,7 +78,6 @@ class GradCAM:
|
|
| 79 |
self.target_layer = target_layer
|
| 80 |
self.gradients = None
|
| 81 |
self.activations = None
|
| 82 |
-
self.hook_handles = []
|
| 83 |
|
| 84 |
def save_activation(self, module, input, output):
|
| 85 |
self.activations = output.detach()
|
|
@@ -88,33 +86,23 @@ class GradCAM:
|
|
| 88 |
self.gradients = grad_output[0].detach()
|
| 89 |
|
| 90 |
def generate_cam(self, input_image, target_class=None):
|
| 91 |
-
# Register hooks
|
| 92 |
forward_handle = self.target_layer.register_forward_hook(self.save_activation)
|
| 93 |
backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
|
| 94 |
|
| 95 |
try:
|
| 96 |
-
# Forward pass
|
| 97 |
model_output = self.model(input_image)
|
| 98 |
|
| 99 |
if target_class is None:
|
| 100 |
target_class = model_output.argmax(dim=1).item()
|
| 101 |
|
| 102 |
-
# Backward pass
|
| 103 |
self.model.zero_grad()
|
| 104 |
class_score = model_output[0, target_class]
|
| 105 |
class_score.backward(retain_graph=False)
|
| 106 |
|
| 107 |
-
if self.gradients is None or self.activations is None:
|
| 108 |
-
return np.zeros((7, 7)) # Default size for ResNet layer4
|
| 109 |
-
|
| 110 |
gradients = self.gradients[0]
|
| 111 |
activations = self.activations[0]
|
| 112 |
-
|
| 113 |
-
# Global average pooling of gradients
|
| 114 |
weights = gradients.mean(dim=(1, 2), keepdim=True)
|
| 115 |
cam = (weights * activations).sum(dim=0)
|
| 116 |
-
|
| 117 |
-
# Apply ReLU and normalize
|
| 118 |
cam = F.relu(cam)
|
| 119 |
cam = cam - cam.min()
|
| 120 |
if cam.max() > 0:
|
|
@@ -122,176 +110,113 @@ class GradCAM:
|
|
| 122 |
|
| 123 |
return cam.detach().cpu().numpy()
|
| 124 |
finally:
|
| 125 |
-
# Remove hooks
|
| 126 |
forward_handle.remove()
|
| 127 |
backward_handle.remove()
|
| 128 |
self.gradients = None
|
| 129 |
self.activations = None
|
| 130 |
|
| 131 |
def overlay_heatmap(image, heatmap, alpha=0.4):
|
| 132 |
-
"""Overlay heatmap on original image"""
|
| 133 |
heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
|
| 134 |
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
|
| 135 |
-
|
| 136 |
-
return output
|
| 137 |
|
| 138 |
def predict_galaxy(image):
|
| 139 |
-
"""Predict galaxy morphology and generate Grad-CAM"""
|
| 140 |
if image is None:
|
| 141 |
return None, "Please upload an image."
|
| 142 |
|
| 143 |
if model is None:
|
| 144 |
return None, "Error: Model not loaded. Please check the logs."
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
except Exception as cam_error:
|
| 176 |
-
print(f"Grad-CAM error: {cam_error}")
|
| 177 |
-
import traceback
|
| 178 |
-
traceback.print_exc()
|
| 179 |
-
# If Grad-CAM fails, just return the original image
|
| 180 |
-
cam = None
|
| 181 |
-
|
| 182 |
-
# Prepare original image for overlay
|
| 183 |
-
img_np = np.array(image)
|
| 184 |
-
img_resized = cv2.resize(img_np, (224, 224))
|
| 185 |
-
|
| 186 |
-
# Create overlay if Grad-CAM succeeded
|
| 187 |
-
if cam is not None:
|
| 188 |
-
try:
|
| 189 |
-
overlay = overlay_heatmap(img_resized, cam)
|
| 190 |
-
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
| 191 |
-
overlay_pil = Image.fromarray(overlay_rgb)
|
| 192 |
-
except Exception as overlay_error:
|
| 193 |
-
print(f"Overlay error: {overlay_error}")
|
| 194 |
-
overlay_pil = image.resize((224, 224))
|
| 195 |
-
else:
|
| 196 |
-
overlay_pil = image.resize((224, 224))
|
| 197 |
-
|
| 198 |
-
# Format results
|
| 199 |
-
result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
|
| 200 |
-
|
| 201 |
-
# Ensure we return PIL Image
|
| 202 |
-
if not isinstance(overlay_pil, Image.Image):
|
| 203 |
-
overlay_pil = Image.fromarray(np.array(overlay_pil))
|
| 204 |
-
|
| 205 |
-
return overlay_pil, str(result_text)
|
| 206 |
-
except Exception as e:
|
| 207 |
-
import traceback
|
| 208 |
-
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
| 209 |
-
print(error_msg) # Print for debugging
|
| 210 |
-
return None, f"Error: {str(e)}"
|
| 211 |
|
| 212 |
-
#
|
|
|
|
|
|
|
| 213 |
custom_css = """
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
box-shadow: none !important;
|
| 253 |
-
background-color: #000000 !important;
|
| 254 |
-
}
|
| 255 |
-
.gr-image-container, .image-container, .image-wrapper {
|
| 256 |
-
border: none !important;
|
| 257 |
-
background-color: #000000 !important;
|
| 258 |
-
padding: 0 !important;
|
| 259 |
-
margin: 0 !important;
|
| 260 |
-
}
|
| 261 |
-
.gr-image .toolbar, .gr-image .image-controls {
|
| 262 |
-
display: none !important;
|
| 263 |
-
}
|
| 264 |
-
.gr-image label, .gr-image .label-wrap {
|
| 265 |
-
display: none !important;
|
| 266 |
-
}
|
| 267 |
-
.gr-box {
|
| 268 |
-
border: none !important;
|
| 269 |
-
background-color: #000000 !important;
|
| 270 |
-
}
|
| 271 |
-
.panel, .panel-header {
|
| 272 |
-
background-color: #000000 !important;
|
| 273 |
-
border: none !important;
|
| 274 |
-
}
|
| 275 |
"""
|
| 276 |
|
| 277 |
-
#
|
| 278 |
-
#
|
| 279 |
-
#
|
| 280 |
with gr.Blocks(css=custom_css) as demo:
|
| 281 |
-
|
| 282 |
with gr.Column():
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
<div style="text-align: center; padding: 30px;
|
| 286 |
-
<h1 style="font-size: 96px; font-weight: bold;
|
| 287 |
-
<p style="font-size: 56px;
|
| 288 |
</div>
|
| 289 |
""")
|
| 290 |
-
|
| 291 |
-
# Spacing between sections
|
| 292 |
gr.Markdown("<div style='height: 60px;'></div>")
|
| 293 |
-
|
| 294 |
-
# How Astrophysicists Use This Section
|
| 295 |
with gr.Row():
|
| 296 |
with gr.Column(scale=1):
|
| 297 |
gr.Markdown("""
|
|
@@ -310,16 +235,14 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 310 |
extragalactic astronomy.
|
| 311 |
""")
|
| 312 |
with gr.Column(scale=1):
|
| 313 |
-
|
| 314 |
-
gr.Markdown("<p style='text-align: center;
|
| 315 |
-
|
| 316 |
-
# Spacing between sections
|
| 317 |
gr.Markdown("<div style='height: 60px;'></div>")
|
| 318 |
-
|
| 319 |
-
# Classification Section
|
| 320 |
gr.Markdown("# Galaxy Morphology Classification")
|
| 321 |
gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
|
| 322 |
-
|
| 323 |
with gr.Row():
|
| 324 |
with gr.Column():
|
| 325 |
input_image = gr.Image(label="Upload Galaxy Image")
|
|
@@ -327,21 +250,17 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 327 |
|
| 328 |
with gr.Column():
|
| 329 |
output_image = gr.Image(label="Grad-CAM Visualization")
|
| 330 |
-
result_text = gr.
|
| 331 |
-
|
| 332 |
-
# Register the classification function
|
| 333 |
-
# Disable API to avoid Gradio schema generation bug
|
| 334 |
classify_btn.click(
|
| 335 |
fn=predict_galaxy,
|
| 336 |
inputs=[input_image],
|
| 337 |
outputs=[output_image, result_text],
|
| 338 |
api_name=False
|
| 339 |
)
|
| 340 |
-
|
| 341 |
-
# Spacing between sections
|
| 342 |
gr.Markdown("<div style='height: 60px;'></div>")
|
| 343 |
-
|
| 344 |
-
# Dark Energy Section
|
| 345 |
gr.Markdown("""
|
| 346 |
# Understanding Dark Energy Through Galaxy Morphology
|
| 347 |
|
|
@@ -364,13 +283,9 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 364 |
its role in the fate of the universe.
|
| 365 |
""")
|
| 366 |
|
| 367 |
-
# Launch
|
| 368 |
-
# For Hugging Face Spaces, Gradio will automatically detect and launch the demo
|
| 369 |
-
# The API error is a known Gradio bug - the app will still work for classification
|
| 370 |
if __name__ == "__main__":
|
| 371 |
try:
|
| 372 |
demo.launch(show_api=False)
|
| 373 |
-
except Exception
|
| 374 |
-
# If launch fails, try without API
|
| 375 |
-
print(f"Launch error (non-critical): {e}")
|
| 376 |
demo.launch()
|
|
|
|
| 67 |
print(f"Failed to load model: {e}")
|
| 68 |
import traceback
|
| 69 |
traceback.print_exc()
|
|
|
|
| 70 |
model = get_model(NUM_CLASSES).to(DEVICE)
|
| 71 |
model.eval()
|
| 72 |
print("Using untrained model as fallback")
|
|
|
|
| 78 |
self.target_layer = target_layer
|
| 79 |
self.gradients = None
|
| 80 |
self.activations = None
|
|
|
|
| 81 |
|
| 82 |
def save_activation(self, module, input, output):
|
| 83 |
self.activations = output.detach()
|
|
|
|
| 86 |
self.gradients = grad_output[0].detach()
|
| 87 |
|
| 88 |
def generate_cam(self, input_image, target_class=None):
|
|
|
|
| 89 |
forward_handle = self.target_layer.register_forward_hook(self.save_activation)
|
| 90 |
backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
|
| 91 |
|
| 92 |
try:
|
|
|
|
| 93 |
model_output = self.model(input_image)
|
| 94 |
|
| 95 |
if target_class is None:
|
| 96 |
target_class = model_output.argmax(dim=1).item()
|
| 97 |
|
|
|
|
| 98 |
self.model.zero_grad()
|
| 99 |
class_score = model_output[0, target_class]
|
| 100 |
class_score.backward(retain_graph=False)
|
| 101 |
|
|
|
|
|
|
|
|
|
|
| 102 |
gradients = self.gradients[0]
|
| 103 |
activations = self.activations[0]
|
|
|
|
|
|
|
| 104 |
weights = gradients.mean(dim=(1, 2), keepdim=True)
|
| 105 |
cam = (weights * activations).sum(dim=0)
|
|
|
|
|
|
|
| 106 |
cam = F.relu(cam)
|
| 107 |
cam = cam - cam.min()
|
| 108 |
if cam.max() > 0:
|
|
|
|
| 110 |
|
| 111 |
return cam.detach().cpu().numpy()
|
| 112 |
finally:
|
|
|
|
| 113 |
forward_handle.remove()
|
| 114 |
backward_handle.remove()
|
| 115 |
self.gradients = None
|
| 116 |
self.activations = None
|
| 117 |
|
| 118 |
def overlay_heatmap(image, heatmap, alpha=0.4):
|
|
|
|
| 119 |
heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
|
| 120 |
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
|
| 121 |
+
return cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
|
|
|
|
| 122 |
|
| 123 |
def predict_galaxy(image):
|
|
|
|
| 124 |
if image is None:
|
| 125 |
return None, "Please upload an image."
|
| 126 |
|
| 127 |
if model is None:
|
| 128 |
return None, "Error: Model not loaded. Please check the logs."
|
| 129 |
|
| 130 |
+
model.eval()
|
| 131 |
+
|
| 132 |
+
if isinstance(image, np.ndarray):
|
| 133 |
+
image = Image.fromarray(image.astype("uint8"))
|
| 134 |
+
elif not isinstance(image, Image.Image):
|
| 135 |
+
image = Image.open(image)
|
| 136 |
+
|
| 137 |
+
if image.mode != "RGB":
|
| 138 |
+
image = image.convert("RGB")
|
| 139 |
+
|
| 140 |
+
img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
|
| 141 |
+
img_tensor.requires_grad = True
|
| 142 |
+
|
| 143 |
+
outputs = model(img_tensor)
|
| 144 |
+
probs = F.softmax(outputs, dim=1)
|
| 145 |
+
pred_class = outputs.argmax(dim=1).item()
|
| 146 |
+
confidence = probs[0][pred_class].item()
|
| 147 |
+
|
| 148 |
+
gradcam = GradCAM(model, model.layer4)
|
| 149 |
+
cam = gradcam.generate_cam(img_tensor, pred_class)
|
| 150 |
+
|
| 151 |
+
img_np = np.array(image.resize((224, 224)))
|
| 152 |
+
overlay = overlay_heatmap(img_np, cam)
|
| 153 |
+
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
| 154 |
+
overlay_pil = Image.fromarray(overlay_rgb)
|
| 155 |
+
|
| 156 |
+
result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
|
| 157 |
+
|
| 158 |
+
return overlay_pil, result_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
# =========================
|
| 161 |
+
# Custom CSS
|
| 162 |
+
# =========================
|
| 163 |
custom_css = """
|
| 164 |
+
.gradio-container {
|
| 165 |
+
background-color: #000000 !important;
|
| 166 |
+
color: #ffffff !important;
|
| 167 |
+
}
|
| 168 |
+
body {
|
| 169 |
+
background-color: #000000 !important;
|
| 170 |
+
color: #ffffff !important;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
/* 🔴 FIX 1: REMOVED unsafe global selector */
|
| 174 |
+
/* .gradio-container * { color: #ffffff !important; } */
|
| 175 |
+
|
| 176 |
+
h1, h2, h3, h4, p, label, span, div {
|
| 177 |
+
color: #ffffff !important;
|
| 178 |
+
}
|
| 179 |
+
.gr-markdown, .gr-markdown * {
|
| 180 |
+
color: #ffffff !important;
|
| 181 |
+
}
|
| 182 |
+
.gr-button {
|
| 183 |
+
background-color: #333333 !important;
|
| 184 |
+
color: #ffffff !important;
|
| 185 |
+
border: 1px solid #555555 !important;
|
| 186 |
+
}
|
| 187 |
+
.gr-button:hover {
|
| 188 |
+
background-color: #555555 !important;
|
| 189 |
+
}
|
| 190 |
+
.gr-textbox, .gr-textbox input, .gr-textbox textarea {
|
| 191 |
+
background-color: #1a1a1a !important;
|
| 192 |
+
color: #ffffff !important;
|
| 193 |
+
border: 1px solid #555555 !important;
|
| 194 |
+
}
|
| 195 |
+
.gr-image {
|
| 196 |
+
background-color: #000000 !important;
|
| 197 |
+
border: none !important;
|
| 198 |
+
}
|
| 199 |
+
.gr-image img {
|
| 200 |
+
background-color: #000000 !important;
|
| 201 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
"""
|
| 203 |
|
| 204 |
+
# =========================
|
| 205 |
+
# UI
|
| 206 |
+
# =========================
|
| 207 |
with gr.Blocks(css=custom_css) as demo:
|
| 208 |
+
|
| 209 |
with gr.Column():
|
| 210 |
+
gr.Image(value="landing.jpg", height=500, show_label=False, container=False)
|
| 211 |
+
gr.Markdown("""
|
| 212 |
+
<div style="text-align: center; padding: 30px;">
|
| 213 |
+
<h1 style="font-size: 96px; font-weight: bold;">Galaxy Morphology AI</h1>
|
| 214 |
+
<p style="font-size: 56px;">Classify galaxies with state-of-the-art deep learning</p>
|
| 215 |
</div>
|
| 216 |
""")
|
| 217 |
+
|
|
|
|
| 218 |
gr.Markdown("<div style='height: 60px;'></div>")
|
| 219 |
+
|
|
|
|
| 220 |
with gr.Row():
|
| 221 |
with gr.Column(scale=1):
|
| 222 |
gr.Markdown("""
|
|
|
|
| 235 |
extragalactic astronomy.
|
| 236 |
""")
|
| 237 |
with gr.Column(scale=1):
|
| 238 |
+
gr.Image(value="astro.jpg", show_label=False, container=False, height=400)
|
| 239 |
+
gr.Markdown("<p style='text-align: center;'>Astrophysics Research</p>")
|
| 240 |
+
|
|
|
|
| 241 |
gr.Markdown("<div style='height: 60px;'></div>")
|
| 242 |
+
|
|
|
|
| 243 |
gr.Markdown("# Galaxy Morphology Classification")
|
| 244 |
gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
|
| 245 |
+
|
| 246 |
with gr.Row():
|
| 247 |
with gr.Column():
|
| 248 |
input_image = gr.Image(label="Upload Galaxy Image")
|
|
|
|
| 250 |
|
| 251 |
with gr.Column():
|
| 252 |
output_image = gr.Image(label="Grad-CAM Visualization")
|
| 253 |
+
result_text = gr.Markdown() # 🔴 FIX 2: Textbox → Markdown (read-only)
|
| 254 |
+
|
|
|
|
|
|
|
| 255 |
classify_btn.click(
|
| 256 |
fn=predict_galaxy,
|
| 257 |
inputs=[input_image],
|
| 258 |
outputs=[output_image, result_text],
|
| 259 |
api_name=False
|
| 260 |
)
|
| 261 |
+
|
|
|
|
| 262 |
gr.Markdown("<div style='height: 60px;'></div>")
|
| 263 |
+
|
|
|
|
| 264 |
gr.Markdown("""
|
| 265 |
# Understanding Dark Energy Through Galaxy Morphology
|
| 266 |
|
|
|
|
| 283 |
its role in the fate of the universe.
|
| 284 |
""")
|
| 285 |
|
| 286 |
+
# Launch
|
|
|
|
|
|
|
| 287 |
if __name__ == "__main__":
|
| 288 |
try:
|
| 289 |
demo.launch(show_api=False)
|
| 290 |
+
except Exception:
|
|
|
|
|
|
|
| 291 |
demo.launch()
|