RoAr777 commited on
Commit
f28de38
·
verified ·
1 Parent(s): d54080c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -237
app.py CHANGED
@@ -1,106 +1,55 @@
1
- import sys
 
 
 
 
 
2
  import json
3
  import warnings
4
- import os
5
- import aiofiles
6
- from contextlib import asynccontextmanager
7
- from pathlib import Path
8
- import pickle
9
- import platform
10
- import io
11
-
12
  warnings.filterwarnings('ignore')
13
 
 
 
 
 
 
 
 
 
 
14
  import numpy as np
15
  import cv2
 
16
  from PIL import Image
17
- from fastapi import FastAPI, File, UploadFile, HTTPException, status
18
- from pydantic import BaseModel
19
- from typing import List, Dict, Any
20
-
21
  import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
- from fastai.vision.all import PILImage, Learner
25
 
26
  # =======================
27
- # CUSTOM UNPICKLER FOR CROSS-PLATFORM COMPATIBILITY
28
  # =======================
29
- class CrossPlatformUnpickler(pickle.Unpickler):
30
- """Custom unpickler to handle Windows paths on Linux."""
31
-
32
- def find_class(self, module, name):
33
- """Override to handle pathlib classes."""
34
- if module == 'pathlib':
35
- # Map both PosixPath and WindowsPath to the generic Path
36
- if name in ('PosixPath', 'WindowsPath'):
37
- import pathlib
38
- return pathlib.Path
39
- return super().find_class(module, name)
40
 
41
- def load_model_cross_platform(model_path):
42
- """Load fastai model with cross-platform pathlib compatibility."""
43
- print(f"Attempting to load model from: {model_path}")
44
-
45
- try:
46
- # Read the pickle file
47
- with open(model_path, 'rb') as f:
48
- pkl_data = f.read()
49
-
50
- # Use custom unpickler
51
- print("Unpickling with custom unpickler...")
52
- learner = CrossPlatformUnpickler(io.BytesIO(pkl_data)).load()
53
-
54
- # If it's already a Learner object, return it
55
- if isinstance(learner, Learner):
56
- print("✓ Learner unpickled directly")
57
- learner.dls.cpu()
58
- return learner
59
-
60
- # If it's a dict or other structure, try to extract the learner
61
- print(f"Unpickled object type: {type(learner)}")
62
-
63
- # fastai sometimes wraps the learner in a dict
64
- if isinstance(learner, dict):
65
- if 'learner' in learner:
66
- learner = learner['learner']
67
- elif 'model' in learner:
68
- print("Found model in dict, attempting to reconstruct learner...")
69
- # This is trickier - you may need to reconstruct the Learner
70
- raise ValueError("Model dict format not directly supported. Please re-export your model.")
71
-
72
- if isinstance(learner, Learner):
73
- learner.dls.cpu()
74
- return learner
75
- else:
76
- raise ValueError(f"Unexpected unpickled type: {type(learner)}")
77
-
78
- except Exception as e:
79
- print(f"Custom unpickler failed: {e}")
80
- print("Attempting fallback with pathlib patch...")
81
-
82
- # Fallback: Try with pathlib patch
83
- import pathlib
84
- original_posix = getattr(pathlib, 'PosixPath', None)
85
- original_windows = getattr(pathlib, 'WindowsPath', None)
86
-
87
- try:
88
- # Patch pathlib
89
- if platform.system() != 'Windows':
90
- pathlib.WindowsPath = pathlib.Path
91
- pathlib.PosixPath = pathlib.Path
92
-
93
- # Try standard fastai loader
94
- from fastai.vision.all import load_learner
95
- learner = load_learner(model_path, cpu=True)
96
- return learner
97
-
98
- finally:
99
- # Restore pathlib
100
- if original_posix is not None:
101
- pathlib.PosixPath = original_posix
102
- if original_windows is not None:
103
- pathlib.WindowsPath = original_windows
104
 
105
  # =======================
106
  # CONFIG
@@ -109,7 +58,8 @@ class Config:
109
  IMG_SIZE_CLF = 224
110
  CAM_PERCENTILE = 75
