Create checkpoints/readme.md
Browse files- checkpoints/readme.md +46 -0
checkpoints/readme.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- pytorch
|
| 5 |
+
- jepa
|
| 6 |
+
- self-supervised-learning
|
| 7 |
+
- checkpoints
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Core JEPA: Pretrained Checkpoints
|
| 11 |
+
|
| 12 |
+
[](https://pytorch.org/)
|
| 13 |
+
|
| 14 |
+
This repository contains the weights and checkpoints for the **Core JEPA** (based on LeJEPA) model.
|
| 15 |
+
|
| 16 |
+
Due to storage optimization, the heavy weights are hosted in the associated artifacts repository but can be accessed directly via the links below.
|
| 17 |
+
|
| 18 |
+
## 📥 Download Weights
|
| 19 |
+
|
| 20 |
+
| Model Variant | Filename | Status | Direct Link |
|
| 21 |
+
| :--- | :--- | :--- | :--- |
|
| 22 |
+
| **LeJEPA-Large** | `lejepa-l.pt` | ✅ Available | [**Download .pt File**](https://huggingface.co/datasets/gajeshladharai/artifacts/resolve/main/core-jepa/lejepa-l.pt) |
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## 💻 Usage
|
| 27 |
+
|
| 28 |
+
### Option 1: Load directly in Python (Recommended)
|
| 29 |
+
You can load these weights directly into the `mapminer` model using `torch.hub`. This handles downloading, caching, and key remapping automatically.
|
| 30 |
+
```python
|
| 31 |
+
from mapminer import models
|
| 32 |
+
|
| 33 |
+
jepa = models.DINOv3(pretrained=False)
|
| 34 |
+
|
| 35 |
+
ckpt = "https://huggingface.co/datasets/gajeshladharai/artifacts/resolve/main/core-jepa/lejepa-l.pt"
|
| 36 |
+
ckpt = torch.hub.load_state_dict_from_url(ckpt, map_location='cpu')
|
| 37 |
+
jepa.load_state_dict({k.replace('encoder.model.', 'model.'): v for k, v in ckpt.items()},strict=False)
|
| 38 |
+
jepa.eval()
|
| 39 |
+
|
| 40 |
+
# x = uint8 image
|
| 41 |
+
# normalize with model's preprocess
|
| 42 |
+
x = jepa.normalize(x)
|
| 43 |
+
|
| 44 |
+
# forward pass
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
emb = jepa(x)
|