monish211205 commited on
Commit
b34d4db
·
verified ·
1 Parent(s): c607932

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +283 -0
  2. requirements.txt +30 -0
inference.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediScan AI - Production Inference Engine
3
+ Handles model loading, image preprocessing, prediction, and Grad-CAM generation.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import io
9
+ import base64
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from PIL import Image
19
+ from torchvision import transforms
20
+ from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Model Architecture (must match training definition exactly)
26
+ # ---------------------------------------------------------------------------
27
+
28
+ class MediScanModel(nn.Module):
29
+ def __init__(self, num_classes: int = 2, dropout: float = 0.4):
30
+ super().__init__()
31
+ backbone = efficientnet_b4(weights=None)
32
+ self.features = backbone.features
33
+ self.avgpool = backbone.avgpool
34
+ in_features = backbone.classifier[1].in_features
35
+ self.classifier = nn.Sequential(
36
+ nn.Dropout(p=dropout),
37
+ nn.Linear(in_features, 512),
38
+ nn.BatchNorm1d(512),
39
+ nn.SiLU(),
40
+ nn.Dropout(p=dropout / 2),
41
+ nn.Linear(512, num_classes),
42
+ )
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ x = self.features(x)
46
+ x = self.avgpool(x)
47
+ x = torch.flatten(x, 1)
48
+ x = self.classifier(x)
49
+ return x
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Grad-CAM
54
+ # ---------------------------------------------------------------------------
55
+
56
+ class GradCAM:
57
+ """
58
+ Generates Gradient-weighted Class Activation Maps.
59
+ Hooks into the final convolutional block of EfficientNetB4.
60
+ """
61
+ def __init__(self, model: MediScanModel):
62
+ self.model = model
63
+ self.gradients = None
64
+ self.activations = None
65
+ target_layer = model.features[-1]
66
+ target_layer.register_forward_hook(self._save_activation)
67
+ target_layer.register_full_backward_hook(self._save_gradient)
68
+
69
+ def _save_activation(self, module, input, output):
70
+ self.activations = output.detach()
71
+
72
+ def _save_gradient(self, module, grad_in, grad_out):
73
+ self.gradients = grad_out[0].detach()
74
+
75
+ def generate(self, input_tensor: torch.Tensor, class_idx: int) -> np.ndarray:
76
+ self.model.eval()
77
+ output = self.model(input_tensor)
78
+ self.model.zero_grad()
79
+ output[0, class_idx].backward()
80
+
81
+ weights = self.gradients.mean(dim=[2, 3], keepdim=True)
82
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
83
+ cam = torch.relu(cam)
84
+ cam = cam - cam.min()
85
+ if cam.max() > 0:
86
+ cam = cam / cam.max()
87
+
88
+ cam = F.interpolate(
89
+ cam,
90
+ size=(input_tensor.shape[2], input_tensor.shape[3]),
91
+ mode='bilinear',
92
+ align_corners=False
93
+ )
94
+ return cam.squeeze().cpu().numpy()
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Inference Engine
99
+ # ---------------------------------------------------------------------------
100
+
101
+ class InferenceEngine:
102
+ """
103
+ Singleton-style inference engine.
104
+ Loads the model once and serves all prediction requests.
105
+ """
106
+
107
+ CLASSES = ["NORMAL", "PNEUMONIA"]
108
+ IMAGE_SIZE = 380
109
+
110
+ RISK_MAP = {
111
+ "NORMAL": {
112
+ "high" : ("LOW", "No radiographic signs of pneumonia detected."),
113
+ "moderate" : ("MODERATE", "Findings suggest normal presentation. Clinical correlation advised."),
114
+ "low" : ("MODERATE", "Result is inconclusive. Radiologist review recommended."),
115
+ },
116
+ "PNEUMONIA": {
117
+ "high" : ("HIGH", "Radiographic signs consistent with pneumonia. Immediate clinical evaluation required."),
118
+ "moderate" : ("HIGH", "Findings suspicious for pneumonia. Clinical and laboratory correlation required."),
119
+ "low" : ("MODERATE", "Possible early pneumonia or other infiltrate. Further workup recommended."),
120
+ },
121
+ }
122
+
123
+ def __init__(self):
124
+ self._model: Optional[MediScanModel] = None
125
+ self._gradcam: Optional[GradCAM] = None
126
+ self._device: torch.device = torch.device("cpu")
127
+
128
+ self._transform = transforms.Compose([
129
+ transforms.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE)),
130
+ transforms.ToTensor(),
131
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
132
+ std =[0.229, 0.224, 0.225]),
133
+ ])
134
+
135
+ def load(self, model_path: str) -> None:
136
+ """Load model weights from checkpoint file."""
137
+ path = Path(model_path)
138
+ if not path.exists():
139
+ raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
140
+
141
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
+ checkpoint = torch.load(model_path, map_location=self._device, weights_only=False)
143
+
144
+ self._model = MediScanModel(num_classes=2, dropout=0.4)
145
+ self._model.load_state_dict(checkpoint["model_state_dict"])
146
+ self._model.eval()
147
+ self._model.to(self._device)
148
+
149
+ self._gradcam = GradCAM(self._model)
150
+
151
+ val_auc = checkpoint.get("val_auc", "N/A")
152
+ val_acc = checkpoint.get("val_acc", "N/A")
153
+ logger.info(
154
+ "Model loaded — device=%s val_auc=%.4f val_acc=%.2f%%",
155
+ self._device, val_auc, val_acc
156
+ )
157
+
158
+ @property
159
+ def is_loaded(self) -> bool:
160
+ return self._model is not None
161
+
162
+ def predict(self, image_bytes: bytes) -> dict:
163
+ """
164
+ Run full inference pipeline on raw image bytes.
165
+
166
+ Returns a structured result dict containing:
167
+ - predicted_class : str
168
+ - confidence : float (0-100)
169
+ - all_probabilities : dict[str, float]
170
+ - risk_level : 'LOW' | 'MODERATE' | 'HIGH'
171
+ - clinical_note : str
172
+ - gradcam_overlay : base64-encoded PNG
173
+ - model_version : str
174
+ """
175
+ if not self.is_loaded:
176
+ raise RuntimeError("Model not loaded. Call load() first.")
177
+
178
+ # Decode and preprocess image
179
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
180
+ input_tensor = self._transform(pil_image).unsqueeze(0).to(self._device)
181
+
182
+ # Forward pass
183
+ with torch.no_grad():
184
+ logits = self._model(input_tensor)
185
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
186
+
187
+ predicted_idx = int(np.argmax(probs))
188
+ predicted_class = self.CLASSES[predicted_idx]
189
+ confidence = float(probs[predicted_idx]) * 100.0
190
+
191
+ # Grad-CAM (requires gradients)
192
+ gradcam_b64 = self._generate_gradcam_overlay(input_tensor, pil_image, predicted_idx)
193
+
194
+ # Risk assessment
195
+ conf_band = "high" if confidence >= 80 else ("moderate" if confidence >= 60 else "low")
196
+ risk_level, clinical_note = self.RISK_MAP[predicted_class][conf_band]
197
+
198
+ return {
199
+ "predicted_class" : predicted_class,
200
+ "confidence" : round(confidence, 2),
201
+ "all_probabilities" : {
202
+ cls: round(float(p) * 100, 2)
203
+ for cls, p in zip(self.CLASSES, probs)
204
+ },
205
+ "risk_level" : risk_level,
206
+ "clinical_note" : clinical_note,
207
+ "gradcam_overlay" : gradcam_b64,
208
+ "model_version" : "MediScan-EfficientNetB4-v5",
209
+ }
210
+
211
+ def _generate_gradcam_overlay(
212
+ self,
213
+ input_tensor: torch.Tensor,
214
+ original_image: Image.Image,
215
+ class_idx: int
216
+ ) -> str:
217
+ """Generate Grad-CAM heatmap overlaid on original image, return as base64 PNG."""
218
+ try:
219
+ # Re-run forward pass with gradient tracking
220
+ tensor_grad = input_tensor.clone().requires_grad_(True)
221
+ cam = self._gradcam.generate(tensor_grad, class_idx)
222
+
223
+ # Suppress border artifacts: zero out outer 10% margin where
224
+ # normalization padding creates false high-gradient corners
225
+ h, w = cam.shape
226
+ bh, bw = max(1, int(h * 0.10)), max(1, int(w * 0.10))
227
+ border_mask = np.ones_like(cam)
228
+ border_mask[:bh, :] = 0
229
+ border_mask[-bh:, :] = 0
230
+ border_mask[:, :bw] = 0
231
+ border_mask[:, -bw:] = 0
232
+ cam = cam * border_mask
233
+
234
+ # Re-normalize after masking so the lung region fills 0..1
235
+ if cam.max() > 0:
236
+ cam = cam / cam.max()
237
+
238
+ # Gaussian smooth to reduce noise and make regions more contiguous
239
+ from scipy.ndimage import gaussian_filter
240
+ cam = gaussian_filter(cam, sigma=2)
241
+ if cam.max() > 0:
242
+ cam = cam / cam.max()
243
+
244
+ # Resize cam to original image size
245
+ img_w, img_h = original_image.size
246
+ cam_resized = np.array(
247
+ Image.fromarray((cam * 255).astype(np.uint8)).resize((img_w, img_h), Image.BILINEAR),
248
+ dtype=np.float32
249
+ ) / 255.0
250
+
251
+ # Apply jet colormap (matches matplotlib plt.cm.jet used in training notebook)
252
+ # jet: 0.0=blue, 0.25=cyan, 0.5=green, 0.75=yellow, 1.0=red
253
+ c = cam_resized
254
+ r = np.clip(1.5 - np.abs(c * 4.0 - 3.0), 0.0, 1.0)
255
+ g = np.clip(1.5 - np.abs(c * 4.0 - 2.0), 0.0, 1.0)
256
+ b = np.clip(1.5 - np.abs(c * 4.0 - 1.0), 0.0, 1.0)
257
+
258
+ # Convert original image to RGB numpy array
259
+ orig_np = np.array(original_image.convert("RGB"), dtype=np.float32) / 255.0
260
+
261
+ # Blend: 40% original + 60% jet heatmap (matches notebook visual style)
262
+ alpha = 0.6
263
+ blended = np.zeros((img_h, img_w, 3), dtype=np.float32)
264
+ blended[:,:,0] = (1 - alpha) * orig_np[:,:,0] + alpha * r
265
+ blended[:,:,1] = (1 - alpha) * orig_np[:,:,1] + alpha * g
266
+ blended[:,:,2] = (1 - alpha) * orig_np[:,:,2] + alpha * b
267
+ blended = np.clip(blended * 255, 0, 255).astype(np.uint8)
268
+
269
+ # Encode to base64
270
+ buf = io.BytesIO()
271
+ Image.fromarray(blended).save(buf, format="PNG", optimize=True)
272
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
273
+
274
+ except Exception as exc:
275
+ logger.warning("Grad-CAM generation failed: %s", exc)
276
+ # Return original image as fallback
277
+ buf = io.BytesIO()
278
+ original_image.save(buf, format="PNG")
279
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
280
+
281
+
282
+ # Module-level singleton
283
+ engine = InferenceEngine()
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MediScan AI - Python Dependencies
2
+
3
+ # Web Framework
4
+ fastapi==0.115.0
5
+ uvicorn[standard]==0.30.6
6
+
7
+ # Database
8
+ sqlalchemy==2.0.35
9
+
10
+ # Authentication
11
+ python-jose[cryptography]==3.3.0
12
+ passlib[bcrypt]==1.7.4
13
+
14
+ # Request parsing
15
+ python-multipart==0.0.9
16
+ pydantic[email]==2.9.2
17
+
18
+ # ML / Inference
19
+ torch>=2.0.0
20
+ torchvision>=0.15.0
21
+ Pillow>=10.0.0
22
+ numpy>=1.24.0
23
+
24
+ # Training only (Kaggle notebook)
25
+ scikit-learn>=1.3.0
26
+ matplotlib>=3.7.0
27
+ seaborn>=0.12.0
28
+
29
+ # Utilities
30
+ python-dotenv==1.0.1