AdarshRajDS commited on
Commit
c46050c
·
1 Parent(s): 9f79a25

Add ResNet baseline and ConvNeXt v2 backend Dockerfile 1

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -25,6 +25,9 @@ app.add_middleware(
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
 
 
 
28
  # ------------------
29
  # Load baseline model (ResNet)
30
  # ------------------
@@ -33,15 +36,23 @@ resnet_ckpt = torch.load(
33
  map_location=device
34
  )
35
 
36
- resnet_classes = resnet_ckpt.get("classes") or []
 
 
 
 
 
 
 
37
  resnet_num_classes = len(resnet_classes) if resnet_classes else 9
38
  resnet_mold_idx = (
39
  resnet_classes.index("mold")
40
- if resnet_classes else 4
 
41
  )
42
 
43
  resnet_model = MultiTaskResNet50(resnet_num_classes).to(device)
44
- resnet_model.load_state_dict(resnet_ckpt["model"])
45
  resnet_model.eval()
46
 
47
 
 
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
+ # ------------------
29
+ # Load baseline model (ResNet)
30
+ # ------------------
31
  # ------------------
32
  # Load baseline model (ResNet)
33
  # ------------------
 
36
  map_location=device
37
  )
38
 
39
+ # Handle different checkpoint formats
40
+ if isinstance(resnet_ckpt, dict) and "model" in resnet_ckpt:
41
+ resnet_state = resnet_ckpt["model"]
42
+ resnet_classes = resnet_ckpt.get("classes", [])
43
+ else:
44
+ resnet_state = resnet_ckpt
45
+ resnet_classes = []
46
+
47
  resnet_num_classes = len(resnet_classes) if resnet_classes else 9
48
  resnet_mold_idx = (
49
  resnet_classes.index("mold")
50
+ if resnet_classes and "mold" in resnet_classes
51
+ else 4
52
  )
53
 
54
  resnet_model = MultiTaskResNet50(resnet_num_classes).to(device)
55
+ resnet_model.load_state_dict(resnet_state)
56
  resnet_model.eval()
57
 
58