AdarshRajDS commited on
Commit
7d6580c
·
1 Parent(s): 4bb02cf

Use ConvNeXt multitask model and new checkpoint

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. app.py +19 -11
  3. best_convnext_multitask.pth +3 -0
  4. model.py +43 -2
Dockerfile CHANGED
@@ -15,7 +15,7 @@ RUN pip install --no-cache-dir --upgrade pip && \
15
  pip install --no-cache-dir -r requirements.txt
16
 
17
  COPY *.py ./
18
- COPY resnet50_multitask_mold.pth ./
19
 
20
  EXPOSE 7860
21
 
 
15
  pip install --no-cache-dir -r requirements.txt
16
 
17
  COPY *.py ./
18
+ COPY best_convnext_multitask.pth ./
19
 
20
  EXPOSE 7860
21
 
app.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  import torch, io
5
  from torchvision import transforms
6
 
7
- from model import MultiTaskResNet50
8
  from decision import final_decision
9
  from advanced_decision import (
10
  mc_uncertainty,
@@ -14,7 +14,7 @@ from advanced_decision import (
14
  from gradcam import GradCAM
15
  from dino import load_dino, build_embeddings, similarity
16
 
17
- app = FastAPI(title="Mold Detection API v2")
18
 
19
  app.add_middleware(
20
  CORSMiddleware,
@@ -24,15 +24,20 @@ app.add_middleware(
24
  )
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
- mold_idx = 4
28
 
29
  # ------------------
30
- # Load main model
31
  # ------------------
32
- model = MultiTaskResNet50().to(device)
33
- model.load_state_dict(
34
- torch.load("resnet50_multitask_mold.pth", map_location=device)
35
- )
 
 
 
 
 
 
36
  model.eval()
37
 
38
  # ------------------
@@ -48,9 +53,12 @@ transform = transforms.Compose([
48
  ])
49
 
50
  # ------------------
51
- # Grad-CAM
52
- # ------------------
53
- gradcam = GradCAM(model, model.backbone.layer4[-1].conv3)
 
 
 
54
 
55
  # ------------------
56
  # DINO (lazy loaded)
 
4
  import torch, io
5
  from torchvision import transforms
6
 
7
+ from model import MultiTaskResNet50, MultiTaskConvNeXt
8
  from decision import final_decision
9
  from advanced_decision import (
10
  mc_uncertainty,
 
14
  from gradcam import GradCAM
15
  from dino import load_dino, build_embeddings, similarity
16
 
17
+ app = FastAPI(title="Mold Detection API v2 (ConvNeXt)")
18
 
19
  app.add_middleware(
20
  CORSMiddleware,
 
24
  )
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
27
 
28
  # ------------------
29
+ # Load main model (ConvNeXt)
30
  # ------------------
31
+ # Expecting checkpoint with keys:
32
+ # - "model": state_dict
33
+ # - "classes": list of class names (length N, mold at some index)
34
+ ckpt = torch.load("best_convnext_multitask.pth", map_location=device)
35
+ classes = ckpt.get("classes") or []
36
+ num_classes = len(classes) if classes else 9
37
+ mold_idx = classes.index("mold") if classes else 4
38
+
39
+ model = MultiTaskConvNeXt(num_classes).to(device)
40
+ model.load_state_dict(ckpt["model"])
41
  model.eval()
42
 
43
  # ------------------
 
53
  ])
54
 
55
  # ------------------
56
+ # Grad-CAM (use exposed last_conv from ConvNeXt wrapper)
57
+ # If missing, fall back to a reasonable conv layer
58
+ target_layer = getattr(model, "last_conv", None)
59
+ if target_layer is None:
60
+ target_layer = model.backbone.features[-1].block[-1].dwconv
61
+ gradcam = GradCAM(model, target_layer)
62
 
63
  # ------------------
64
  # DINO (lazy loaded)
best_convnext_multitask.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb0e73f75e0b9fbc2a548dec97791598d99b67791eec44d2aa35053f1e27e342
3
+ size 350441583
model.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from torchvision import models
4
 
 
5
  class MultiTaskResNet50(nn.Module):
6
  def __init__(self, num_classes=9):
7
  super().__init__()
@@ -11,10 +12,50 @@ class MultiTaskResNet50(nn.Module):
11
  self.class_head = nn.Linear(feat_dim, num_classes)
12
  self.bio_head = nn.Linear(feat_dim, 2)
13
 
14
- def forward(self, x):
15
  feats = self.backbone(x)
16
  return {
17
  "class": self.class_head(feats),
18
- "bio": self.bio_head(feats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
 
 
2
  import torch.nn as nn
3
  from torchvision import models
4
 
5
+
6
  class MultiTaskResNet50(nn.Module):
7
  def __init__(self, num_classes=9):
8
  super().__init__()
 
12
  self.class_head = nn.Linear(feat_dim, num_classes)
13
  self.bio_head = nn.Linear(feat_dim, 2)
14
 
15
+ def forward(self, x: torch.Tensor):
16
  feats = self.backbone(x)
17
  return {
18
  "class": self.class_head(feats),
19
+ "bio": self.bio_head(feats),
20
+ }
21
+
22
+
23
+ class MultiTaskConvNeXt(nn.Module):
24
+ """
25
+ ConvNeXt-Base backbone with two heads:
26
+ - N-class structural/mold classifier
27
+ - 2-class biological vs non-biological head
28
+
29
+ Mirrors the training setup from the ConvNeXt Kaggle notebook.
30
+ """
31
+
32
+ def __init__(self, num_classes: int):
33
+ super().__init__()
34
+
35
+ # We load task-specific weights, so no ImageNet weights here.
36
+ self.backbone = models.convnext_base(weights=None)
37
+
38
+ # ConvNeXt classifier is [LayerNorm2d, Flatten, Linear]
39
+ feat_dim = self.backbone.classifier[2].in_features
40
+ self.backbone.classifier = nn.Identity()
41
+
42
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
43
+ self.class_head = nn.Linear(feat_dim, num_classes)
44
+ self.bio_head = nn.Linear(feat_dim, 2)
45
+ self.dropout = nn.Dropout(p=0.1)
46
+
47
+ # Expose a sensible last conv layer ref for Grad-CAM usage.
48
+ self.last_conv = self.backbone.features[-1].block[-1].dwconv
49
+
50
+ def forward(self, x: torch.Tensor):
51
+ feats = self.backbone.features(x)
52
+ feats = self.pool(feats)
53
+ feats = torch.flatten(feats, 1)
54
+ feats = self.dropout(feats)
55
+
56
+ return {
57
+ "class": self.class_head(feats),
58
+ "bio": self.bio_head(feats),
59
  }
60
 
61
+