Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -46,6 +46,25 @@ configs = {
|
|
| 46 |
"THRESHOLD": 0.5
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"""### Define helper functions"""
|
| 50 |
|
| 51 |
# helper function for loading pre-trained model
|
|
@@ -211,39 +230,46 @@ class modelModule(torch_light.LightningModule):
|
|
| 211 |
|
| 212 |
"""### Create function for running inference (i.e., assistive medical diagnosis)"""
|
| 213 |
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
-
input_tensor = preprocess_fn(input_image)
|
| 217 |
-
input_tensor = input_tensor.unsqueeze(dim = 0)
|
| 218 |
-
# newimg = transform(img).unsqueeze(dim=0)
|
| 219 |
-
|
| 220 |
-
CKPT_PATH = os.path.join(CKPT_ROOT, f"{backbone_name}.ckpt")
|
| 221 |
-
model = modelModule.load_from_checkpoint(CKPT_PATH)
|
| 222 |
model.eval()
|
| 223 |
|
| 224 |
-
#
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
# File "/app/app.py", line 226, in run_diagnosis
|
| 228 |
-
# probabilities = torch.sigmoid(output_logits)[0].numpy().tolist()
|
| 229 |
-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
|
| 230 |
-
# RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
probabilities = torch.sigmoid(output_logits)[0].detach().numpy().tolist()
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
|
|
|
| 239 |
|
| 240 |
-
|
| 241 |
-
output_probs[Idx2labels[idx]] = prob
|
| 242 |
-
if prob >= threshold:
|
| 243 |
-
predicted_classes.append(Idx2labels[idx])
|
| 244 |
|
| 245 |
-
predicted_classes = "\n".join(predicted_classes)
|
| 246 |
-
return predicted_classes, output_probs
|
| 247 |
|
| 248 |
"""### Gradio app"""
|
| 249 |
CKPT_ROOT = os.path.join(os.getcwd(), "Trained models")
|
|
|
|
| 46 |
"THRESHOLD": 0.5
|
| 47 |
}
|
| 48 |
|
| 49 |
+
BACKBONE_REGISTRY = {
|
| 50 |
+
"EfficientNet(b3)": {
|
| 51 |
+
"torchvision_name": "efficientnet_b3",
|
| 52 |
+
"ckpt": "EfficientNet(b3).ckpt"},
|
| 53 |
+
"ConvNeXt(tiny)": {
|
| 54 |
+
"torchvision_name": "convnext_tiny",
|
| 55 |
+
"ckpt": "ConvNeXt(tiny).ckpt"},
|
| 56 |
+
"EfficientNet(v2_small)": {
|
| 57 |
+
"torchvision_name": "efficientnet_v2_s)",
|
| 58 |
+
"ckpt": "EfficientNet(v2_small).ckpt"},
|
| 59 |
+
"RegNet(x3_2GF)": {
|
| 60 |
+
"torchvision_name": "regnet_x_3_2gf)",
|
| 61 |
+
"ckpt": "RegNet(x3_2GF).ckpt"},
|
| 62 |
+
"ResNet50": {
|
| 63 |
+
"torchvision_name": "resnet50)",
|
| 64 |
+
"ckpt": "ResNet50.ckpt"}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
MODEL_CACHE = {}
|
| 68 |
"""### Define helper functions"""
|
| 69 |
|
| 70 |
# helper function for loading pre-trained model
|
|
|
|
| 230 |
|
| 231 |
"""### Create function for running inference (i.e., assistive medical diagnosis)"""
|
| 232 |
|
| 233 |
+
@torch.inference_mode()
|
| 234 |
+
def run_diagnosis(
|
| 235 |
+
backbone_name,
|
| 236 |
+
input_image,
|
| 237 |
+
preprocess_fn=None,
|
| 238 |
+
Idx2labels=None,
|
| 239 |
+
threshold=configs["THRESHOLD"]):
|
| 240 |
+
|
| 241 |
+
# Preprocess
|
| 242 |
+
x = preprocess_fn(input_image).unsqueeze(0)
|
| 243 |
+
|
| 244 |
+
# Resolve backbone
|
| 245 |
+
backbone_info = BACKBONE_REGISTRY[backbone_name]
|
| 246 |
+
ckpt_path = os.path.join(CKPT_ROOT, backbone_info["ckpt"])
|
| 247 |
+
|
| 248 |
+
if not os.path.exists(ckpt_path):
|
| 249 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 250 |
+
|
| 251 |
+
# Load model (cache for speed)
|
| 252 |
+
if backbone_name not in MODEL_CACHE:
|
| 253 |
+
MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
|
| 254 |
+
ckpt_path, backbone_model_name=backbone_info["torchvision_name"])
|
| 255 |
+
model = MODEL_CACHE[backbone_name]
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
model.eval()
|
| 258 |
|
| 259 |
+
# Forward
|
| 260 |
+
logits = model(x)
|
| 261 |
+
probs = torch.sigmoid(logits)[0].cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
+
output_probs = {
|
| 264 |
+
Idx2labels[i]: float(p) for i, p in enumerate(probs)
|
| 265 |
+
}
|
|
|
|
| 266 |
|
| 267 |
+
predicted_classes = [
|
| 268 |
+
Idx2labels[i] for i, p in enumerate(probs) if p >= threshold
|
| 269 |
+
]
|
| 270 |
|
| 271 |
+
return "\n".join(predicted_classes), output_probs
|
|
|
|
|
|
|
|
|
|
| 272 |
|
|
|
|
|
|
|
| 273 |
|
| 274 |
"""### Gradio app"""
|
| 275 |
CKPT_ROOT = os.path.join(os.getcwd(), "Trained models")
|