111
  MIN_AREA_RATIO = 0.01
112
-
 
113
  cfg = Config()
114
 
115
  # =======================
@@ -117,28 +67,25 @@ cfg = Config()
117
  # =======================
118
  class GradCAM:
119
  """Grad-CAM for single image inference."""
120
-
121
  def __init__(self, learn):
122
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
123
  self.device = device
124
  self.model = learn.model.to(device).eval()
125
  self.target_layer = self._find_target_layer()
126
-
127
  def _find_target_layer(self):
128
  """Find last spatial conv layer (not 1x1 convolutions)."""
129
  last_conv = None
130
  last_conv_name = None
131
-
132
- # Iterate through all modules
133
  for name, m in self.model.named_modules():
134
  if isinstance(m, nn.Conv2d):
135
- # Skip 1x1 convolutions (classifier heads)
136
  if m.kernel_size != (1, 1):
137
  last_conv = m
138
  last_conv_name = name
139
 
140
  if last_conv is None:
141
- # Fallback: try to find ANY conv layer
142
  for name, m in self.model.named_modules():
143
  if isinstance(m, nn.Conv2d):
144
  last_conv = m
@@ -149,28 +96,22 @@ class GradCAM:
149
 
150
  return last_conv
151
 
152
- def compute(self, img_path, target_class_idx):
153
- """Compute Grad-CAM for a single image."""
154
 
155
  try:
156
- # Load and preprocess image
157
- img = PILImage.create(img_path)
158
- img_np = np.array(img.resize((cfg.IMG_SIZE_CLF, cfg.IMG_SIZE_CLF)))
159
  img_tensor = torch.from_numpy(img_np).float() / 255.0
160
 
161
- # Handle grayscale
162
  if img_tensor.ndim == 2:
163
  img_tensor = img_tensor.unsqueeze(0).repeat(3, 1, 1)
164
  elif img_tensor.ndim == 3:
165
  img_tensor = img_tensor.permute(2, 0, 1)
166
- # Ensure 3 channels
167
  if img_tensor.shape[0] == 1:
168
  img_tensor = img_tensor.repeat(3, 1, 1)
169
 
170
- # Add batch dimension
171
  img_tensor = img_tensor.unsqueeze(0)
172
 
173
- # ImageNet normalization
174
  mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
175
  std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
176
 
@@ -178,7 +119,6 @@ class GradCAM:
178
  xb = (xb - mean) / std
179
  xb = xb.requires_grad_(True)
180
 
181
- # Hook storage
182
  activations_list = []
183
  gradients_list = []
184
 
@@ -191,41 +131,30 @@ class GradCAM:
191
  if grad_out[0] is not None:
192
  gradients_list.append(grad_out[0].detach().clone())
193
 
194
- # Register hooks
195
  fwd_handle = self.target_layer.register_forward_hook(save_activation)
196
  bwd_handle = self.target_layer.register_full_backward_hook(save_gradient)
197
 
198
- # Forward pass
199
  self.model.zero_grad()
200
  with torch.set_grad_enabled(True):
201
  output = self.model(xb)
202
 
203
- # Check activations
204
  if len(activations_list) == 0:
205
- print(f"⚠ Warning: Forward hook didn't fire", file=sys.stderr)
206
  return None
207
 
208
- # Backward pass
209
  target_score = output[0, target_class_idx]
210
  target_score.backward()
211
 
212
- # Check gradients
213
  if len(gradients_list) == 0:
214
- print(f"⚠ Warning: Backward hook didn't fire", file=sys.stderr)
215
  return None
216
 
217
- # Get activations and gradients
218
  acts = activations_list[0].to(self.device)
219
  grads = gradients_list[0].to(self.device)
220
 
221
- # Compute CAM
222
  weights = grads.mean(dim=[2, 3], keepdim=True)
223
  cam_map = (weights * acts).sum(dim=1).squeeze(0)
224
  cam_map = F.relu(cam_map)
225
 
226
- # Resize to original size
227
- orig_img = Image.open(img_path)
228
- orig_w, orig_h = orig_img.size
229
  cam_resized = F.interpolate(
230
  cam_map.unsqueeze(0).unsqueeze(0),
231
  size=(orig_h, orig_w),
@@ -233,7 +162,6 @@ class GradCAM:
233
  align_corners=False
234
  ).squeeze()
