Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -45,10 +45,10 @@ configs = {
|
|
| 45 |
"THRESHOLD": 0.5
|
| 46 |
}
|
| 47 |
|
| 48 |
-
MODEL_REGISTRY = {
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
}
|
| 52 |
|
| 53 |
MODEL_CACHE = {}
|
| 54 |
|
|
@@ -67,7 +67,8 @@ class get_pretrained_model(nn.Module):
|
|
| 67 |
print(f"Loading pretrained [{model_name}] model")
|
| 68 |
|
| 69 |
self.backbone = AutoModel.from_pretrained(
|
| 70 |
-
MODEL_REGISTRY[model_name],
|
|
|
|
| 71 |
trust_remote_code=True)
|
| 72 |
|
| 73 |
hidden_size = self.backbone.config.hidden_size
|
|
@@ -252,7 +253,8 @@ def run_diagnosis(
|
|
| 252 |
x = preprocess_fn(input_image).unsqueeze(0)
|
| 253 |
|
| 254 |
# Resolve backbone
|
| 255 |
-
ckpt_path = os.path.join(CKPT_ROOT, MODEL_REGISTRY[backbone_name])
|
|
|
|
| 256 |
|
| 257 |
if not os.path.exists(ckpt_path):
|
| 258 |
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
@@ -260,7 +262,7 @@ def run_diagnosis(
|
|
| 260 |
# Load model (cache for speed)
|
| 261 |
if backbone_name not in MODEL_CACHE:
|
| 262 |
MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
|
| 263 |
-
ckpt_path, backbone_model_name=
|
| 264 |
model = MODEL_CACHE[backbone_name]
|
| 265 |
|
| 266 |
model.eval()
|
|
|
|
| 45 |
"THRESHOLD": 0.5
|
| 46 |
}
|
| 47 |
|
| 48 |
+
# MODEL_REGISTRY = {
|
| 49 |
+
# "CheXFormer-small": "m42-health/CXformer-small",
|
| 50 |
+
# "ViT-base-16": "google/vit-base-patch16-224",
|
| 51 |
+
# }
|
| 52 |
|
| 53 |
MODEL_CACHE = {}
|
| 54 |
|
|
|
|
| 67 |
print(f"Loading pretrained [{model_name}] model")
|
| 68 |
|
| 69 |
self.backbone = AutoModel.from_pretrained(
|
| 70 |
+
# MODEL_REGISTRY[model_name],
|
| 71 |
+
model_name,
|
| 72 |
trust_remote_code=True)
|
| 73 |
|
| 74 |
hidden_size = self.backbone.config.hidden_size
|
|
|
|
| 253 |
x = preprocess_fn(input_image).unsqueeze(0)
|
| 254 |
|
| 255 |
# Resolve backbone
|
| 256 |
+
# ckpt_path = os.path.join(CKPT_ROOT, MODEL_REGISTRY[backbone_name])
|
| 257 |
+
kpt_path = os.path.join(CKPT_ROOT, backbone_name)
|
| 258 |
|
| 259 |
if not os.path.exists(ckpt_path):
|
| 260 |
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
|
|
| 262 |
# Load model (cache for speed)
|
| 263 |
if backbone_name not in MODEL_CACHE:
|
| 264 |
MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
|
| 265 |
+
ckpt_path, backbone_model_name=backbone_name, num_layers_to_unfreeze = 2)
|
| 266 |
model = MODEL_CACHE[backbone_name]
|
| 267 |
|
| 268 |
model.eval()
|