Update README.md
Browse files
README.md
CHANGED
|
@@ -22,7 +22,7 @@ The pretrained model weights are hosted externally and can be downloaded here:
|
|
| 22 |
|
| 23 |
➡️ **https://staging.cortex.thetavision.nl/dataset-provider/listing/2/**
|
| 24 |
|
| 25 |
-
Download the file (e.g., `
|
| 26 |
|
| 27 |
---
|
| 28 |
|
|
@@ -34,24 +34,22 @@ import torch
|
|
| 34 |
import timm
|
| 35 |
|
| 36 |
# Initialize ViT‑B backbone (no classifier head)
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Update this path to where you downloaded the checkpoint
|
| 40 |
-
ckpt_path = "
|
| 41 |
state = torch.load(ckpt_path, map_location="cpu")
|
| 42 |
-
|
| 43 |
-
# Handle state dict variations
|
| 44 |
-
if "model" in state:
|
| 45 |
-
state_dict = state["model"]
|
| 46 |
-
elif "state_dict" in state:
|
| 47 |
-
state_dict = state["state_dict"]
|
| 48 |
-
else:
|
| 49 |
-
state_dict = state
|
| 50 |
|
| 51 |
# Remove 'module.' prefix if present
|
| 52 |
-
clean_state = {k.replace("
|
| 53 |
-
model.load_state_dict(clean_state, strict=False)
|
| 54 |
-
|
| 55 |
model.eval()
|
| 56 |
```
|
| 57 |
|
|
|
|
| 22 |
|
| 23 |
➡️ **https://staging.cortex.thetavision.nl/dataset-provider/listing/2/**
|
| 24 |
|
| 25 |
+
Download the file (e.g., `dinov2.pth`) and place it locally or on your device.
|
| 26 |
|
| 27 |
---
|
| 28 |
|
|
|
|
| 34 |
import timm
|
| 35 |
|
| 36 |
# Initialize ViT‑B backbone (no classifier head)
|
| 37 |
+
|
| 38 |
+
model = timm.create_model("timm/vit_base_patch14_dinov2.lvd142m",
|
| 39 |
+
pretrained=False,
|
| 40 |
+
num_classes=0,
|
| 41 |
+
img_size=336,
|
| 42 |
+
)
|
| 43 |
|
| 44 |
# Update this path to where you downloaded the checkpoint
|
| 45 |
+
ckpt_path = "dinov2.pth"
|
| 46 |
state = torch.load(ckpt_path, map_location="cpu")
|
| 47 |
+
state_dict = state['teacher']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# Remove 'module.' prefix if present
|
| 50 |
+
clean_state = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 51 |
+
msg = model.load_state_dict(clean_state, strict=False)
|
| 52 |
+
print(msg)
|
| 53 |
model.eval()
|
| 54 |
```
|
| 55 |
|