235
 
236
- # Normalize
237
  cam_min = cam_resized.min()
238
  cam_max = cam_resized.max()
239
 
@@ -242,7 +170,6 @@ class GradCAM:
242
  else:
243
  cam_normalized = torch.zeros_like(cam_resized)
244
 
245
- # Cleanup
246
  fwd_handle.remove()
247
  bwd_handle.remove()
248
  self.model.zero_grad()
@@ -250,7 +177,7 @@ class GradCAM:
250
  return cam_normalized.clamp(0, 1).detach().cpu()
251
 
252
  except Exception as e:
253
- print(f"Grad-CAM error: {e}", file=sys.stderr)
254
  return None
255
 
256
  # =======================
@@ -268,7 +195,6 @@ def cam_to_multiscale_bboxes(cam, img_w, img_h):
268
  boxes = []
269
  img_area = img_w * img_h
270
 
271
- # Try multiple thresholds
272
  percentiles = [60, 75, 85]
273
  seen_boxes = set()
274
 
@@ -280,7 +206,6 @@ def cam_to_multiscale_bboxes(cam, img_w, img_h):
280
  thresh_val = np.percentile(cam_np[non_zero_mask], percentile)
281
  _, thresh = cv2.threshold(cam_np, int(thresh_val), 255, cv2.THRESH_BINARY)
282
 
283
- # Morphological cleanup
284
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
285
  thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
286
  thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
@@ -290,31 +215,25 @@ def cam_to_multiscale_bboxes(cam, img_w, img_h):
290
  for cnt in contours:
291
  area = cv2.contourArea(cnt)
292
 
293
- # Dynamic min_area based on threshold
294
  min_area_ratio = 0.005 if percentile == 60 else 0.01
295
  min_area = min_area_ratio * img_area
296
 
297
  if area > min_area:
298
  x, y, w, h = cv2.boundingRect(cnt)
299
 
300
- # Filter tiny boxes
301
  if w < 10 or h < 10:
302
  continue
303
 
304
- # Avoid duplicates
305
  box_key = (x // 5, y // 5, w // 5, h // 5)
306
  if box_key not in seen_boxes:
307
  seen_boxes.add(box_key)
308
 
309
- # Confidence based on area and threshold
310
  conf = (area / img_area) * (percentile / 100.0)
311
  boxes.append([x, y, w, h, min(conf, 1.0)])
312
 
313
- # Apply NMS
314
  if len(boxes) > 1:
315
  boxes = apply_nms(boxes, iou_threshold=0.5)
316
 
317
- # Filter contained boxes
318
  boxes = filter_contained_boxes(boxes, tolerance=10)
319
 
320
  return boxes
@@ -356,11 +275,10 @@ def apply_nms(boxes, iou_threshold=0.5):
356
  return boxes[keep].tolist()
357
 
358
  def filter_contained_boxes(boxes, tolerance=10):
359
- """Filter out boxes that are contained within larger boxes with tolerance."""
360
  if len(boxes) <= 1:
361
  return boxes
362
 
363
- # Sort by area descending (larger first)
364
  boxes_sorted = sorted(boxes, key=lambda b: b[2] * b[3], reverse=True)
365
  filtered = []
366
 
@@ -376,7 +294,7 @@ def filter_contained_boxes(boxes, tolerance=10):
376
  return filtered
377
 
378
  def is_contained(small_box, large_box, tolerance):
379
- """Check if small_box is contained within large_box with tolerance."""
380
  sx, sy, sw, sh = small_box[:4]
381
  lx, ly, lw, lh = large_box[:4]
382
 
@@ -386,36 +304,28 @@ def is_contained(small_box, large_box, tolerance):
386
  sy + sh <= ly + lh + tolerance)
387
 
388
  # =======================
389
- # INFERENCE
390
  # =======================
391
- def run_inference(image_path, learn):
392
  """Run inference using classifier + Grad-CAM."""
393
 
