update deit
Browse files- predict.py +2 -2
predict.py
CHANGED
|
@@ -33,8 +33,8 @@ class DeiT(nn.Module):
|
|
| 33 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
| 34 |
print("Model path:", model_path)
|
| 35 |
model = DeiT(num_classes=len(class_names))
|
| 36 |
-
|
| 37 |
-
model.load_state_dict(
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
#deit transform
|
|
|
|
| 33 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
| 34 |
print("Model path:", model_path)
|
| 35 |
model = DeiT(num_classes=len(class_names))
|
| 36 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 37 |
+
model.load_state_dict(state_dict)
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
#deit transform
|