cfoli commited on
Commit
22e4bb0
·
1 Parent(s): af8ad28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -26
app.py CHANGED
@@ -46,6 +46,25 @@ configs = {
46
  "THRESHOLD": 0.5
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """### Define helper functions"""
50
 
51
  # helper function for loading pre-trained model
@@ -211,39 +230,46 @@ class modelModule(torch_light.LightningModule):
211
 
212
  """### Create function for running inference (i.e., assistive medical diagnosis)"""
213
 
214
- def run_diagnosis(backbone_name, input_image, preprocess_fn = None, Idx2labels=None, threshold = configs["THRESHOLD"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- input_tensor = preprocess_fn(input_image)
217
- input_tensor = input_tensor.unsqueeze(dim = 0)
218
- # newimg = transform(img).unsqueeze(dim=0)
219
-
220
- CKPT_PATH = os.path.join(CKPT_ROOT, f"{backbone_name}.ckpt")
221
- model = modelModule.load_from_checkpoint(CKPT_PATH)
222
  model.eval()
223
 
224
- # Generate predictions
225
- output_logits = model(input_tensor).cpu()
226
-
227
- # File "/app/app.py", line 226, in run_diagnosis
228
- # probabilities = torch.sigmoid(output_logits)[0].numpy().tolist()
229
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
230
- # RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
231
 
232
- # ylogit = model(newimg).detach()
233
- # yprob = torch.sigmoid(ylogit)
234
-
235
- probabilities = torch.sigmoid(output_logits)[0].detach().numpy().tolist()
236
 
237
- output_probs = dict()
238
- predicted_classes = []
 
239
 
240
- for idx, prob in enumerate(probabilities):
241
- output_probs[Idx2labels[idx]] = prob
242
- if prob >= threshold:
243
- predicted_classes.append(Idx2labels[idx])
244
 
245
- predicted_classes = "\n".join(predicted_classes)
246
- return predicted_classes, output_probs
247
 
248
  """### Gradio app"""
249
  CKPT_ROOT = os.path.join(os.getcwd(), "Trained models")
 
46
  "THRESHOLD": 0.5
47
  }
48
 
49
+ BACKBONE_REGISTRY = {
50
+ "EfficientNet(b3)": {
51
+ "torchvision_name": "efficientnet_b3",
52
+ "ckpt": "EfficientNet(b3).ckpt"},
53
+ "ConvNeXt(tiny)": {
54
+ "torchvision_name": "convnext_tiny",
55
+ "ckpt": "ConvNeXt(tiny).ckpt"},
56
+ "EfficientNet(v2_small)": {
57
+ "torchvision_name": "efficientnet_v2_s)",
58
+ "ckpt": "EfficientNet(v2_small).ckpt"},
59
+ "RegNet(x3_2GF)": {
60
+ "torchvision_name": "regnet_x_3_2gf)",
61
+ "ckpt": "RegNet(x3_2GF).ckpt"},
62
+ "ResNet50": {
63
+ "torchvision_name": "resnet50)",
64
+ "ckpt": "ResNet50.ckpt"}
65
+ }
66
+
67
+ MODEL_CACHE = {}
68
  """### Define helper functions"""
69
 
70
  # helper function for loading pre-trained model
 
230
 
231
  """### Create function for running inference (i.e., assistive medical diagnosis)"""
232
 
233
+ @torch.inference_mode()
234
+ def run_diagnosis(
235
+ backbone_name,
236
+ input_image,
237
+ preprocess_fn=None,
238
+ Idx2labels=None,
239
+ threshold=configs["THRESHOLD"]):
240
+
241
+ # Preprocess
242
+ x = preprocess_fn(input_image).unsqueeze(0)
243
+
244
+ # Resolve backbone
245
+ backbone_info = BACKBONE_REGISTRY[backbone_name]
246
+ ckpt_path = os.path.join(CKPT_ROOT, backbone_info["ckpt"])
247
+
248
+ if not os.path.exists(ckpt_path):
249
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
250
+
251
+ # Load model (cache for speed)
252
+ if backbone_name not in MODEL_CACHE:
253
+ MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
254
+ ckpt_path, backbone_model_name=backbone_info["torchvision_name"])
255
+ model = MODEL_CACHE[backbone_name]
256
 
 
 
 
 
 
 
257
  model.eval()
258
 
259
+ # Forward
260
+ logits = model(x)
261
+ probs = torch.sigmoid(logits)[0].cpu().numpy()
 
 
 
 
262
 
263
+ output_probs = {
264
+ Idx2labels[i]: float(p) for i, p in enumerate(probs)
265
+ }
 
266
 
267
+ predicted_classes = [
268
+ Idx2labels[i] for i, p in enumerate(probs) if p >= threshold
269
+ ]
270
 
271
+ return "\n".join(predicted_classes), output_probs
 
 
 
272
 
 
 
273
 
274
  """### Gradio app"""
275
  CKPT_ROOT = os.path.join(os.getcwd(), "Trained models")