AdarshRajDS commited on
Commit
7a5f7fb
·
1 Parent(s): cca95f0

Fix ConvNeXt checkpoint loading and Grad-CAM layer selection

Browse files
Files changed (2) hide show
  1. app.py +30 -13
  2. model.py +13 -5
app.py CHANGED
@@ -1,10 +1,10 @@
1
- from fastapi import FastAPI, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
  import torch, io
5
  from torchvision import transforms
6
 
7
- from model import MultiTaskResNet50, MultiTaskConvNeXt
8
  from decision import final_decision
9
  from advanced_decision import (
10
  mc_uncertainty,
@@ -12,7 +12,7 @@ from advanced_decision import (
12
  final_decision_v2
13
  )
14
  from gradcam import GradCAM
15
- from dino import load_dino, build_embeddings, similarity
16
 
17
  app = FastAPI(title="Mold Detection API v2 (ConvNeXt)")
18
 
@@ -53,26 +53,41 @@ transform = transforms.Compose([
53
  ])
54
 
55
  # ------------------
56
- # Grad-CAM (use exposed last_conv from ConvNeXt wrapper)
57
- # If missing, fall back to a reasonable conv layer
58
- target_layer = getattr(model, "last_conv", None)
59
- if target_layer is None:
60
- # ConvNeXt features[-1] is a ConvNeXt block with a depthwise conv `dwconv`
61
- target_layer = model.backbone.features[-1].dwconv
62
- gradcam = GradCAM(model, target_layer)
63
 
64
  # ------------------
65
  # DINO (lazy loaded)
66
  # ------------------
67
- dino = None
68
  mold_embs = None
69
 
70
 
71
  def ensure_dino():
72
  global dino, mold_embs
73
  if dino is None:
74
- dino = load_dino(device)
75
- mold_embs = build_embeddings(dino, transform, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  # ------------------
@@ -89,6 +104,8 @@ async def predict_v1(file: UploadFile):
89
  @app.post("/predict/v2")
90
  async def predict_v2(file: UploadFile):
91
  ensure_dino()
 
 
92
 
93
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
94
  img_t = transform(img).to(device)
 
1
+ from fastapi import FastAPI, UploadFile, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
  import torch, io
5
  from torchvision import transforms
6
 
7
+ from model import MultiTaskResNet50, MultiTaskConvNeXt, find_last_conv2d
8
  from decision import final_decision
9
  from advanced_decision import (
10
  mc_uncertainty,
 
12
  final_decision_v2
13
  )
14
  from gradcam import GradCAM
15
+ from typing import Optional
16
 
17
  app = FastAPI(title="Mold Detection API v2 (ConvNeXt)")
18
 
 
53
  ])
54
 
55
  # ------------------
56
+ # Grad-CAM target layer (computed, not stored in model state_dict)
57
+ # ------------------
58
+ target_layer = find_last_conv2d(model.backbone)
59
+ gradcam = GradCAM(model, target_layer) if target_layer is not None else None
 
 
 
60
 
61
  # ------------------
62
  # DINO (lazy loaded)
63
  # ------------------
64
+ dino: Optional[object] = None
65
  mold_embs = None
66
 
67
 
68
  def ensure_dino():
69
  global dino, mold_embs
70
  if dino is None:
71
+ try:
72
+ from dino import load_dino, build_embeddings
73
+ except ModuleNotFoundError as e:
74
+ # Local/dev env might not have optional deps like `datasets`.
75
+ raise HTTPException(
76
+ status_code=503,
77
+ detail=(
78
+ "DINO dependencies are not installed. "
79
+ "Install extras with: pip install datasets scikit-learn"
80
+ ),
81
+ ) from e
82
+
83
+ try:
84
+ dino = load_dino(device)
85
+ mold_embs = build_embeddings(dino, transform, device)
86
+ except Exception as e:
87
+ raise HTTPException(
88
+ status_code=503,
89
+ detail=f"Failed to initialize DINO reference embeddings: {e}",
90
+ ) from e
91
 
92
 
93
  # ------------------
 
104
  @app.post("/predict/v2")
105
  async def predict_v2(file: UploadFile):
106
  ensure_dino()
107
+ # Import similarity lazily (only needed for v2)
108
+ from dino import similarity
109
 
110
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
111
  img_t = transform(img).to(device)
model.py CHANGED
@@ -3,6 +3,19 @@ import torch.nn as nn
3
  from torchvision import models
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class MultiTaskResNet50(nn.Module):
7
  def __init__(self, num_classes=9):
8
  super().__init__()
@@ -44,11 +57,6 @@ class MultiTaskConvNeXt(nn.Module):
44
  self.bio_head = nn.Linear(feat_dim, 2)
45
  self.dropout = nn.Dropout(p=0.1)
46
 
47
- # Expose a sensible last conv layer ref for Grad-CAM usage.
48
- # In torchvision ConvNeXt, each element of `features` is a ConvNeXt block
49
- # and has a depthwise conv named `dwconv`.
50
- self.last_conv = self.backbone.features[-1].dwconv
51
-
52
  def forward(self, x: torch.Tensor):
53
  feats = self.backbone.features(x)
54
  feats = self.pool(feats)
 
3
  from torchvision import models
4
 
5
 
6
+ def find_last_conv2d(module: nn.Module) -> nn.Conv2d | None:
7
+ """
8
+ Returns the last nn.Conv2d found in a module traversal.
9
+ Important: we do NOT attach this as a child module on the model instance,
10
+ otherwise it becomes part of state_dict and breaks checkpoint loading.
11
+ """
12
+ last = None
13
+ for m in module.modules():
14
+ if isinstance(m, nn.Conv2d):
15
+ last = m
16
+ return last
17
+
18
+
19
  class MultiTaskResNet50(nn.Module):
20
  def __init__(self, num_classes=9):
21
  super().__init__()
 
57
  self.bio_head = nn.Linear(feat_dim, 2)
58
  self.dropout = nn.Dropout(p=0.1)
59
 
 
 
 
 
 
60
  def forward(self, x: torch.Tensor):
61
  feats = self.backbone.features(x)
62
  feats = self.pool(feats)