Star_Struck_Model / model.py
Tiffany Degbotse
Deploy Star Struck model API
77bb7d0
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
import os
# --------------------
# Configuration
# --------------------
MODEL_PATH = "robust_galaxy_model (1).pth"
NUM_CLASSES = 2
CLASS_NAMES = ["Elliptical", "Spiral"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --------------------
# Preprocessing
# --------------------
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# --------------------
# Model Definition
# --------------------
def get_model(num_classes=2):
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# Freeze backbone
for param in model.parameters():
param.requires_grad = False
# Unfreeze last residual block
for param in model.layer4.parameters():
param.requires_grad = True
# Replace classifier
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
def load_model():
model = get_model(NUM_CLASSES)
if os.path.exists(MODEL_PATH):
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict, strict=True)
print(f"Loaded model from {MODEL_PATH}")
else:
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
model.to(DEVICE)
model.eval()
return model
# Load model ONCE at import time
model = load_model()
# --------------------
# Grad-CAM
# --------------------
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate_cam(self, input_image, target_class):
forward_handle = self.target_layer.register_forward_hook(self.save_activation)
backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
try:
output = self.model(input_image)
score = output[0, target_class]
self.model.zero_grad()
score.backward()
gradients = self.gradients[0]
activations = self.activations[0]
weights = gradients.mean(dim=(1, 2), keepdim=True)
cam = (weights * activations).sum(dim=0)
cam = F.relu(cam)
cam -= cam.min()
cam /= cam.max() + 1e-8
return cam.cpu().numpy()
finally:
forward_handle.remove()
backward_handle.remove()
def overlay_heatmap(image, heatmap, alpha=0.4):
heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap_colored = cv2.applyColorMap(
np.uint8(255 * heatmap_resized),
cv2.COLORMAP_JET
)
return cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
# --------------------
# Prediction Function
# --------------------
def predict_galaxy(image: Image.Image):
"""
Args:
image (PIL.Image)
Returns:
overlay_pil (PIL.Image)
result_text (str)
"""
if image.mode != "RGB":
image = image.convert("RGB")
img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
img_tensor.requires_grad = True
with torch.set_grad_enabled(True):
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
raw_probs = probs[0].detach().cpu().numpy()
pred_class = int(np.argmax(raw_probs))
pred_prob = raw_probs[pred_class]
gradcam = GradCAM(model, model.layer4)
cam = gradcam.generate_cam(img_tensor, pred_class)
img_np = np.array(image.resize((224, 224)))
overlay = overlay_heatmap(img_np, cam)
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
overlay_pil = Image.fromarray(overlay)
result_text = (
f"Predicted Class: {CLASS_NAMES[pred_class]}\n"
f"Probability: {pred_prob:.2%}"
)
return overlay_pil, result_text