394
- # Get class names from the loaded learner
395
  class_names = learn.dls.vocab
396
 
397
- # Get prediction
398
- img = PILImage.create(image_path)
399
-
400
  # Manual preprocessing
401
- img_np = np.array(img.resize((cfg.IMG_SIZE_CLF, cfg.IMG_SIZE_CLF)))
402
  img_tensor = torch.from_numpy(img_np).float() / 255.0
403
 
404
- # Handle grayscale
405
  if img_tensor.ndim == 2:
406
  img_tensor = img_tensor.unsqueeze(0).repeat(3, 1, 1)
407
  elif img_tensor.ndim == 3:
408
  img_tensor = img_tensor.permute(2, 0, 1)
409
- # Ensure 3 channels
410
  if img_tensor.shape[0] == 1:
411
  img_tensor = img_tensor.repeat(3, 1, 1)
412
 
413
- # Add batch dimension
414
  img_tensor = img_tensor.unsqueeze(0)
415
 
416
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
417
 
418
- # ImageNet normalization
419
  mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
420
  std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
421
  xb = (img_tensor - mean) / std
@@ -428,27 +338,20 @@ def run_inference(image_path, learn):
428
  probs = F.softmax(output, dim=1).squeeze(0)
429
  confidence = probs[pred_idx].item()
430
 
431
- # Get image dimensions
432
- orig_img = Image.open(image_path)
433
- img_w, img_h = orig_img.size
434
 
435
- # Generate Grad-CAM
436
  gradcam = GradCAM(learn)
437
- cam = gradcam.compute(image_path, pred_idx)
438
 
439
- # Generate bounding boxes
440
  boxes = cam_to_multiscale_bboxes(cam, img_w, img_h)
441
-
442
- # Filter overlapping boxes
443
  boxes = filter_contained_boxes(boxes, tolerance=10)
444
 
445
- # Format detections
446
  detections = []
447
  for box in boxes:
448
  x, y, w, h, conf = box
