SaswatML123 commited on
Commit
78aa842
Β·
verified Β·
1 Parent(s): 10b056e

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +113 -51
model_loader.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  model_loader.py
3
  Downloads models from HuggingFace repos, caches them, runs inference.
 
4
  """
5
 
6
  import os
@@ -16,10 +17,10 @@ os.makedirs(CACHE_DIR, exist_ok=True)
16
  # PNEUMONIA β€” Keras .h5
17
  # ══════════════════════════════════════════════════════════════════════════════
18
 
19
- PNEUMO_REPO = "SaswatML123/PneuModel"
20
- PNEUMO_FILE = "pneumodel.h5"
21
- PNEUMO_SIZE = (224, 224)
22
- _pneumo_model = None
23
 
24
 
25
  def _download(repo, filename):
@@ -51,9 +52,9 @@ def predict_pneumonia(image: Image.Image) -> dict:
51
 
52
  if preds.shape[-1] == 1:
53
  pneumonia_prob = float(preds[0][0])
54
- normal_prob = 1.0 - pneumonia_prob
55
  else:
56
- normal_prob = float(preds[0][0])
57
  pneumonia_prob = float(preds[0][1])
58
 
59
  if pneumonia_prob >= 0.5:
@@ -62,70 +63,125 @@ def predict_pneumonia(image: Image.Image) -> dict:
62
  label, confidence = "NORMAL", normal_prob
63
 
64
  return {
65
- "label": label,
66
- "confidence": round(confidence, 4),
67
  "probabilities": {
68
- "NORMAL": round(normal_prob, 4),
69
  "PNEUMONIA": round(pneumonia_prob, 4),
70
  },
71
  }
72
 
73
 
74
  # ══════════════════════════════════════════════════════════════════════════════
75
- # SKIN CANCER β€” PyTorch ensemble
76
  # ══════════════════════════════════════════════════════════════════════════════
77
 
78
  SKIN_REPO = "SaswatML123/Skin_cancer_detection"
79
  SKIN_FILES = {
80
- "efficientnetv2m": "model1_efficientnetv2m.pth",
81
- "efficientnetv2s": "model2_efficientnetv2s.pth",
82
- "convnext": "model3_convnext.pth",
83
  }
 
 
84
  SKIN_CLASSES = [
85
- "Melanocytic nevi",
86
- "Melanoma",
87
- "Benign keratosis",
88
- "Basal cell carcinoma",
89
- "Actinic keratosis",
90
- "Vascular lesions",
91
- "Dermatofibroma",
92
  ]
93
  NUM_SKIN_CLASSES = len(SKIN_CLASSES)
94
- _skin_models = []
95
- SKIN_TRANSFORM = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  def load_skin_models():
99
  global _skin_models, SKIN_TRANSFORM
100
  if _skin_models:
101
  return
102
- import torch
103
- import timm
104
- from torchvision import transforms
105
 
106
- SKIN_TRANSFORM = transforms.Compose([
107
- transforms.Resize((224, 224)),
108
- transforms.ToTensor(),
109
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
 
 
 
 
 
110
  ])
111
 
112
- arch_map = {
113
- "efficientnetv2m": "tf_efficientnetv2_m",
114
- "efficientnetv2s": "tf_efficientnetv2_s",
115
- "convnext": "convnext_base",
116
- }
117
- device = torch.device("cpu") # Space free tier is CPU only
118
-
119
- for arch, filename in SKIN_FILES.items():
120
- path = _download(SKIN_REPO, filename)
121
- model = timm.create_model(arch_map[arch], pretrained=False, num_classes=NUM_SKIN_CLASSES)
122
- state = torch.load(path, map_location=device)
123
- if isinstance(state, dict):
124
- for key in ("model_state_dict", "state_dict", "model"):
125
- if key in state:
126
- state = state[key]
127
- break
128
- model.load_state_dict(state, strict=False)
 
 
 
129
  model.eval()
130
  _skin_models.append(model)
131
  print(f"[Skin] βœ“ {arch}")
@@ -135,18 +191,24 @@ def load_skin_models():
135
 
136
  def predict_skin(image: Image.Image) -> dict:
137
  import torch
 
138
  load_skin_models()
139
- img_t = SKIN_TRANSFORM(image.convert("RGB")).unsqueeze(0)
 
 
140
  all_probs = []
141
  with torch.no_grad():
142
  for model in _skin_models:
143
- probs = torch.softmax(model(img_t), dim=1).squeeze().numpy()
 
144
  all_probs.append(probs)
 
145
  avg = np.mean(all_probs, axis=0)
146
  top = int(np.argmax(avg))
 
147
  return {
148
- "label": SKIN_CLASSES[top],
149
- "confidence": round(float(avg[top]), 4),
150
  "probabilities": {c: round(float(p), 4) for c, p in zip(SKIN_CLASSES, avg)},
151
- "model_count": len(_skin_models),
152
  }
 
1
  """
