Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import os | |
| # -------------------- | |
| # Configuration | |
| # -------------------- | |
| MODEL_PATH = "robust_galaxy_model (1).pth" | |
| NUM_CLASSES = 2 | |
| CLASS_NAMES = ["Elliptical", "Spiral"] | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # -------------------- | |
| # Preprocessing | |
| # -------------------- | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| # -------------------- | |
| # Model Definition | |
| # -------------------- | |
| def get_model(num_classes=2): | |
| model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) | |
| # Freeze backbone | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze last residual block | |
| for param in model.layer4.parameters(): | |
| param.requires_grad = True | |
| # Replace classifier | |
| model.fc = nn.Linear(model.fc.in_features, num_classes) | |
| return model | |
| def load_model(): | |
| model = get_model(NUM_CLASSES) | |
| if os.path.exists(MODEL_PATH): | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(state_dict, strict=True) | |
| print(f"Loaded model from {MODEL_PATH}") | |
| else: | |
| raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| # Load model ONCE at import time | |
| model = load_model() | |
| # -------------------- | |
| # Grad-CAM | |
| # -------------------- | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| def save_activation(self, module, input, output): | |
| self.activations = output.detach() | |
| def save_gradient(self, module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| def generate_cam(self, input_image, target_class): | |
| forward_handle = self.target_layer.register_forward_hook(self.save_activation) | |
| backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient) | |
| try: | |
| output = self.model(input_image) | |
| score = output[0, target_class] | |
| self.model.zero_grad() | |
| score.backward() | |
| gradients = self.gradients[0] | |
| activations = self.activations[0] | |
| weights = gradients.mean(dim=(1, 2), keepdim=True) | |
| cam = (weights * activations).sum(dim=0) | |
| cam = F.relu(cam) | |
| cam -= cam.min() | |
| cam /= cam.max() + 1e-8 | |
| return cam.cpu().numpy() | |
| finally: | |
| forward_handle.remove() | |
| backward_handle.remove() | |
| def overlay_heatmap(image, heatmap, alpha=0.4): | |
| heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0])) | |
| heatmap_colored = cv2.applyColorMap( | |
| np.uint8(255 * heatmap_resized), | |
| cv2.COLORMAP_JET | |
| ) | |
| return cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0) | |
| # -------------------- | |
| # Prediction Function | |
| # -------------------- | |
| def predict_galaxy(image: Image.Image): | |
| """ | |
| Args: | |
| image (PIL.Image) | |
| Returns: | |
| overlay_pil (PIL.Image) | |
| result_text (str) | |
| """ | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| img_tensor = preprocess(image).unsqueeze(0).to(DEVICE) | |
| img_tensor.requires_grad = True | |
| with torch.set_grad_enabled(True): | |
| outputs = model(img_tensor) | |
| probs = F.softmax(outputs, dim=1) | |
| raw_probs = probs[0].detach().cpu().numpy() | |
| pred_class = int(np.argmax(raw_probs)) | |
| pred_prob = raw_probs[pred_class] | |
| gradcam = GradCAM(model, model.layer4) | |
| cam = gradcam.generate_cam(img_tensor, pred_class) | |
| img_np = np.array(image.resize((224, 224))) | |
| overlay = overlay_heatmap(img_np, cam) | |
| overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) | |
| overlay_pil = Image.fromarray(overlay) | |
| result_text = ( | |
| f"Predicted Class: {CLASS_NAMES[pred_class]}\n" | |
| f"Probability: {pred_prob:.2%}" | |
| ) | |
| return overlay_pil, result_text | |