449
  detections.append({
450
  'diseaseName': predicted_class,
451
- 'confidence': float(conf * confidence), # Combined confidence
452
  'boundingBox': {
453
  'x': int(x),
454
  'y': int(y),
@@ -458,7 +361,6 @@ def run_inference(image_path, learn):
458
  'classId': pred_idx
459
  })
460
 
461
- # If no boxes found, return full image as bbox
462
  if len(detections) == 0:
463
  detections.append({
464
  'diseaseName': predicted_class,
@@ -475,98 +377,148 @@ def run_inference(image_path, learn):
475
  return detections
476
 
477
  # =======================
478
- # FASTAPI SERVER
479
  # =======================
 
 
 
 
 
480
 
481
- # Store model in a global cache
482
- class ModelCache:
483
- learn = None
484
-
485
- model_cache = ModelCache()
486
-
487
- @asynccontextmanager
488
- async def lifespan(app: FastAPI):
489
- # Load the model on startup
490
- print("Loading Fastai learner...")
491
- model_path = "classifier.pkl"
492
- if not Path(model_path).exists():
493
- print(f"FATAL: Model file not found at {model_path}", file=sys.stderr)
494
- else:
495
- try:
496
- # Use our safe cross-platform loader
497
- model_cache.learn = load_model_cross_platform(model_path)
498
- print("✓ Learner loaded successfully.")
499
- print(f"✓ Classes: {model_cache.learn.dls.vocab}")
500
-
501
- except Exception as e:
502
- print(f"FATAL: Failed to load learner: {e}", file=sys.stderr)
503
- import traceback
504
- traceback.print_exc()
505
- yield
506
- # Clear model from memory on shutdown
507
- model_cache.learn = None
508
- print("Model cache cleared.")
509
-
510
- # Define Pydantic models for response
511
- class BoundingBox(BaseModel):
512
- x: int
513
- y: int
514
- width: int
515
- height: int
516
-
517
- class Detection(BaseModel):
518
- diseaseName: str
519
- confidence: float
520
- boundingBox: BoundingBox
521
- classId: int
522
-
523
- class PredictionResponse(BaseModel):
524
- detections: List[Detection]
525
 
526
- # Initialize FastAPI app with the lifespan event handler
527
- app = FastAPI(lifespan=lifespan)
528
 
529
- @app.get("/")
530
- def read_root():
531
- """Root endpoint for health check."""
532
- return {"status": "ok", "model_loaded": model_cache.learn is not None}
 
 
 
 
 
 
 
 
533
 
534
- @app.post("/predict", response_model=PredictionResponse)
535
- async def predict(file: UploadFile = File(...)):
536
- """Accepts an image, saves it, runs inference, and returns detections."""
537
-
538
- # Check if model is loaded
539
- if model_cache.learn is None:
540
- raise HTTPException(
541
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
542
- detail="Model is not loaded. Check startup logs."
543
- )
544
 
545
- # Define a temporary path to save the uploaded image
546
- temp_image_path = f"/tmp/{file.filename}"
 
 
 
 
 
 
 
547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  try:
549
- # Asynchronously save the uploaded file
550
- async with aiofiles.open(temp_image_path, 'wb') as out_file:
551
- content = await file.read()
552
- await out_file.write(content)
 
 
 
553
 
554
- # Run inference using the saved file path
555
- detections = run_inference(temp_image_path, model_cache.learn)
556
 
557
- # Return the formatted detections
558
- return {"detections": detections}
 
 
 
559
 
560
  except Exception as e:
561
- print(f"Error during prediction: {e}", file=sys.stderr)
562
- import traceback
563
- traceback.print_exc()
564
- raise HTTPException(
565
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
566
- detail=f"Inference error: {str(e)}"
567
- )
568
-
569
- finally:
570
- # Clean up the temporary file
571
- if os.path.exists(temp_image_path):
572
- os.remove(temp_image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Disease Detection Service with Grad-CAM
3
+ Usage: uvicorn app:app --host 0.0.0.0 --port 8000
4
+ """
5
+
6
+ import io
7
  import json
8
  import warnings
 
 
 
 
 
 
 
 
9
  warnings.filterwarnings('ignore')
10
 
11
+
12
+
13
+ from typing import List, Optional
14
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query
15
+ from fastapi.responses import JSONResponse
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel
18
+ import uvicorn
19
+
20
  import numpy as np
21
  import cv2
22
+ from pathlib import Path
23
  from PIL import Image
 
 
 
 
24
  import torch
25
  import torch.nn as nn
26
  import torch.nn.functional as F
27
+ from fastai.vision.all import load_learner, PILImage
28
 
29
  # =======================
30
+ # PYDANTIC MODELS
31
  # =======================
32
+ class BoundingBox(BaseModel):
33
+ x: int
34
+ y: int
35
+ width: int
36
+ height: int
 
 
 
 
 
 
37
 
38
+ class Detection(BaseModel):
39
+ diseaseName: str
40
+ confidence: float
41
+ boundingBox: BoundingBox
42
+ classId: int
43
+
44
+ class InferenceResponse(BaseModel):
45
+ success: bool
46
+ detections: List[Detection]
47
+ message: Optional[str] = None
48
+
49
+ class HealthResponse(BaseModel):
50
+ status: str
51
+ model_loaded: bool
52
+ device: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # =======================
55
  # CONFIG
 
58
  IMG_SIZE_CLF = 224
59
  CAM_PERCENTILE = 75
60
  MIN_AREA_RATIO = 0.01
61
+ MODEL_PATH = "classifier.pkl" # Default model path
62
+
63
  cfg = Config()
64
 
65
  # =======================
 
67
  # =======================
68
  class GradCAM:
69
  """Grad-CAM for single image inference."""
70
+
71
  def __init__(self, learn):
72
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73
  self.device = device
74
  self.model = learn.model.to(device).eval()
75
  self.target_layer = self._find_target_layer()
76
+
77
  def _find_target_layer(self):
78
  """Find last spatial conv layer (not 1x1 convolutions)."""
79
  last_conv = None
80
  last_conv_name = None
81
+
 
82
  for name, m in self.model.named_modules():
83
  if isinstance(m, nn.Conv2d):
 
84
  if m.kernel_size != (1, 1):
85
  last_conv = m
86
  last_conv_name = name
87
 
88
  if last_conv is None:
 
89
  for name, m in self.model.named_modules():
90
  if isinstance(m, nn.Conv2d):
91
  last_conv = m
 
96
 
97
  return last_conv
98
 
99
+ def compute(self, img_pil, target_class_idx):
100
+ """Compute Grad-CAM for a PIL image."""
101
 
102
  try:
103
+ img_np = np.array(img_pil.resize((cfg.IMG_SIZE_CLF, cfg.IMG_SIZE_CLF)))
 
 
104
  img_tensor = torch.from_numpy(img_np).float() / 255.0
105
 
 
106
  if img_tensor.ndim == 2:
107
  img_tensor = img_tensor.unsqueeze(0).repeat(3, 1, 1)
108
  elif img_tensor.ndim == 3:
109
  img_tensor = img_tensor.permute(2, 0, 1)
 
110
  if img_tensor.shape[0] == 1:
111
  img_tensor = img_tensor.repeat(3, 1, 1)
112
 
 
113
  img_tensor = img_tensor.unsqueeze(0)
114
 
 
115
  mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
116
  std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
117
 
 
119
  xb = (xb - mean) / std
120
  xb = xb.requires_grad_(True)
121
 
 
122
  activations_list = []
123
  gradients_list = []
124
 
 
131
  if grad_out[0] is not None:
132
  gradients_list.append(grad_out[0].detach().clone())
133
 
 
134
  fwd_handle = self.target_layer.register_forward_hook(save_activation)
135
  bwd_handle = self.target_layer.register_full_backward_hook(save_gradient)
136
 
 
137
  self.model.zero_grad()
138
  with torch.set_grad_enabled(True):
139
  output = self.model(xb)
140
 
 
141
  if len(activations_list) == 0:
 
142
  return None
143
 
 
144
  target_score = output[0, target_class_idx]
145
  target_score.backward()
146
 
 
147
  if len(gradients_list) == 0:
 
148
  return None
149
 
 
150
  acts = activations_list[0].to(self.device)
151
  grads = gradients_list[0].to(self.device)
152
 
 
153
  weights = grads.mean(dim=[2, 3], keepdim=True)
154
  cam_map = (weights * acts).sum(dim=1).squeeze(0)
155
  cam_map = F.relu(cam_map)
156
 
157
+ orig_w, orig_h = img_pil.size
 
 
158
  cam_resized = F.interpolate(
159
  cam_map.unsqueeze(0).unsqueeze(0),
160
  size=(orig_h, orig_w),
 
162
  align_corners=False
163
  ).squeeze()
164
 
 
165
  cam_min = cam_resized.min()
166
  cam_max = cam_resized.max()
167
 
 
170
  else:
171
  cam_normalized = torch.zeros_like(cam_resized)
172
 
 
173
  fwd_handle.remove()
174
  bwd_handle.remove()
175
  self.model.zero_grad()
 
177
  return cam_normalized.clamp(0, 1).detach().cpu()
178
 
179
  except Exception as e:
180
+ print(f"Grad-CAM error: {e}")
181
  return None
182
 
183
  # =======================
 
195
  boxes = []
196
  img_area = img_w * img_h
197
 
 
198
  percentiles = [60, 75, 85]
199
  seen_boxes = set()
200
 
 
206
  thresh_val = np.percentile(cam_np[non_zero_mask], percentile)
207
  _, thresh = cv2.threshold(cam_np, int(thresh_val), 255, cv2.THRESH_BINARY)
208
 
 
209
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
210
  thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
211
  thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
 
215
  for cnt in contours:
216
  area = cv2.contourArea(cnt)
217
 
 
218
  min_area_ratio = 0.005 if percentile == 60 else 0.01
219
  min_area = min_area_ratio * img_area
220
 
221
  if area > min_area:
222
  x, y, w, h = cv2.boundingRect(cnt)
223
 
 
224
  if w < 10 or h < 10:
225
  continue
226
 
 
227
  box_key = (x // 5, y // 5, w // 5, h // 5)
228
  if box_key not in seen_boxes:
229
  seen_boxes.add(box_key)
230
 
 
231
  conf = (area / img_area) * (percentile / 100.0)
232
  boxes.append([x, y, w, h, min(conf, 1.0)])
233
 
 
234
  if len(boxes) > 1:
235
  boxes = apply_nms(boxes, iou_threshold=0.5)
236
 
 
237
  boxes = filter_contained_boxes(boxes, tolerance=10)
238
 
239
  return boxes
 
275
  return boxes[keep].tolist()
276
 
277
  def filter_contained_boxes(boxes, tolerance=10):
278
+ """Filter out boxes that are contained within larger boxes."""
279
  if len(boxes) <= 1:
280
  return boxes
281
 
 
282
  boxes_sorted = sorted(boxes, key=lambda b: b[2] * b[3], reverse=True)
283
  filtered = []
284
 
 
294
  return filtered
295
 
296
  def is_contained(small_box, large_box, tolerance):
297
+ """Check if small_box is contained within large_box."""
298
  sx, sy, sw, sh = small_box[:4]
299
  lx, ly, lw, lh = large_box[:4]
300
 
 
304
  sy + sh <= ly + lh + tolerance)
305
 
306
  # =======================
307
+ # INFERENCE LOGIC
308
  # =======================
309
+ def run_inference(img_pil, learn):
310
  """Run inference using classifier + Grad-CAM."""
311
 
 
312
  class_names = learn.dls.vocab
313
 
 
 
 
314
  # Manual preprocessing
315
+ img_np = np.array(img_pil.resize((cfg.IMG_SIZE_CLF, cfg.IMG_SIZE_CLF)))
316
  img_tensor = torch.from_numpy(img_np).float() / 255.0
317
 
 
318
  if img_tensor.ndim == 2:
319
  img_tensor = img_tensor.unsqueeze(0).repeat(3, 1, 1)
320
  elif img_tensor.ndim == 3:
321
  img_tensor = img_tensor.permute(2, 0, 1)
 
322
  if img_tensor.shape[0] == 1:
323
  img_tensor = img_tensor.repeat(3, 1, 1)
324
 
 
325
  img_tensor = img_tensor.unsqueeze(0)
326
 
327
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
328
 
 
329
  mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
330
  std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
331
  xb = (img_tensor - mean) / std
 
338
  probs = F.softmax(output, dim=1).squeeze(0)
339
  confidence = probs[pred_idx].item()
340
 
341
+ img_w, img_h = img_pil.size
 
 
342
 
 
343
  gradcam = GradCAM(learn)
344
+ cam = gradcam.compute(img_pil, pred_idx)
345
 
 
346
  boxes = cam_to_multiscale_bboxes(cam, img_w, img_h)
 
 
347
  boxes = filter_contained_boxes(boxes, tolerance=10)
348
 
 
349
  detections = []
350
  for box in boxes:
351
  x, y, w, h, conf = box
352
  detections.append({
353
  'diseaseName': predicted_class,
354
+ 'confidence': float(conf * confidence),
355
  'boundingBox': {
356
  'x': int(x),
357
  'y': int(y),
 
361
  'classId': pred_idx
362
  })
363
 
 
364
  if len(detections) == 0:
365
  detections.append({
366
  'diseaseName': predicted_class,
 
377
  return detections
378
 
379
  # =======================
380
+ # FASTAPI APP
381
  # =======================
382
+ app = FastAPI(
383
+ title="Disease Detection API",
384
+ description="AI-powered disease detection service with Grad-CAM visualization",
385
+ version="1.0.0"
386
+ )
387
 
388
+ # CORS middleware
389
+ app.add_middleware(
390
+ CORSMiddleware,
391
+ allow_origins=["*"],
392
+ allow_credentials=True,
393
+ allow_methods=["*"],
394
+ allow_headers=["*"],
395
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
+ # Global model variable
398
+ model = None
399
 
400
+ @app.on_event("startup")
401
+ async def load_model():
402
+ """Load the classifier model on startup."""
403
+ global model
404
+ try:
405
+ if Path(cfg.MODEL_PATH).exists():
406
+ model = load_learner(cfg.MODEL_PATH)
407
+ print(f"✓ Model loaded from {cfg.MODEL_PATH}")
408
+ else:
409
+ print(f"⚠ Warning: Model file not found at {cfg.MODEL_PATH}")
410
+ except Exception as e:
411
+ print(f"✗ Error loading model: {e}")
412
 
413
+ @app.get("/", response_model=HealthResponse)
414
+ async def root():
415
+ """Root endpoint - health check."""
416
+ device = "cuda" if torch.cuda.is_available() else "cpu"
417
+ return {
418
+ "status": "running",
419
+ "model_loaded": model is not None,
420
+ "device": device
421
+ }
 
422
 
423
+ @app.get("/health", response_model=HealthResponse)
424
+ async def health_check():
425
+ """Health check endpoint."""
426
+ device = "cuda" if torch.cuda.is_available() else "cpu"
427
+ return {
428
+ "status": "healthy" if model is not None else "model_not_loaded",
429
+ "model_loaded": model is not None,
430
+ "device": device
431
+ }
432
 
433
+ @app.post("/predict", response_model=InferenceResponse)
434
+ async def predict(
435
+ file: UploadFile = File(...),
436
+ model_path: Optional[str] = Query(None, description="Optional custom model path")
437
+ ):
438
+ """
439
+ Predict disease from uploaded image.
440
+
441
+ Parameters:
442
+ - file: Image file (JPG, PNG, etc.)
443
+ - model_path: Optional custom model path (query parameter)
444
+
445
+ Returns:
446
+ - JSON with detections including disease name, confidence, and bounding boxes
447
+ """
448
+
449
+ # Check model
450
+ current_model = model
451
+ if model_path:
452
+ try:
453
+ if not Path(model_path).exists():
454
+ raise HTTPException(status_code=400, detail=f"Model not found: {model_path}")
455
+ current_model = load_learner(model_path)
456
+ except Exception as e:
457
+ raise HTTPException(status_code=500, detail=f"Error loading custom model: {str(e)}")
458
+
459
+ if current_model is None:
460
+ raise HTTPException(status_code=503, detail="Model not loaded. Please check server logs.")
461
+
462
+ # Validate file type
463
+ if not file.content_type.startswith('image/'):
464
+ raise HTTPException(status_code=400, detail="File must be an image")
465
+
466
  try:
467
+ # Read image
468
+ contents = await file.read()
469
+ img_pil = Image.open(io.BytesIO(contents))
470
+
471
+ # Convert RGBA to RGB if needed
472
+ if img_pil.mode == 'RGBA':
473
+ img_pil = img_pil.convert('RGB')
474
 
475
+ # Run inference
476
+ detections = run_inference(img_pil, current_model)
477
 
478
+ return {
479
+ "success": True,
480
+ "detections": detections,
481
+ "message": f"Detected {len(detections)} region(s)"
482
+ }
483
 
484
  except Exception as e:
485
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
486
+
487
+ @app.get("/classes")
488
+ async def get_classes():
489
+ """Get list of disease classes the model can detect."""
490
+ if model is None:
491
+ raise HTTPException(status_code=503, detail="Model not loaded")
492
+
493
+ try:
494
+ classes = list(model.dls.vocab)
495
+ return {
496
+ "success": True,
497
+ "classes": classes,
498
+ "num_classes": len(classes)
499
+ }
500
+ except Exception as e:
501
+ raise HTTPException(status_code=500, detail=f"Error retrieving classes: {str(e)}")
502
+
503
+ # =======================
504
+ # RUN SERVER
505
+ # =======================
506
+ if __name__ == "__main__":
507
+ import argparse
508
+
509
+ parser = argparse.ArgumentParser(description="Disease Detection API Server")
510
+ parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
511
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
512
+ parser.add_argument("--model", default="classifier.pkl", help="Path to classifier model")
513
+ parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
514
+
515
+ args = parser.parse_args()
516
+
517
+ cfg.MODEL_PATH = args.model
518
+
519
+ uvicorn.run(
520
+ "app:app" if args.reload else app,
521
+ host=args.host,
522
+ port=args.port,
523
+ reload=args.reload
524
+ )