2
  model_loader.py
3
  Downloads models from HuggingFace repos, caches them, runs inference.
4
+ IMPORTANT: Uses exact same architecture + transforms as Colab training code.
5
  """
6
 
7
  import os
 
17
  # PNEUMONIA β€” Keras .h5
18
  # ══════════════════════════════════════════════════════════════════════════════
19
 
20
+ PNEUMO_REPO = "SaswatML123/PneuModel"
21
+ PNEUMO_FILE = "pneumodel.h5"
22
+ PNEUMO_SIZE = (224, 224)
23
+ _pneumo_model = None
24
 
25
 
26
  def _download(repo, filename):
 
52
 
53
  if preds.shape[-1] == 1:
54
  pneumonia_prob = float(preds[0][0])
55
+ normal_prob = 1.0 - pneumonia_prob
56
  else:
57
+ normal_prob = float(preds[0][0])
58
  pneumonia_prob = float(preds[0][1])
59
 
60
  if pneumonia_prob >= 0.5:
 
63
  label, confidence = "NORMAL", normal_prob
64
 
65
  return {
66
+ "label": label,
67
+ "confidence": round(confidence, 4),
68
  "probabilities": {
69
+ "NORMAL": round(normal_prob, 4),
70
  "PNEUMONIA": round(pneumonia_prob, 4),
71
  },
72
  }
73
 
74
 
75
  # ══════════════════════════════════════════════════════════════════════════════
76
+ # SKIN CANCER β€” exact same architecture as Colab training
77
  # ══════════════════════════════════════════════════════════════════════════════
78
 
79
  SKIN_REPO = "SaswatML123/Skin_cancer_detection"
80
  SKIN_FILES = {
81
+ "efficientnetv2m": ("model1_efficientnetv2m.pth", "tf_efficientnetv2_m"),
82
+ "efficientnetv2s": ("model2_efficientnetv2s.pth", "tf_efficientnetv2_s"),
83
+ "convnext": ("model3_convnext.pth", "convnext_base"),
84
  }
85
+
86
+ # Alphabetical sorted order β€” matches CLASS_NAMES = sorted(df['dx'].unique())
87
  SKIN_CLASSES = [
88
+ "Actinic Keratoses", # akiec β€” index 0
89
+ "Basal Cell Carcinoma", # bcc β€” index 1
90
+ "Benign Keratosis", # bkl β€” index 2
91
+ "Dermatofibroma", # df β€” index 3
92
+ "Melanoma", # mel β€” index 4
93
+ "Melanocytic Nevi", # nv β€” index 5
94
+ "Vascular Lesions", # vasc β€” index 6
95
  ]
96
  NUM_SKIN_CLASSES = len(SKIN_CLASSES)
97
+ _skin_models = []
98
+ SKIN_TRANSFORM = None
99
+
100
+
101
+ # ── Exact replica of Colab SkinCancerModel ────────────────────────────────────
102
+ def _build_skin_model(model_name: str):
103
+ import torch
104
+ import torch.nn as nn
105
+ import torch.nn.functional as F
106
+ import timm
107
+
108
+ class GeM(nn.Module):
109
+ def __init__(self, p=3, eps=1e-6):
110
+ super().__init__()
111
+ self.p = nn.Parameter(torch.ones(1) * p)
112
+ self.eps = eps
113
+
114
+ def forward(self, x):
115
+ return F.avg_pool2d(
116
+ x.clamp(min=self.eps).pow(self.p),
117
+ (x.size(-2), x.size(-1))
118
+ ).pow(1.0 / self.p)
119
+
120
+ class SkinCancerModel(nn.Module):
121
+ def __init__(self, num_classes=7, model_name='tf_efficientnetv2_m',
122
+ pretrained=False, drop_rate=0.3):
123
+ super().__init__()
124
+ self.backbone = timm.create_model(
125
+ model_name, pretrained=pretrained,
126
+ num_classes=0, global_pool='', drop_rate=drop_rate
127
+ )
128
+ in_features = self.backbone.num_features
129
+ self.pool = GeM()
130
+ self.head = nn.Sequential(
131
+ nn.Flatten(),
132
+ nn.Linear(in_features, 512),
133
+ nn.BatchNorm1d(512),
134
+ nn.SiLU(),
135
+ nn.Dropout(drop_rate),
136
+ nn.Linear(512, num_classes)
137
+ )
138
+
139
+ def forward(self, x):
140
+ return self.head(self.pool(self.backbone(x)))
141
+
142
+ return SkinCancerModel(
143
+ num_classes=NUM_SKIN_CLASSES,
144
+ model_name=model_name,
145
+ pretrained=False
146
+ )
147
 
148
 
149
  def load_skin_models():
150
  global _skin_models, SKIN_TRANSFORM
151
  if _skin_models:
152
  return
 
 
 
153
 
154
+ import torch
155
+ import albumentations as A
156
+ from albumentations.pytorch import ToTensorV2
157
+
158
+ # Exact same transforms as Colab get_val_transforms(300)
159
+ _albu_transform = A.Compose([
160
+ A.Resize(height=300, width=300),
161
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
162
+ ToTensorV2(),
163
  ])
164
 
165
+ # Store as a wrapper function
166
+ def transform_fn(pil_img):
167
+ img_np = np.array(pil_img.convert("RGB"))
168
+ return _albu_transform(image=img_np)["image"].unsqueeze(0)
169
+
170
+ global SKIN_TRANSFORM
171
+ SKIN_TRANSFORM = transform_fn
172
+
173
+ device = torch.device("cpu")
174
+
175
+ for arch, (filename, model_name) in SKIN_FILES.items():
176
+ path = _download(SKIN_REPO, filename)
177
+ model = _build_skin_model(model_name)
178
+
179
+ checkpoint = torch.load(path, map_location=device)
180
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
181
+ model.load_state_dict(checkpoint["model_state_dict"])
182
+ else:
183
+ model.load_state_dict(checkpoint)
184
+
185
  model.eval()
186
  _skin_models.append(model)
187
  print(f"[Skin] βœ“ {arch}")
 
191
 
192
  def predict_skin(image: Image.Image) -> dict:
193
  import torch
194
+ import torch.nn.functional as F
195
  load_skin_models()
196
+
197
+ img_t = SKIN_TRANSFORM(image) # (1, 3, 300, 300)
198
+
199
  all_probs = []
200
  with torch.no_grad():
201
  for model in _skin_models:
202
+ logits = model(img_t)
203
+ probs = F.softmax(logits, dim=1).squeeze().numpy()
204
  all_probs.append(probs)
205
+
206
  avg = np.mean(all_probs, axis=0)
207
  top = int(np.argmax(avg))
208
+
209
  return {
210
+ "label": SKIN_CLASSES[top],
211
+ "confidence": round(float(avg[top]), 4),
212
  "probabilities": {c: round(float(p), 4) for c, p in zip(SKIN_CLASSES, avg)},
213
+ "model_count": len(_skin_models),
214
  }