tiffany101's picture
update app.py
2ceffa9 verified
raw
history blame
14.2 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
import os
# Workaround for Gradio API schema bug
# Monkey-patch to handle the schema generation error gracefully
try:
import gradio_client.utils as client_utils
original_get_type = client_utils.get_type
def patched_get_type(schema):
if isinstance(schema, bool):
return "bool"
return original_get_type(schema)
client_utils.get_type = patched_get_type
except:
pass # If patching fails, continue anyway
# Model configuration
MODEL_PATH = "robust_galaxy_model.pth"
NUM_CLASSES = 2
CLASS_NAMES = ["Elliptical", "Spiral"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load model
def get_model(num_classes=2):
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
def load_model():
model = get_model(NUM_CLASSES)
if os.path.exists(MODEL_PATH):
try:
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
print(f"Model loaded successfully from {MODEL_PATH}")
except Exception as e:
print(f"Error loading model: {e}")
print("Using untrained model")
else:
print(f"Model file not found at {MODEL_PATH}. Using untrained model.")
model.to(DEVICE)
model.eval()
return model
# Load model - handle errors gracefully
model = None
try:
model = load_model()
print("Model loaded successfully")
except Exception as e:
print(f"Failed to load model: {e}")
import traceback
traceback.print_exc()
# Create a dummy model as fallback
model = get_model(NUM_CLASSES).to(DEVICE)
model.eval()
print("Using untrained model as fallback")
# Grad-CAM implementation
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.hook_handles = []
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate_cam(self, input_image, target_class=None):
# Register hooks
forward_handle = self.target_layer.register_forward_hook(self.save_activation)
backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
try:
# Forward pass
model_output = self.model(input_image)
if target_class is None:
target_class = model_output.argmax(dim=1).item()
# Backward pass
self.model.zero_grad()
class_score = model_output[0, target_class]
class_score.backward(retain_graph=False)
if self.gradients is None or self.activations is None:
return np.zeros((7, 7)) # Default size for ResNet layer4
gradients = self.gradients[0]
activations = self.activations[0]
# Global average pooling of gradients
weights = gradients.mean(dim=(1, 2), keepdim=True)
cam = (weights * activations).sum(dim=0)
# Apply ReLU and normalize
cam = F.relu(cam)
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
return cam.detach().cpu().numpy()
finally:
# Remove hooks
forward_handle.remove()
backward_handle.remove()
self.gradients = None
self.activations = None
def overlay_heatmap(image, heatmap, alpha=0.4):
"""Overlay heatmap on original image"""
heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
output = cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
return output
def predict_galaxy(image):
"""Predict galaxy morphology and generate Grad-CAM"""
if image is None:
return None, "Please upload an image."
if model is None:
return None, "Error: Model not loaded. Please check the logs."
try:
# Ensure model is in eval mode
model.eval()
# Convert image to PIL if it's not already
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
elif not isinstance(image, Image.Image):
image = Image.open(image) if hasattr(image, 'read') else Image.fromarray(np.array(image))
# Ensure image is RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess image
img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
img_tensor.requires_grad = True
# Get prediction
with torch.set_grad_enabled(True):
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
pred_class = outputs.argmax(dim=1).item()
confidence = probs[0][pred_class].item()
# Generate Grad-CAM
try:
gradcam = GradCAM(model, model.layer4)
cam = gradcam.generate_cam(img_tensor, pred_class)
except Exception as cam_error:
print(f"Grad-CAM error: {cam_error}")
import traceback
traceback.print_exc()
# If Grad-CAM fails, just return the original image
cam = None
# Prepare original image for overlay
img_np = np.array(image)
img_resized = cv2.resize(img_np, (224, 224))
# Create overlay if Grad-CAM succeeded
if cam is not None:
try:
overlay = overlay_heatmap(img_resized, cam)
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
overlay_pil = Image.fromarray(overlay_rgb)
except Exception as overlay_error:
print(f"Overlay error: {overlay_error}")
overlay_pil = image.resize((224, 224))
else:
overlay_pil = image.resize((224, 224))
# Format results
result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
# Ensure we return PIL Image
if not isinstance(overlay_pil, Image.Image):
overlay_pil = Image.fromarray(np.array(overlay_pil))
return overlay_pil, str(result_text)
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg) # Print for debugging
return None, f"Error: {str(e)}"
# Custom CSS for black background and white text
custom_css = """
.gradio-container {
background-color: #000000 !important;
color: #ffffff !important;
}
body {
background-color: #000000 !important;
color: #ffffff !important;
}
.gradio-container * {
color: #ffffff !important;
}
h1, h2, h3, h4, p, label, span, div {
color: #ffffff !important;
}
.gr-markdown, .gr-markdown * {
color: #ffffff !important;
}
.gr-button {
background-color: #333333 !important;
color: #ffffff !important;
border: 1px solid #555555 !important;
}
.gr-button:hover {
background-color: #555555 !important;
}
.gr-textbox, .gr-textbox input, .gr-textbox textarea {
background-color: #1a1a1a !important;
color: #ffffff !important;
border: 1px solid #555555 !important;
}
.gr-image {
background-color: #000000 !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
.gr-image img {
border: none !important;
box-shadow: none !important;
background-color: #000000 !important;
}
.gr-image-container, .image-container, .image-wrapper {
border: none !important;
background-color: #000000 !important;
padding: 0 !important;
margin: 0 !important;
}
.gr-image .toolbar, .gr-image .image-controls {
display: none !important;
}
.gr-image label, .gr-image .label-wrap {
display: none !important;
}
.gr-box {
border: none !important;
background-color: #000000 !important;
}
.panel, .panel-header {
background-color: #000000 !important;
border: none !important;
}
"""
# Create Gradio interface
# Note: There's a known Gradio bug with API schema generation that causes errors
# The app will still work for classification, but API endpoints may fail
with gr.Blocks(css=custom_css) as demo:
# Landing Section
with gr.Column():
landing_img = gr.Image(value="landing.jpg", height=500, show_label=False, container=False)
landing_text = gr.Markdown("""
<div style="text-align: center; padding: 30px; color: white; background-color: #000000; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
<h1 style="font-size: 96px; font-weight: bold; margin: 0 auto 30px auto; text-align: center; width: 100%;">Galaxy Morphology AI</h1>
<p style="font-size: 56px; font-weight: normal; margin: 0 auto; text-align: center; width: 100%;">Classify galaxies with state-of-the-art deep learning</p>
</div>
""")
# Spacing between sections
gr.Markdown("<div style='height: 60px;'></div>")
# How Astrophysicists Use This Section
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""
# How Astrophysicists Use This
Galaxy morphology classification is a fundamental tool in modern astrophysics.
By automatically identifying whether a galaxy is elliptical or spiral, researchers
can analyze large datasets from telescopes like the Hubble Space Telescope and
the James Webb Space Telescope. This classification helps understand galaxy
formation, evolution, and the distribution of matter in the universe.
The deep learning model uses convolutional neural networks to identify key
features in galaxy images, such as spiral arms, central bulges, and overall
structure. This automated classification enables astronomers to process millions
of galaxy images efficiently, accelerating discoveries in cosmology and
extragalactic astronomy.
""")
with gr.Column(scale=1):
astro_img = gr.Image(value="astro.jpg", show_label=False, container=False, height=400)
gr.Markdown("<p style='text-align: center; color: white; margin-top: 10px;'>Astrophysics Research</p>")
# Spacing between sections
gr.Markdown("<div style='height: 60px;'></div>")
# Classification Section
gr.Markdown("# Galaxy Morphology Classification")
gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Galaxy Image")
classify_btn = gr.Button("Classify Galaxy")
with gr.Column():
output_image = gr.Image(label="Grad-CAM Visualization")
result_text = gr.Textbox(label="Classification Result")
# Register the classification function
# Disable API to avoid Gradio schema generation bug
classify_btn.click(
fn=predict_galaxy,
inputs=[input_image],
outputs=[output_image, result_text],
api_name=False
)
# Spacing between sections
gr.Markdown("<div style='height: 60px;'></div>")
# Dark Energy Section
gr.Markdown("""
# Understanding Dark Energy Through Galaxy Morphology
Galaxy morphology classification plays a crucial role in understanding dark energy,
one of the most profound mysteries in modern cosmology. Dark energy is the
mysterious force driving the accelerated expansion of the universe, and its nature
remains one of the biggest questions in physics.
By classifying large numbers of galaxies and mapping their distribution across
cosmic time, astronomers can trace the expansion history of the universe.
Different galaxy types (elliptical vs spiral) form and evolve differently, and
their relative abundances at different redshifts provide clues about the universe's
evolution. The distribution and clustering of these galaxies help measure the
large-scale structure of the universe, which is directly influenced by dark energy.
Automated classification systems like this one enable the analysis of millions of
galaxies from current and future surveys, such as the Vera C. Rubin Observatory's
Legacy Survey of Space and Time (LSST). These massive datasets will provide
unprecedented precision in measuring dark energy's properties and understanding
its role in the fate of the universe.
""")
# Launch the demo
# For Hugging Face Spaces, Gradio will automatically detect and launch the demo
# The API error is a known Gradio bug - the app will still work for classification
if __name__ == "__main__":
try:
demo.launch(show_api=False)
except Exception as e:
# If launch fails, try without API
print(f"Launch error (non-critical): {e}")
demo.launch()