Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -112,7 +112,7 @@ MAX_LEN = int(CFG.get("max_len", 128))
|
|
| 112 |
|
| 113 |
# ---- Image model
|
| 114 |
img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
|
| 115 |
-
sd_img = clean_state_dict(
|
| 116 |
img_model.load_state_dict(sd_img, strict=True)
|
| 117 |
img_model.to(DEVICE).eval()
|
| 118 |
|
|
@@ -134,7 +134,7 @@ class TextClassifier(nn.Module):
|
|
| 134 |
|
| 135 |
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
| 136 |
text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
|
| 137 |
-
sd_txt = clean_state_dict(
|
| 138 |
text_model.load_state_dict(sd_txt, strict=False)
|
| 139 |
text_model.to(DEVICE).eval()
|
| 140 |
|
|
|
|
| 112 |
|
| 113 |
# ---- Image model
|
| 114 |
img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
|
| 115 |
+
sd_img = clean_state_dict(safe_torch_load("best_scin_image.pt"))
|
| 116 |
img_model.load_state_dict(sd_img, strict=True)
|
| 117 |
img_model.to(DEVICE).eval()
|
| 118 |
|
|
|
|
| 134 |
|
| 135 |
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
| 136 |
text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
|
| 137 |
+
sd_txt = clean_state_dict(safe_torch_load("best_scin_text.pt"))
|
| 138 |
text_model.load_state_dict(sd_txt, strict=False)
|
| 139 |
text_model.to(DEVICE).eval()
|
| 140 |
|