tiffany101's picture
update app.py
17a40a1 verified
raw
history blame
10.9 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
import os
# Workaround for Gradio API schema bug
# Monkey-patch to handle the schema generation error gracefully
try:
import gradio_client.utils as client_utils
original_get_type = client_utils.get_type
def patched_get_type(schema):
if isinstance(schema, bool):
return "bool"
return original_get_type(schema)
client_utils.get_type = patched_get_type
except:
pass # If patching fails, continue anyway
# Model configuration
MODEL_PATH = "robust_galaxy_model.pth"
NUM_CLASSES = 2
CLASS_NAMES = ["Elliptical", "Spiral"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# πŸ”΄ Calibration + OOD thresholds
TEMPERATURE = 2.5 # softens overconfidence
CONF_THRESHOLD = 0.60 # below this β†’ uncertain
ENTROPY_THRESHOLD = 0.85 # high entropy β†’ uncertain
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load model
def get_model(num_classes=2):
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
def load_model():
model = get_model(NUM_CLASSES)
if os.path.exists(MODEL_PATH):
try:
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict((torch.load(MODEL_PATH, map_location=DEVICE))
print(f"Model loaded successfully from {MODEL_PATH}")
except Exception as e:
print(f"Error loading model: {e}")
print("Using untrained model")
else:
print(f"Model file not found at {MODEL_PATH}. Using untrained model.")
model.to(DEVICE)
model.eval()
return model
# Load model - handle errors gracefully
model = None
try:
model = load_model()
print("Model loaded successfully")
except Exception as e:
print(f"Failed to load model: {e}")
import traceback
traceback.print_exc()
model = get_model(NUM_CLASSES).to(DEVICE)
model.eval()
print("Using untrained model as fallback")
# Grad-CAM implementation
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate_cam(self, input_image, target_class=None):
forward_handle = self.target_layer.register_forward_hook(self.save_activation)
backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
try:
model_output = self.model(input_image)
if target_class is None:
target_class = model_output.argmax(dim=1).item()
self.model.zero_grad()
class_score = model_output[0, target_class]
class_score.backward(retain_graph=False)
gradients = self.gradients[0]
activations = self.activations[0]
weights = gradients.mean(dim=(1, 2), keepdim=True)
cam = (weights * activations).sum(dim=0)
cam = F.relu(cam)
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
return cam.detach().cpu().numpy()
finally:
forward_handle.remove()
backward_handle.remove()
self.gradients = None
self.activations = None
def overlay_heatmap(image, heatmap, alpha=0.4):
heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
return cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
def predict_galaxy(image):
if image is None:
return None, "Please upload an image."
if model is None:
return None, "Error: Model not loaded. Please check the logs."
model.eval()
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8"))
elif not isinstance(image, Image.Image):
image = Image.open(image)
if image.mode != "RGB":
image = image.convert("RGB")
img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
img_tensor.requires_grad = True
# πŸ”΄ Temperature scaling
scaled_logits = logits / TEMPERATURE
probs = F.softmax(scaled_logits, dim=1)[0]
confidence, pred_class = torch.max(probs, dim=0)
# πŸ”΄ Entropy-based uncertainty
entropy = -(probs * torch.log(probs + 1e-8)).sum().item()
if confidence.item() < CONF_THRESHOLD or entropy > ENTROPY_THRESHOLD:
result_text = (
"**Prediction:** Uncertain / Not a Galaxy\n"
"**Confidence:** Low"
)
overlay_img = image.resize((224, 224))
return overlay_img, result_text
gradcam = GradCAM(model, model.layer4)
cam = gradcam.generate_cam(img_tensor, pred_class)
img_np = np.array(image.resize((224, 224)))
overlay = overlay_heatmap(img_np, cam)
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
overlay_pil = Image.fromarray(overlay_rgb)
# πŸ”΄ Separate lines (as requested)
result_text = (
f"**Prediction:** {CLASS_NAMES[pred_class.item()]}\n"
f"**Confidence:** {confidence.item() * 100:.2f}%"
)
return overlay_pil, result_text
# =========================
# Custom CSS
# =========================
custom_css = """
.gradio-container {
background-color: #000000 !important;
color: #ffffff !important;
}
body {
background-color: #000000 !important;
color: #ffffff !important;
}
/* πŸ”΄ FIX 1: REMOVED unsafe global selector */
/* .gradio-container * { color: #ffffff !important; } */
h1, h2, h3, h4, p, label, span, div {
color: #ffffff !important;
}
.gr-markdown, .gr-markdown * {
color: #ffffff !important;
}
.gr-button {
background-color: #333333 !important;
color: #ffffff !important;
border: 1px solid #555555 !important;
}
.gr-button:hover {
background-color: #555555 !important;
}
.gr-textbox, .gr-textbox input, .gr-textbox textarea {
background-color: #1a1a1a !important;
color: #ffffff !important;
border: 1px solid #555555 !important;
}
.gr-image {
background-color: #000000 !important;
border: none !important;
}
.gr-image img {
background-color: #000000 !important;
}
"""
# =========================
# UI
# =========================
with gr.Blocks(css=custom_css) as demo:
with gr.Column():
gr.Image(value="landing.jpg", height=500, show_label=False, container=False)
gr.Markdown("""
<div style="text-align: center; padding: 30px;">
<h1 style="font-size: 96px; font-weight: bold;">Galaxy Morphology AI</h1>
<p style="font-size: 56px;">Classify galaxies with state-of-the-art deep learning</p>
</div>
""")
gr.Markdown("<div style='height: 60px;'></div>")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""
# How Astrophysicists Use This
Galaxy morphology classification is a fundamental tool in modern astrophysics.
By automatically identifying whether a galaxy is elliptical or spiral, researchers
can analyze large datasets from telescopes like the Hubble Space Telescope and
the James Webb Space Telescope. This classification helps understand galaxy
formation, evolution, and the distribution of matter in the universe.
The deep learning model uses convolutional neural networks to identify key
features in galaxy images, such as spiral arms, central bulges, and overall
structure. This automated classification enables astronomers to process millions
of galaxy images efficiently, accelerating discoveries in cosmology and
extragalactic astronomy.
""")
with gr.Column(scale=1):
gr.Image(value="astro.jpg", show_label=False, container=False, height=400)
gr.Markdown("<p style='text-align: center;'>Astrophysics Research</p>")
gr.Markdown("<div style='height: 60px;'></div>")
gr.Markdown("# Galaxy Morphology Classification")
gr.Markdown("Upload a galaxy image to classify its morphology and visualize the model's attention using Grad-CAM.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Galaxy Image")
classify_btn = gr.Button("Classify Galaxy")
with gr.Column():
output_image = gr.Image(label="Grad-CAM Visualization")
result_text = gr.Markdown() # πŸ”΄ FIX 2: Textbox β†’ Markdown (read-only)
classify_btn.click(
fn=predict_galaxy,
inputs=[input_image],
outputs=[output_image, result_text],
api_name=False
)
gr.Markdown("<div style='height: 60px;'></div>")
gr.Markdown("""
# Understanding Dark Energy Through Galaxy Morphology
Galaxy morphology classification plays a crucial role in understanding dark energy,
one of the most profound mysteries in modern cosmology. Dark energy is the
mysterious force driving the accelerated expansion of the universe, and its nature
remains one of the biggest questions in physics.
By classifying large numbers of galaxies and mapping their distribution across
cosmic time, astronomers can trace the expansion history of the universe.
Different galaxy types (elliptical vs spiral) form and evolve differently, and
their relative abundances at different redshifts provide clues about the universe's
evolution. The distribution and clustering of these galaxies help measure the
large-scale structure of the universe, which is directly influenced by dark energy.
Automated classification systems like this one enable the analysis of millions of
galaxies from current and future surveys, such as the Vera C. Rubin Observatory's
Legacy Survey of Space and Time (LSST). These massive datasets will provide
unprecedented precision in measuring dark energy's properties and understanding
its role in the fate of the universe.
""")
# Launch
if __name__ == "__main__":
try:
demo.launch(show_api=False)
except Exception:
demo.launch()