RoAr777 commited on
Commit
812cd20
·
verified ·
1 Parent(s): 51468b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +448 -17
app.py CHANGED
@@ -9,30 +9,398 @@ import pickle
9
  import numpy as np
10
  import cv2
11
  from PIL import Image
12
- from fastapi import FastAPI, File, UploadFile, HTTPException
 
 
13
 
14
- # --- START OF FIX ---
15
- # We must import torch FIRST.
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
 
20
- # NOW, we apply the global patch.
21
- # This fixes the "Custom classes... not available" error
22
- # by putting the patched Path object in the global namespace.
23
- import pathlib
24
- pathlib.PosixPath = pathlib.Path
25
 
26
- # FINALLY, we import fastai.
27
  from fastai.vision.all import load_learner, PILImage
28
- # --- END OF FIX ---
29
 
30
 
31
  # =======================
32
  # CONFIG
33
  # =======================
34
- # ... (rest of the file is unchanged) ...
35
- # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # =======================
37
  # FASTAPI SERVER
38
  # =======================
@@ -40,7 +408,6 @@ from fastai.vision.all import load_learner, PILImage
40
  # Store model in a global cache
41
  class ModelCache:
42
  learn = None
43
- class_names = None
44
 
45
  model_cache = ModelCache()
46
 
@@ -53,11 +420,12 @@ async def lifespan(app: FastAPI):
53
  print(f"FATAL: Model file not found at {model_path}", file=sys.stderr)
54
  else:
55
  try:
56
- # We NO LONGER need the patch here, it's global now.
 
 
57
 
58
  # Force CPU loading
59
  model_cache.learn = load_learner(model_path, cpu=True)
60
- model_cache.class_names = model_cache.learn.dls.vocab
61
  print("Learner loaded successfully.")
62
 
63
  except Exception as e:
@@ -65,7 +433,70 @@ async def lifespan(app: FastAPI):
65
  yield
66
  # Clear model from memory on shutdown
67
  model_cache.learn = None
68
- model_cache.class_names = None
69
  print("Model cache cleared.")
70
 
71
- # ... (rest of the file is unchanged) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import cv2
11
  from PIL import Image
12
+ from fastapi import FastAPI, File, UploadFile, HTTPException, status
13
+ from pydantic import BaseModel
14
+ from typing import List, Dict, Any
15
 
16
+ # --- IMPORT ORDER FIX ---
17
+ # 1. Import torch FIRST
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
 
22
+ # 2. REMOVE the pathlib patch
23
+ # The global patch was breaking matplotlib, which fastai imports.
24
+ # import pathlib
25
+ # pathlib.PosixPath = pathlib.Path
 
26
 
27
+ # 3. Import fastai LAST
28
  from fastai.vision.all import load_learner, PILImage
29
+ # --- END IMPORT ORDER FIX ---
30
 
31
 
32
  # =======================
33
  # CONFIG
34
  # =======================
