Tiffany Degbotse commited on
Commit
77bb7d0
·
1 Parent(s): b9f3d31

Deploy Star Struck model API

Browse files
Files changed (3) hide show
  1. Dockerfile +12 -0
  2. model.py +156 -0
  3. robust_galaxy_model .pth +3 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
model.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+ import os
9
+
10
+ # --------------------
11
+ # Configuration
12
+ # --------------------
13
+ MODEL_PATH = "robust_galaxy_model (1).pth"
14
+ NUM_CLASSES = 2
15
+ CLASS_NAMES = ["Elliptical", "Spiral"]
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # --------------------
19
+ # Preprocessing
20
+ # --------------------
21
+ preprocess = transforms.Compose([
22
+ transforms.Resize((224, 224)),
23
+ transforms.ToTensor()
24
+ ])
25
+
26
+ # --------------------
27
+ # Model Definition
28
+ # --------------------
29
+ def get_model(num_classes=2):
30
+ model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
31
+
32
+ # Freeze backbone
33
+ for param in model.parameters():
34
+ param.requires_grad = False
35
+
36
+ # Unfreeze last residual block
37
+ for param in model.layer4.parameters():
38
+ param.requires_grad = True
39
+
40
+ # Replace classifier
41
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
42
+
43
+ return model
44
+
45
+
46
+ def load_model():
47
+ model = get_model(NUM_CLASSES)
48
+
49
+ if os.path.exists(MODEL_PATH):
50
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
51
+ model.load_state_dict(state_dict, strict=True)
52
+ print(f"Loaded model from {MODEL_PATH}")
53
+ else:
54
+ raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
55
+
56
+ model.to(DEVICE)
57
+ model.eval()
58
+ return model
59
+
60
+
61
+ # Load model ONCE at import time
62
+ model = load_model()
63
+
64
+ # --------------------
65
+ # Grad-CAM
66
+ # --------------------
67
+ class GradCAM:
68
+ def __init__(self, model, target_layer):
69
+ self.model = model
70
+ self.target_layer = target_layer
71
+ self.gradients = None
72
+ self.activations = None
73
+
74
+ def save_activation(self, module, input, output):
75
+ self.activations = output.detach()
76
+
77
+ def save_gradient(self, module, grad_input, grad_output):
78
+ self.gradients = grad_output[0].detach()
79
+
80
+ def generate_cam(self, input_image, target_class):
81
+ forward_handle = self.target_layer.register_forward_hook(self.save_activation)
82
+ backward_handle = self.target_layer.register_full_backward_hook(self.save_gradient)
83
+
84
+ try:
85
+ output = self.model(input_image)
86
+ score = output[0, target_class]
87
+
88
+ self.model.zero_grad()
89
+ score.backward()
90
+
91
+ gradients = self.gradients[0]
92
+ activations = self.activations[0]
93
+
94
+ weights = gradients.mean(dim=(1, 2), keepdim=True)
95
+ cam = (weights * activations).sum(dim=0)
96
+
97
+ cam = F.relu(cam)
98
+ cam -= cam.min()
99
+ cam /= cam.max() + 1e-8
100
+
101
+ return cam.cpu().numpy()
102
+
103
+ finally:
104
+ forward_handle.remove()
105
+ backward_handle.remove()
106
+
107
+
108
+ def overlay_heatmap(image, heatmap, alpha=0.4):
109
+ heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
110
+ heatmap_colored = cv2.applyColorMap(
111
+ np.uint8(255 * heatmap_resized),
112
+ cv2.COLORMAP_JET
113
+ )
114
+ return cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)
115
+
116
+ # --------------------
117
+ # Prediction Function
118
+ # --------------------
119
+ def predict_galaxy(image: Image.Image):
120
+ """
121
+ Args:
122
+ image (PIL.Image)
123
+
124
+ Returns:
125
+ overlay_pil (PIL.Image)
126
+ result_text (str)
127
+ """
128
+
129
+ if image.mode != "RGB":
130
+ image = image.convert("RGB")
131
+
132
+ img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
133
+ img_tensor.requires_grad = True
134
+
135
+ with torch.set_grad_enabled(True):
136
+ outputs = model(img_tensor)
137
+ probs = F.softmax(outputs, dim=1)
138
+
139
+ raw_probs = probs[0].detach().cpu().numpy()
140
+ pred_class = int(np.argmax(raw_probs))
141
+ pred_prob = raw_probs[pred_class]
142
+
143
+ gradcam = GradCAM(model, model.layer4)
144
+ cam = gradcam.generate_cam(img_tensor, pred_class)
145
+
146
+ img_np = np.array(image.resize((224, 224)))
147
+ overlay = overlay_heatmap(img_np, cam)
148
+ overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
149
+ overlay_pil = Image.fromarray(overlay)
150
+
151
+ result_text = (
152
+ f"Predicted Class: {CLASS_NAMES[pred_class]}\n"
153
+ f"Probability: {pred_prob:.2%}"
154
+ )
155
+
156
+ return overlay_pil, result_text
robust_galaxy_model .pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6217b4ef7679ccf90ab16c733ac4fe7810376c389330a7fe663f718114f8823
3
+ size 44790923