MeshMax commited on
Commit
a379734
·
verified ·
1 Parent(s): a5705f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -63,7 +63,7 @@ text_model = AutoModel.from_pretrained(TEXT_MODEL).to(DEVICE)
63
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
64
  head = MultimodalRegressor().to(DEVICE)
65
 
66
- ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE)
67
  if "text_model_state" in ckpt:
68
  text_model.load_state_dict(ckpt["text_model_state"])
69
  if "img_model_state" in ckpt:
 
63
  img_model = timm.create_model(IMG_MODEL, pretrained=False, num_classes=0).to(DEVICE)
64
  head = MultimodalRegressor().to(DEVICE)
65
 
66
+ ckpt = torch.load(MODEL_FILENAME, map_location=DEVICE, weights_only=False)
67
  if "text_model_state" in ckpt:
68
  text_model.load_state_dict(ckpt["text_model_state"])
69
  if "img_model_state" in ckpt: