muruga778 commited on
Commit
044ce91
·
verified ·
1 Parent(s): 139a19b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -112,7 +112,7 @@ MAX_LEN = int(CFG.get("max_len", 128))
112
 
113
  # ---- Image model
114
  img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
115
- sd_img = clean_state_dict(torch.load("best_scin_image.pt", map_location="cpu"))
116
  img_model.load_state_dict(sd_img, strict=True)
117
  img_model.to(DEVICE).eval()
118
 
@@ -134,7 +134,7 @@ class TextClassifier(nn.Module):
134
 
135
  tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
136
  text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
137
- sd_txt = clean_state_dict(torch.load("best_scin_text.pt", map_location="cpu"))
138
  text_model.load_state_dict(sd_txt, strict=False)
139
  text_model.to(DEVICE).eval()
140
 
 
112
 
113
  # ---- Image model
114
  img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
115
+ sd_img = clean_state_dict(safe_torch_load("best_scin_image.pt"))
116
  img_model.load_state_dict(sd_img, strict=True)
117
  img_model.to(DEVICE).eval()
118
 
 
134
 
135
  tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
136
  text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
137
+ sd_txt = clean_state_dict(safe_torch_load("best_scin_text.pt"))
138
  text_model.load_state_dict(sd_txt, strict=False)
139
  text_model.to(DEVICE).eval()
140