TimJaspersTue commited on
Commit
136e9a6
·
verified ·
1 Parent(s): 57a8a81

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -14
README.md CHANGED
@@ -22,7 +22,7 @@ The pretrained model weights are hosted externally and can be downloaded here:
22
 
23
  ➡️ **https://staging.cortex.thetavision.nl/dataset-provider/listing/2/**
24
 
25
- Download the file (e.g., `vit_b_gastronet5m.pth`) and place it locally or on your server.
26
 
27
  ---
28
 
@@ -34,24 +34,22 @@ import torch
34
  import timm
35
 
36
  # Initialize ViT‑B backbone (no classifier head)
37
- model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=0)
 
 
 
 
 
38
 
39
  # Update this path to where you downloaded the checkpoint
40
- ckpt_path = "./weights/vit_b_gastronet5m.pth"
41
  state = torch.load(ckpt_path, map_location="cpu")
42
-
43
- # Handle state dict variations
44
- if "model" in state:
45
- state_dict = state["model"]
46
- elif "state_dict" in state:
47
- state_dict = state["state_dict"]
48
- else:
49
- state_dict = state
50
 
51
  # Remove 'module.' prefix if present
52
- clean_state = {k.replace("module.", ""): v for k, v in state_dict.items()}
53
- model.load_state_dict(clean_state, strict=False)
54
-
55
  model.eval()
56
  ```
57
 
 
22
 
23
  ➡️ **https://staging.cortex.thetavision.nl/dataset-provider/listing/2/**
24
 
25
+ Download the file (e.g., `dinov2.pth`) and place it locally or on your device.
26
 
27
  ---
28
 
 
34
  import timm
35
 
36
  # Initialize ViT‑B backbone (no classifier head)
37
+
38
+ model = timm.create_model("timm/vit_base_patch14_dinov2.lvd142m",
39
+ pretrained=False,
40
+ num_classes=0,
41
+ img_size=336,
42
+ )
43
 
44
  # Update this path to where you downloaded the checkpoint
45
+ ckpt_path = "dinov2.pth"
46
  state = torch.load(ckpt_path, map_location="cpu")
47
+ state_dict = state['teacher']
 
 
 
 
 
 
 
48
 
49
  # Remove 'module.' prefix if present
50
+ clean_state = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
51
+ msg = model.load_state_dict(clean_state, strict=False)
52
+ print(msg)
53
  model.eval()
54
  ```
55