cfoli commited on
Commit
28cf2c9
·
1 Parent(s): f533365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -45,10 +45,10 @@ configs = {
45
  "THRESHOLD": 0.5
46
  }
47
 
48
- MODEL_REGISTRY = {
49
- "CheXFormer-small": "m42-health/CXformer-small",
50
- "ViT-base-16": "google/vit-base-patch16-224",
51
- }
52
 
53
  MODEL_CACHE = {}
54
 
@@ -67,7 +67,8 @@ class get_pretrained_model(nn.Module):
67
  print(f"Loading pretrained [{model_name}] model")
68
 
69
  self.backbone = AutoModel.from_pretrained(
70
- MODEL_REGISTRY[model_name],
 
71
  trust_remote_code=True)
72
 
73
  hidden_size = self.backbone.config.hidden_size
@@ -252,7 +253,8 @@ def run_diagnosis(
252
  x = preprocess_fn(input_image).unsqueeze(0)
253
 
254
  # Resolve backbone
255
- ckpt_path = os.path.join(CKPT_ROOT, MODEL_REGISTRY[backbone_name])
 
256
 
257
  if not os.path.exists(ckpt_path):
258
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
@@ -260,7 +262,7 @@ def run_diagnosis(
260
  # Load model (cache for speed)
261
  if backbone_name not in MODEL_CACHE:
262
  MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
263
- ckpt_path, backbone_model_name=MODEL_REGISTRY[backbone_name], num_layers_to_unfreeze = 2)
264
  model = MODEL_CACHE[backbone_name]
265
 
266
  model.eval()
 
45
  "THRESHOLD": 0.5
46
  }
47
 
48
+ # MODEL_REGISTRY = {
49
+ # "CheXFormer-small": "m42-health/CXformer-small",
50
+ # "ViT-base-16": "google/vit-base-patch16-224",
51
+ # }
52
 
53
  MODEL_CACHE = {}
54
 
 
67
  print(f"Loading pretrained [{model_name}] model")
68
 
69
  self.backbone = AutoModel.from_pretrained(
70
+ # MODEL_REGISTRY[model_name],
71
+ model_name,
72
  trust_remote_code=True)
73
 
74
  hidden_size = self.backbone.config.hidden_size
 
253
  x = preprocess_fn(input_image).unsqueeze(0)
254
 
255
  # Resolve backbone
256
+ # ckpt_path = os.path.join(CKPT_ROOT, MODEL_REGISTRY[backbone_name])
257
+ kpt_path = os.path.join(CKPT_ROOT, backbone_name)
258
 
259
  if not os.path.exists(ckpt_path):
260
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
 
262
  # Load model (cache for speed)
263
  if backbone_name not in MODEL_CACHE:
264
  MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
265
+ ckpt_path, backbone_model_name=backbone_name, num_layers_to_unfreeze = 2)
266
  model = MODEL_CACHE[backbone_name]
267
 
268
  model.eval()