35
+ class Config:
36
+ IMG_SIZE_CLF = 224
37
+ CAM_PERCENTILE = 75
38
+ MIN_AREA_RATIO = 0.01
39
+
40
+ cfg = Config()
41
+
42
+ # =======================
43
+ # GRAD-CAM IMPLEMENTATION
44
+ # =======================
45
+ class GradCAM:
46
+ """Grad-CAM for single image inference."""
47
+
48
+ def __init__(self, learn):
49
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
+ self.device = device
51
+ self.model = learn.model.to(device).eval()
52
+ self.target_layer = self._find_target_layer()
53
+
54
+ def _find_target_layer(self):
55
+ """Find last spatial conv layer (not 1x1 convolutions)."""
56
+ last_conv = None
57
+ last_conv_name = None
58
+
59
+ # Iterate through all modules
60
+ for name, m in self.model.named_modules():
61
+ if isinstance(m, nn.Conv2d):
62
+ # Skip 1x1 convolutions (classifier heads)
63
+ if m.kernel_size != (1, 1):
64
+ last_conv = m
65
+ last_conv_name = name
66
+
67
+ if last_conv is None:
68
+ # Fallback: try to find ANY conv layer
69
+ for name, m in self.model.named_modules():
70
+ if isinstance(m, nn.Conv2d):
71
+ last_conv = m
72
+ last_conv_name = name
73
+
74
+ if last_conv is None:
75
+ raise RuntimeError("No Conv2d layer found in model")
76
+
77
+ return last_conv
78
+
79
+ def compute(self, img_path, target_class_idx):
80
+ """Compute Grad-CAM for a single image."""
81
+
82
+ try:
83
+ # Load and preprocess image
84
+ img = PILImage.create(img_path)
85
+ img_np = np.array(img.resize((cfg.IMG_SIZE_CLF, cfg.IMG_SIZE_CLF)))
86
+ img_tensor = torch.from_numpy(img_np).float() / 255.0
87
+
88
+ # Handle grayscale
89
+ if img_tensor.ndim == 2:
90
+ img_tensor = img_tensor.unsqueeze(0).repeat(3, 1, 1)
91
+ elif img_tensor.ndim == 3:
92
+ img_tensor = img_tensor.permute(2, 0, 1)
93
+ # Ensure 3 channels
94
+ if img_tensor.shape[0] == 1:
95
+ img_tensor = img_tensor.repeat(3, 1, 1)
96
+
97
+ # Add batch dimension
98
+ img_tensor = img_tensor.unsqueeze(0)
99
+
100
+ # ImageNet normalization
101
+ mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
102
+ std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
103
+
104
+ xb = img_tensor.to(self.device)
105
+ xb = (xb - mean) / std
106
+ xb = xb.requires_grad_(True)
107
+
108
+ # Hook storage
109
+ activations_list = []
110
+ gradients_list = []
111
+
112
+ def save_activation(module, input, output):
113
+ activations_list.clear()
114
+ activations_list.append(output.detach().clone())
115
+
116
+ def save_gradient(module, grad_in, grad_out):
117
+ gradients_list.clear()
118
+ if grad_out[0] is not None:
119
+ gradients_list.append(grad_out[0].detach().clone())
120
+
121
+ # Register hooks
122
+ fwd_handle = self.target_layer.register_forward_hook(save_activation)
123
+ bwd_handle = self.target_layer.register_full_backward_hook(save_gradient)
124
+
125
+ # Forward pass
126
+ self.model.zero_grad()
127
+ with torch.set_grad_enabled(True):
128
+ output = self.model(xb)
129
+
130
+ # Check activations
131
+ if len(activations_list) == 0:
132
+ print(f"⚠ Warning: Forward hook didn't fire", file=sys.stderr)
133
+ return None
134
+
135
+ # Backward pass
136
+ target_score = output[0, target_class_idx]
137
+ target_score.backward()
138
+
139
+ # Check gradients
140
+ if len(gradients_list) == 0:
141
+ print(f"⚠ Warning: Backward hook didn't fire", file=sys.stderr)
142
+ return None
143
+
144
+ # Get activations and gradients
145
+ acts = activations_list[0].to(self.device)
146
+ grads = gradients_list[0].to(self.device)
147
+
148
+ # Compute CAM
149
+ weights = grads.mean(dim=[2, 3], keepdim=True)
150
+ cam_map = (weights * acts).sum(dim=1).squeeze(0)
151
+ cam_map = F.relu(cam_map)
152
+
153
+ # Resize to original size
154
+ orig_img = Image.open(img_path)
155
+ orig_w, orig_h = orig_img.size
156
+ cam_resized = F.interpolate(
157
+ cam_map.unsqueeze(0).unsqueeze(0),
158
+ size=(orig_h, orig_w),
159
+ mode='bilinear',
160
+ align_corners=False
161
+ ).squeeze()
162
+
163
+ # Normalize
164
+ cam_min = cam_resized.min()
165
+ cam_max = cam_resized.max()
166
+
167
+ if cam_max - cam_min > 1e-8:
168
+ cam_normalized = (cam_resized - cam_min) / (cam_max - cam_min)
169
+ else:
170
+ cam_normalized = torch.zeros_like(cam_resized)
171
+
172
+ # Cleanup
173
+ fwd_handle.remove()
174
+ bwd_handle.remove()
175
+ self.model.zero_grad()
176
+
177
+ return cam_normalized.clamp(0, 1).detach().cpu()
178
+
179
+ except Exception as e:
180
+ print(f"⚠ Grad-CAM error: {e}", file=sys.stderr)
181
+ return None
182
+
183
+ # =======================
184
+ # BBOX GENERATION
185
+ # =======================
186
+ def cam_to_multiscale_bboxes(cam, img_w, img_h):
187
+ """Generate multiple bboxes at different thresholds."""
188
+
189
+ if cam is None:
190
+ return []
191
+
192
+ cam_np = cam.numpy() if isinstance(cam, torch.Tensor) else cam
193
+ cam_np = (cam_np * 255).astype(np.uint8)
194
+
195
+ boxes = []
196
+ img_area = img_w * img_h
197
+
198
+ # Try multiple thresholds
199
+ percentiles = [60, 75, 85]
200
+ seen_boxes = set()
201
+
202
+ for percentile in percentiles:
203
+ non_zero_mask = cam_np > 0
204
+ if not np.any(non_zero_mask):
205
+ continue
206
+
207
+ thresh_val = np.percentile(cam_np[non_zero_mask], percentile)
208
+ _, thresh = cv2.threshold(cam_np, int(thresh_val), 255, cv2.THRESH_BINARY)
209
+
210
+ # Morphological cleanup
211
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
212
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
213
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
214
+
215
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
216
+
217
+ for cnt in contours:
218
+ area = cv2.contourArea(cnt)
219
+
220
+ # Dynamic min_area based on threshold
221
+ min_area_ratio = 0.005 if percentile == 60 else 0.01
222
+ min_area = min_area_ratio * img_area
223
+
224
+ if area > min_area:
225
+ x, y, w, h = cv2.boundingRect(cnt)
226
+
227
+ # Filter tiny boxes
228
+ if w < 10 or h < 10:
229
+ continue
230
+
231
+ # Avoid duplicates
232
+ box_key = (x // 5, y // 5, w // 5, h // 5)
233
+ if box_key not in seen_boxes:
234
+ seen_boxes.add(box_key)
235
+
236
+ # Confidence based on area and threshold
237
+ conf = (area / img_area) * (percentile / 100.0)
238
+ boxes.append([x, y, w, h, min(conf, 1.0)])
239
+
240
+ # Apply NMS
241
+ if len(boxes) > 1:
242
+ boxes = apply_nms(boxes, iou_threshold=0.5)
243
+
244
+ # Filter contained boxes
245
+ boxes = filter_contained_boxes(boxes, tolerance=10)
246
+
247
+ return boxes
248
+
249
+ def apply_nms(boxes, iou_threshold=0.5):
250
+ """Non-Maximum Suppression."""
251
+ if len(boxes) == 0:
252
+ return []
253
+
254
+ boxes = np.array(boxes)
255
+ x1 = boxes[:, 0]
256
+ y1 = boxes[:, 1]
257
+ x2 = boxes[:, 0] + boxes[:, 2]
258
+ y2 = boxes[:, 1] + boxes[:, 3]
259
+ scores = boxes[:, 4]
260
+
261
+ areas = (x2 - x1) * (y2 - y1)
262
+ order = scores.argsort()[::-1]
263
+
264
+ keep = []
265
+ while order.size > 0:
266
+ i = order[0]
267
+ keep.append(i)
268
+
269
+ xx1 = np.maximum(x1[i], x1[order[1:]])
270
+ yy1 = np.maximum(y1[i], y1[order[1:]])
271
+ xx2 = np.minimum(x2[i], x2[order[1:]])
272
+ yy2 = np.minimum(y2[i], y2[order[1:]])
273
+
274
+ w = np.maximum(0.0, xx2 - xx1)
275
+ h = np.maximum(0.0, yy2 - yy1)
276
+ inter = w * h
277
+
278
+ iou = inter / (areas[i] + areas[order[1:]] - inter)
279
+
280
+ inds = np.where(iou <= iou_threshold)[0]
281
+ order = order[inds + 1]
282
+
283
+ return boxes[keep].tolist()
284
+
285
+ def filter_contained_boxes(boxes, tolerance=10):
286
+ """Filter out boxes that are contained within larger boxes with tolerance."""
287
+ if len(boxes) <= 1:
288
+ return boxes
289
+
290
+ # Sort by area descending (larger first)
291
+ boxes_sorted = sorted(boxes, key=lambda b: b[2] * b[3], reverse=True)
292
+ filtered = []
293
+
294
+ for box in boxes_sorted:
295
+ contained = False
296
+ for larger_box in filtered:
297
+ if is_contained(box, larger_box, tolerance):
298
+ contained = True
299
+ break
300
+ if not contained:
301
+ filtered.append(box)
302
+
303
+ return filtered
304
+
305
+ def is_contained(small_box, large_box, tolerance):
306
+ """Check if small_box is contained within large_box with tolerance."""
307
+ sx, sy, sw, sh = small_box[:4]
308
+ lx, ly, lw, lh = large_box[:4]
309
+
310
+ return (sx >= lx - tolerance and
311
+ sy >= ly - tolerance and
312
+ sx + sw <= lx + lw + tolerance and
313
+ sy + sh <= ly + lh + tolerance)
314
+
315
+ # =======================
316
+ # MODIFIED MAIN INFERENCE
317
+ # =======================
318
+ def run_inference(image_path, learn):
319
+ """Run inference using classifier + Grad-CAM."""
320
+
321
+ # Get class names from the loaded learner
322
+ class_names = learn.dls.vocab
323
+
324
+ # Get prediction
325
+ img = PILImage.create(image_path)
326
+
327
+ # Manual preprocessing
328
+ img_np = np.array(img.resize((cfg.IMG_SIZE_CLF, cfg.IMG_SIZE_CLF)))
329
+ img_tensor = torch.from_numpy(img_np).float() / 255.0
330
+
331
+ # Handle grayscale
332
+ if img_tensor.ndim == 2:
333
+ img_tensor = img_tensor.unsqueeze(0).repeat(3, 1, 1)
334
+ elif img_tensor.ndim == 3:
335
+ img_tensor = img_tensor.permute(2, 0, 1)
336
+ # Ensure 3 channels
337
+ if img_tensor.shape[0] == 1:
338
+ img_tensor = img_tensor.repeat(3, 1, 1)
339
+
340
+ # Add batch dimension
341
+ img_tensor = img_tensor.unsqueeze(0)
342
+
343
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
344
+
345
+ # ImageNet normalization
346
+ mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
347
+ std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
348
+ xb = (img_tensor - mean) / std
349
+
350
+ with torch.no_grad():
351
+ output = learn.model(xb.to(device))
352
+
353
+ pred_idx = output.argmax(dim=1).item()
354
+ predicted_class = class_names[pred_idx]
355
+ probs = F.softmax(output, dim=1).squeeze(0)
356
+ confidence = probs[pred_idx].item()
357
+
358
+ # Get image dimensions
359
+ orig_img = Image.open(image_path)
360
+ img_w, img_h = orig_img.size
361
+
362
+ # Generate Grad-CAM
363
+ gradcam = GradCAM(learn)
364
+ cam = gradcam.compute(image_path, pred_idx)
365
+
366
+ # Generate bounding boxes
367
+ boxes = cam_to_multiscale_bboxes(cam, img_w, img_h)
368
+
369
+ # Filter overlapping boxes
370
+ boxes = filter_contained_boxes(boxes, tolerance=10)
371
+
372
+ # Format detections
373
+ detections = []
374
+ for box in boxes:
375
+ x, y, w, h, conf = box
376
+ detections.append({
377
+ 'diseaseName': predicted_class,
378
+ 'confidence': float(conf * confidence), # Combined confidence
379
+ 'boundingBox': {
380
+ 'x': int(x),
381
+ 'y': int(y),
382
+ 'width': int(w),
383
+ 'height': int(h)
384
+ },
385
+ 'classId': pred_idx
386
+ })
387
+
388
+ # If no boxes found, return full image as bbox
389
+ if len(detections) == 0:
390
+ detections.append({
391
+ 'diseaseName': predicted_class,
392
+ 'confidence': confidence,
393
+ 'boundingBox': {
394
+ 'x': 0,
395
+ 'y': 0,
396
+ 'width': img_w,
397
+ 'height': img_h
398
+ },
399
+ 'classId': pred_idx
400
+ })
401
+
402
+ return detections
403
+
404
  # =======================
405
  # FASTAPI SERVER
406
  # =======================
 
408
  # Store model in a global cache
409
  class ModelCache:
410
  learn = None
 
411
 
412
  model_cache = ModelCache()
413
 
 
420
  print(f"FATAL: Model file not found at {model_path}", file=sys.stderr)
421
  else:
422
  try:
423
+ # We have REMOVED the pathlib patch.
424
+ # If this fails, the model was saved with a patch and
425
+ # this is a more complex problem.
426
 
427
  # Force CPU loading
428
  model_cache.learn = load_learner(model_path, cpu=True)
 
429
  print("Learner loaded successfully.")
430
 
431
  except Exception as e:
 
433
  yield
434
  # Clear model from memory on shutdown
435
  model_cache.learn = None
 
436
  print("Model cache cleared.")
437
 
438
+ # Define Pydantic models for response
439
+ class BoundingBox(BaseModel):
440
+ x: int
441
+ y: int
442
+ width: int
443
+ height: int
444
+
445
+ class Detection(BaseModel):
446
+ diseaseName: str
447
+ confidence: float
448
+ boundingBox: BoundingBox
449
+ classId: int
450
+
451
+ class PredictionResponse(BaseModel):
452
+ detections: List[Detection]
453
+
454
+ # Initialize FastAPI app with the lifespan event handler
455
+ app = FastAPI(lifespan=lifespan)
456
+
457
+ @app.get("/")
458
+ def read_root():
459
+ """Root endpoint for health check."""
460
+ return {"status": "ok", "model_loaded": model_cache.learn is not None}
461
+
462
+ @app.post("/predict", response_model=PredictionResponse)
463
+ async def predict(file: UploadFile = File(...)):
464
+ """Accepts an image, saves it, runs inference, and returns detections."""
465
+
466
+ # Check if model is loaded
467
+ if model_cache.learn is None:
468
+ raise HTTPException(
469
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
470
+ detail="Model is not loaded. Check startup logs."
471
+ )
472
+
473
+ # Define a temporary path to save the uploaded image
474
+ # Using /tmp/ is standard for temporary files in Linux containers
475
+ temp_image_path = f"/tmp/{file.filename}"
476
+
477
+ try:
478
+ # Asynchronously save the uploaded file
479
+ async with aiofiles.open(temp_image_path, 'wb') as out_file:
480
+ content = await file.read()
481
+ await out_file.write(content)
482
+
483
+ # Run inference using the saved file path
484
+ detections = run_inference(temp_image_path, model_cache.learn)
485
+
486
+ # Return the formatted detections
487
+ return {"detections": detections}
488
+
489
+ except Exception as e:
490
+ print(f"Error during prediction: {e}", file=sys.stderr)
491
+ raise HTTPException(
492
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
493
+ detail=f"Inference error: {str(e)}"
494
+ )
495
+
496
+ finally:
497
+ # Clean up the temporary file
498
+ if os.path.exists(temp_image_path):
499
+ os.remove(temp_image_path)
500
+
501
+ # Note: The `if __name__ == "__main__":` block is removed.
502
+ # Uvicorn will run this "app" object.