Copilot Copilot commited on
Commit ·
d28c63e
1
Parent(s): d9e4621
Add AI-Endo project hub UI
Browse filesCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +1 -1
- README.md +33 -11
- app.py +544 -114
- dinov2/.github/workflows/lint.yaml +38 -0
- dinov2/.gitignore +11 -0
- dinov2/CODE_OF_CONDUCT.md +80 -0
- dinov2/CONTRIBUTING.md +31 -0
- dinov2/LICENSE +203 -0
- dinov2/MODEL_CARD.md +272 -0
- dinov2/README.md +620 -0
- dinov2/conda-extras.yaml +24 -0
- dinov2/conda.yaml +21 -0
- dinov2/pyproject.toml +29 -0
- dinov2/requirements-dev.txt +3 -0
- dinov2/requirements-extras.txt +2 -0
- dinov2/requirements.txt +11 -0
- dinov2/scripts/lint.sh +28 -0
- dinov2/setup.cfg +8 -0
- dinov2/setup.py +88 -0
- explainability.py +112 -0
- model/transformer.py +28 -8
- model_manager.py +10 -0
- predictor.py +372 -15
- runtime-requirements.txt +1 -1
- scripts/publish_model_repo.py +156 -0
- scripts/publish_space_repo.py +98 -0
- scripts/stage_space_bundle.py +104 -0
- scripts/stage_vendor_sources.py +38 -0
- vjepa2/.flake8 +5 -0
- vjepa2/.github/workflows/base_tests.yaml +29 -0
- vjepa2/.github/workflows/linters.yaml +48 -0
- vjepa2/.gitignore +32 -0
- vjepa2/APACHE-LICENSE +201 -0
- vjepa2/CHANGELOG.md +5 -0
- vjepa2/CODE_OF_CONDUCT.md +80 -0
- vjepa2/CONTRIBUTING.md +39 -0
- vjepa2/LICENSE +21 -0
- vjepa2/README.md +450 -0
- vjepa2/app/__init__.py +0 -0
- vjepa2/app/main.py +84 -0
- vjepa2/app/main_distributed.py +269 -0
- vjepa2/app/scaffold.py +17 -0
- vjepa2/app/vjepa/train.py +536 -0
- vjepa2/app/vjepa/transforms.py +154 -0
- vjepa2/app/vjepa/utils.py +267 -0
- vjepa2/app/vjepa_droid/droid.py +232 -0
- vjepa2/app/vjepa_droid/train.py +524 -0
- vjepa2/app/vjepa_droid/transforms.py +156 -0
- vjepa2/app/vjepa_droid/utils.py +253 -0
- vjepa2/configs/eval/vitg-384/coin.yaml +163 -0
Dockerfile
CHANGED
|
@@ -4,7 +4,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
|
| 4 |
PYTHONUNBUFFERED=1 \
|
| 5 |
PIP_NO_CACHE_DIR=1 \
|
| 6 |
SPACE_MODEL_DIR=/app/model \
|
| 7 |
-
SPACE_ENABLED_MODELS=dinov2 \
|
| 8 |
SPACE_DEFAULT_MODEL=dinov2
|
| 9 |
|
| 10 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
|
|
| 4 |
PYTHONUNBUFFERED=1 \
|
| 5 |
PIP_NO_CACHE_DIR=1 \
|
| 6 |
SPACE_MODEL_DIR=/app/model \
|
| 7 |
+
SPACE_ENABLED_MODELS=dinov2,aiendo,vjepa2 \
|
| 8 |
SPACE_DEFAULT_MODEL=dinov2
|
| 9 |
|
| 10 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
README.md
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
-
|
| 2 |
-
title: DINO-ENDO Phase Recognition
|
| 3 |
emoji: 🩺
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
|
@@ -7,11 +6,12 @@ sdk: docker
|
|
| 7 |
app_port: 7860
|
| 8 |
---
|
| 9 |
|
| 10 |
-
#
|
| 11 |
|
| 12 |
This folder is an isolated Hugging Face Space scaffold for the phase-recognition models in this repository.
|
| 13 |
-
It is intentionally separate from the existing FastAPI webapp and
|
| 14 |
-
The
|
|
|
|
| 15 |
|
| 16 |
## Supported model families
|
| 17 |
|
|
@@ -42,15 +42,15 @@ A fully local `model/` folder is still supported as a fallback.
|
|
| 42 |
|
| 43 |
## Default Space behavior
|
| 44 |
|
| 45 |
-
The Docker Space is configured to boot as a **
|
| 46 |
|
| 47 |
-
- `SPACE_ENABLED_MODELS=dinov2`
|
| 48 |
- `SPACE_DEFAULT_MODEL=dinov2`
|
| 49 |
|
| 50 |
-
If you want
|
| 51 |
|
| 52 |
```text
|
| 53 |
-
SPACE_ENABLED_MODELS=dinov2
|
| 54 |
SPACE_DEFAULT_MODEL=dinov2
|
| 55 |
```
|
| 56 |
|
|
@@ -67,11 +67,22 @@ If a required checkpoint is missing locally, it will try to download it from the
|
|
| 67 |
|
| 68 |
### Upload and dashboard behavior
|
| 69 |
|
| 70 |
-
- The
|
|
|
|
|
|
|
| 71 |
- MP4 is the primary video upload format, while `mov`, `avi`, `mkv`, `webm`, and `m4v` remain enabled as fallback containers.
|
| 72 |
- `.streamlit/config.toml` raises the default Streamlit single-file upload ceiling to **4096 MB** for this Space.
|
| 73 |
- Uploaded videos are immediately spooled to local disk for metadata probing and analysis, instead of repeatedly reading the in-memory upload object on every rerun.
|
| 74 |
- The UI shows file size, duration, fps, frame count, resolution, working-storage headroom, and suppresses inline preview for very large uploads to keep the browser path lighter.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
### Common environment variables
|
| 77 |
|
|
@@ -116,6 +127,15 @@ That script refreshes the vendored source copies inside this folder before publi
|
|
| 116 |
4. Upload your checkpoints to HF **model repo(s)**.
|
| 117 |
5. Configure the relevant repo IDs (and `HF_TOKEN` only if the repos are private).
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
## Local smoke test
|
| 120 |
|
| 121 |
Once the Space dependencies are installed, you can smoke test a predictor directly:
|
|
@@ -129,12 +149,14 @@ python scripts/smoke_test.py --model vjepa2 --model-dir /path/to/model
|
|
| 129 |
## Scope of v1
|
| 130 |
|
| 131 |
- Streamlit UI
|
| 132 |
-
-
|
|
|
|
| 133 |
- image upload and video upload
|
| 134 |
- dashboard-style model/runtime status
|
| 135 |
- robust video metadata probing with OpenCV + ffprobe fallback
|
| 136 |
- large single-file uploads up to the configured Streamlit cap
|
| 137 |
- per-frame phase timeline output for video
|
|
|
|
| 138 |
- JSON / CSV export
|
| 139 |
|
| 140 |
Not included in v1:
|
|
|
|
| 1 |
+
title: AI-Endo Project Hub
|
|
|
|
| 2 |
emoji: 🩺
|
| 3 |
colorFrom: blue
|
| 4 |
colorTo: green
|
|
|
|
| 6 |
app_port: 7860
|
| 7 |
---
|
| 8 |
|
| 9 |
+
# AI-Endo Project Hub
|
| 10 |
|
| 11 |
This folder is an isolated Hugging Face Space scaffold for the phase-recognition models in this repository.
|
| 12 |
+
It is intentionally separate from the existing FastAPI webapp and is designed to expose **DINO-Endo, AI-Endo, and V-JEPA2** on paid GPU hardware such as **1x A10G (24 GB VRAM)**.
|
| 13 |
+
The public UI now behaves like a small **project hub**: DINO-Endo Surgery is the first featured workspace, and the same landing page can later host additional projects without rebuilding the overall shell.
|
| 14 |
+
The default featured model remains **DINO-Endo**, but the same Space can load and unload all three model families one at a time.
|
| 15 |
|
| 16 |
## Supported model families
|
| 17 |
|
|
|
|
| 42 |
|
| 43 |
## Default Space behavior
|
| 44 |
|
| 45 |
+
The Docker Space is configured to boot as a **three-model public demo** with **DINO-Endo** selected by default:
|
| 46 |
|
| 47 |
+
- `SPACE_ENABLED_MODELS=dinov2,aiendo,vjepa2`
|
| 48 |
- `SPACE_DEFAULT_MODEL=dinov2`
|
| 49 |
|
| 50 |
+
If you want to narrow the public picker to a subset of models, override those environment variables in Space Settings, for example:
|
| 51 |
|
| 52 |
```text
|
| 53 |
+
SPACE_ENABLED_MODELS=dinov2
|
| 54 |
SPACE_DEFAULT_MODEL=dinov2
|
| 55 |
```
|
| 56 |
|
|
|
|
| 67 |
|
| 68 |
### Upload and dashboard behavior
|
| 69 |
|
| 70 |
+
- The top of the app is a reusable project-hub landing section, with DINO-Endo Surgery as the current live workspace.
|
| 71 |
+
- The active model family is selected through a visible **model slider** in the workspace rather than a hidden picker.
|
| 72 |
+
- The Space now keeps a single active predictor loaded at a time and unloads the previous model when the model slider changes.
|
| 73 |
- MP4 is the primary video upload format, while `mov`, `avi`, `mkv`, `webm`, and `m4v` remain enabled as fallback containers.
|
| 74 |
- `.streamlit/config.toml` raises the default Streamlit single-file upload ceiling to **4096 MB** for this Space.
|
| 75 |
- Uploaded videos are immediately spooled to local disk for metadata probing and analysis, instead of repeatedly reading the in-memory upload object on every rerun.
|
| 76 |
- The UI shows file size, duration, fps, frame count, resolution, working-storage headroom, and suppresses inline preview for very large uploads to keep the browser path lighter.
|
| 77 |
+
- V-JEPA2 is labeled as a slower first load so users understand the cold-cache cost of its very large encoder checkpoint.
|
| 78 |
+
|
| 79 |
+
### Explainability behavior
|
| 80 |
+
|
| 81 |
+
- The sidebar includes an opt-in live explainability toggle for encoder/decoder visualizations.
|
| 82 |
+
- DINO-Endo and V-JEPA2 use true encoder self-attention maps, while AI-Endo uses a labeled proxy encoder overlay from ResNet activations.
|
| 83 |
+
- AI-Endo and DINO-Endo render decoder-side temporal attention strips from the custom Transformer path.
|
| 84 |
+
- V-JEPA2 renders a labeled proxy temporal strip from decoder feature energy because its classifier head is an MLP, not an attention block.
|
| 85 |
+
- Encoder controls expose **layer/head sliders** when the loaded model supports true encoder attention.
|
| 86 |
|
| 87 |
### Common environment variables
|
| 88 |
|
|
|
|
| 127 |
4. Upload your checkpoints to HF **model repo(s)**.
|
| 128 |
5. Configure the relevant repo IDs (and `HF_TOKEN` only if the repos are private).
|
| 129 |
|
| 130 |
+
### Deployment helper scripts
|
| 131 |
+
|
| 132 |
+
- `python scripts/stage_space_bundle.py --overwrite --output-dir /tmp/dino_space_minimal_upload`
|
| 133 |
+
- stages a code-only upload bundle for the current multi-model Space without local caches or checkpoints.
|
| 134 |
+
- `python scripts/publish_model_repo.py --family aiendo --repo-id <owner/repo> --model-dir /path/to/model`
|
| 135 |
+
- publishes one model family to a Hugging Face **model repo** and automatically switches to `upload_large_folder()` for very large bundles.
|
| 136 |
+
- `python scripts/publish_space_repo.py --repo-id <owner/space> --dino-model-repo-id <owner/dino-repo> --aiendo-model-repo-id <owner/aiendo-repo> --vjepa2-model-repo-id <owner/vjepa2-repo>`
|
| 137 |
+
- stages/uploads the Docker Space bundle and updates the key Space environment variables for the three-model demo.
|
| 138 |
+
|
| 139 |
## Local smoke test
|
| 140 |
|
| 141 |
Once the Space dependencies are installed, you can smoke test a predictor directly:
|
|
|
|
| 149 |
## Scope of v1
|
| 150 |
|
| 151 |
- Streamlit UI
|
| 152 |
+
- project-hub landing page with DINO-Endo Surgery as the first hosted workspace
|
| 153 |
+
- three-model slider for DINO-Endo, AI-Endo, and V-JEPA2, with DINO-Endo selected by default
|
| 154 |
- image upload and video upload
|
| 155 |
- dashboard-style model/runtime status
|
| 156 |
- robust video metadata probing with OpenCV + ffprobe fallback
|
| 157 |
- large single-file uploads up to the configured Streamlit cap
|
| 158 |
- per-frame phase timeline output for video
|
| 159 |
+
- optional live encoder/decoder explainability sidebar with true attention where available and labeled proxies elsewhere
|
| 160 |
- JSON / CSV export
|
| 161 |
|
| 162 |
Not included in v1:
|
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import json
|
|
| 4 |
import os
|
| 5 |
import time
|
| 6 |
from collections import Counter
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
import cv2
|
|
@@ -13,6 +14,7 @@ import streamlit as st
|
|
| 13 |
import torch
|
| 14 |
from PIL import Image
|
| 15 |
|
|
|
|
| 16 |
from model_manager import SpaceModelManager
|
| 17 |
from model_registry import MODEL_SPECS, get_model_source_summary
|
| 18 |
from predictor import MODEL_LABELS, PHASE_LABELS, normalize_model_key
|
|
@@ -28,7 +30,80 @@ from video_utils import (
|
|
| 28 |
spool_uploaded_video,
|
| 29 |
)
|
| 30 |
|
| 31 |
-
st.set_page_config(page_title="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def _phase_index(phase: str) -> int:
|
|
@@ -43,6 +118,10 @@ def _image_to_rgb(uploaded_file) -> np.ndarray:
|
|
| 43 |
return np.array(image)
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def _enabled_model_keys() -> list[str]:
|
| 47 |
configured = os.getenv("SPACE_ENABLED_MODELS", "").strip()
|
| 48 |
if not configured:
|
|
@@ -82,7 +161,223 @@ def _default_model_key(enabled_model_keys: list[str]) -> str:
|
|
| 82 |
def _space_caption(enabled_model_keys: list[str]) -> str:
|
| 83 |
if enabled_model_keys == ["dinov2"]:
|
| 84 |
return "Streamlit Hugging Face Space demo for the DINO-Endo phase-recognition stack."
|
| 85 |
-
return "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def _get_model_manager() -> SpaceModelManager:
|
|
@@ -126,7 +421,100 @@ def _prepare_staged_video(uploaded_file):
|
|
| 126 |
return temp_path, meta
|
| 127 |
|
| 128 |
|
| 129 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
temp_path = Path(video_path)
|
| 131 |
capture = cv2.VideoCapture(str(temp_path))
|
| 132 |
if not capture.isOpened():
|
|
@@ -141,6 +529,7 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
|
|
| 141 |
records = []
|
| 142 |
processed = 0
|
| 143 |
frame_index = 0
|
|
|
|
| 144 |
|
| 145 |
try:
|
| 146 |
while True:
|
|
@@ -154,7 +543,7 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
|
|
| 154 |
|
| 155 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 156 |
started = time.perf_counter()
|
| 157 |
-
result = predictor.predict(rgb)
|
| 158 |
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
| 159 |
|
| 160 |
probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
|
|
@@ -174,6 +563,9 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
|
|
| 174 |
records.append(record)
|
| 175 |
processed += 1
|
| 176 |
|
|
|
|
|
|
|
|
|
|
| 177 |
if total_frames > 0:
|
| 178 |
progress.progress(min(frame_index + 1, total_frames) / total_frames)
|
| 179 |
else:
|
|
@@ -192,18 +584,6 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
|
|
| 192 |
return records, {"fps": fps, "total_frames": total_frames, "sampled_frames": processed}
|
| 193 |
|
| 194 |
|
| 195 |
-
def _records_to_frame(records):
|
| 196 |
-
if not records:
|
| 197 |
-
return pd.DataFrame(columns=["frame_index", "timestamp_sec", "phase", "confidence"])
|
| 198 |
-
return pd.DataFrame.from_records(records)
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def _download_payloads(df: pd.DataFrame):
|
| 202 |
-
json_payload = df.to_json(orient="records", indent=2).encode("utf-8")
|
| 203 |
-
csv_payload = df.to_csv(index=False).encode("utf-8")
|
| 204 |
-
return json_payload, csv_payload
|
| 205 |
-
|
| 206 |
-
|
| 207 |
def _render_single_result(result: dict):
|
| 208 |
probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
|
| 209 |
metrics = st.columns(3)
|
|
@@ -215,7 +595,7 @@ def _render_single_result(result: dict):
|
|
| 215 |
st.bar_chart(prob_df.set_index("phase"))
|
| 216 |
st.download_button(
|
| 217 |
label="Download JSON",
|
| 218 |
-
data=json.dumps(result, indent=2).encode("utf-8"),
|
| 219 |
file_name="phase_prediction.json",
|
| 220 |
mime="application/json",
|
| 221 |
key="download-single-json",
|
|
@@ -269,31 +649,23 @@ def main():
|
|
| 269 |
enabled_model_keys = _enabled_model_keys()
|
| 270 |
default_model_key = _default_model_key(enabled_model_keys)
|
| 271 |
manager = _get_model_manager()
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
|
| 274 |
st.caption(_space_caption(enabled_model_keys))
|
| 275 |
|
| 276 |
-
st.
|
| 277 |
-
if len(enabled_model_keys) == 1:
|
| 278 |
-
model_key = enabled_model_keys[0]
|
| 279 |
-
st.sidebar.write(MODEL_LABELS[model_key])
|
| 280 |
-
else:
|
| 281 |
-
model_key = st.sidebar.selectbox(
|
| 282 |
-
"Model",
|
| 283 |
-
options=enabled_model_keys,
|
| 284 |
-
index=enabled_model_keys.index(default_model_key),
|
| 285 |
-
format_func=lambda key: MODEL_LABELS[key],
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
previous_selected_model_key = st.session_state.get("selected_model_key")
|
| 289 |
-
st.session_state["selected_model_key"] = model_key
|
| 290 |
if previous_selected_model_key is not None and previous_selected_model_key != model_key:
|
| 291 |
manager.unload_model()
|
| 292 |
|
|
|
|
|
|
|
| 293 |
source_summary = get_model_source_summary(model_key)
|
| 294 |
-
manager_status = manager.status()
|
| 295 |
st.sidebar.markdown("### Runtime")
|
| 296 |
st.sidebar.write(f"Selected model: `{MODEL_LABELS[model_key]}`")
|
|
|
|
| 297 |
st.sidebar.write(f"CUDA available: `{torch.cuda.is_available()}`")
|
| 298 |
if torch.cuda.is_available():
|
| 299 |
st.sidebar.write(f"Device: `{torch.cuda.get_device_name(torch.cuda.current_device())}`")
|
|
@@ -308,16 +680,13 @@ def main():
|
|
| 308 |
st.sidebar.write(f"HF repo: `{source_summary['repo_id'] or 'local-only'}`")
|
| 309 |
if source_summary["subfolder"]:
|
| 310 |
st.sidebar.write(f"Repo subfolder: `{source_summary['subfolder']}`")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
st.sidebar.write(f"Video upload cap: `{STREAMLIT_SERVER_MAX_UPLOAD_MB} MB`")
|
| 312 |
st.sidebar.write(f"Working storage free: `{format_bytes(get_workspace_free_bytes())}`")
|
| 313 |
|
| 314 |
-
if manager_status.is_loaded and manager_status.active_model_label:
|
| 315 |
-
st.sidebar.success(f"Loaded model: {manager_status.active_model_label}")
|
| 316 |
-
else:
|
| 317 |
-
st.sidebar.info("No model is currently loaded.")
|
| 318 |
-
if manager_status.last_error:
|
| 319 |
-
st.sidebar.error(manager_status.last_error)
|
| 320 |
-
|
| 321 |
prepare_col, unload_col = st.sidebar.columns(2)
|
| 322 |
if prepare_col.button("Load model", use_container_width=True):
|
| 323 |
try:
|
|
@@ -331,90 +700,151 @@ def main():
|
|
| 331 |
manager.unload_model()
|
| 332 |
st.sidebar.success("Model unloaded")
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
image_tab, video_tab = st.tabs(["Image", "Video"])
|
| 335 |
|
| 336 |
with image_tab:
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
with video_tab:
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
"MP4 is preferred; MOV/AVI/MKV/WEBM/M4V stay enabled as fallback containers."
|
| 365 |
-
),
|
| 366 |
-
max_upload_size=STREAMLIT_SERVER_MAX_UPLOAD_MB,
|
| 367 |
)
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
if should_show_inline_preview(video_meta["file_size_bytes"]):
|
| 390 |
-
st.video(uploaded_video)
|
| 391 |
else:
|
| 392 |
-
st.
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
)
|
| 397 |
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
with st.spinner(f"Running {MODEL_LABELS[model_key]} on {uploaded_video.name}..."):
|
| 401 |
-
predictor = manager.get_predictor(model_key)
|
| 402 |
-
records, analysis_meta = _analyse_video(
|
| 403 |
-
temp_path,
|
| 404 |
-
predictor,
|
| 405 |
-
frame_stride=frame_stride,
|
| 406 |
-
max_frames=max_frames,
|
| 407 |
-
)
|
| 408 |
-
meta = {
|
| 409 |
-
**video_meta,
|
| 410 |
-
**analysis_meta,
|
| 411 |
-
}
|
| 412 |
-
except Exception as exc:
|
| 413 |
-
st.error(str(exc))
|
| 414 |
else:
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
|
| 420 |
if __name__ == "__main__":
|
|
|
|
| 4 |
import os
|
| 5 |
import time
|
| 6 |
from collections import Counter
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
import cv2
|
|
|
|
| 14 |
import torch
|
| 15 |
from PIL import Image
|
| 16 |
|
| 17 |
+
from explainability import ExplainabilitySpec
|
| 18 |
from model_manager import SpaceModelManager
|
| 19 |
from model_registry import MODEL_SPECS, get_model_source_summary
|
| 20 |
from predictor import MODEL_LABELS, PHASE_LABELS, normalize_model_key
|
|
|
|
| 30 |
spool_uploaded_video,
|
| 31 |
)
|
| 32 |
|
| 33 |
+
st.set_page_config(page_title="AI-Endo Project Hub", layout="wide")
|
| 34 |
+
|
| 35 |
+
MODEL_OPTION_LABELS = {
|
| 36 |
+
"aiendo": "AI-Endo",
|
| 37 |
+
"dinov2": "DINO-Endo",
|
| 38 |
+
"vjepa2": "V-JEPA2 (slower first load)",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
MODEL_LOAD_NOTES = {
|
| 42 |
+
"aiendo": "AI-Endo uses the ResNet + MS-TCN + Transformer stack.",
|
| 43 |
+
"dinov2": "DINO-Endo remains the default public model in this demo.",
|
| 44 |
+
"vjepa2": "V-JEPA2 can take longer on the first load because the encoder checkpoint is several gigabytes.",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
FALLBACK_EXPLAINABILITY_SPECS = {
|
| 48 |
+
"aiendo": ExplainabilitySpec(
|
| 49 |
+
encoder_mode="proxy",
|
| 50 |
+
encoder_label="ResNet layer4 activation energy (proxy)",
|
| 51 |
+
decoder_mode="attention",
|
| 52 |
+
decoder_label="Temporal Transformer attention",
|
| 53 |
+
),
|
| 54 |
+
"dinov2": ExplainabilitySpec(
|
| 55 |
+
encoder_mode="attention",
|
| 56 |
+
encoder_label="DINOv2 encoder self-attention",
|
| 57 |
+
decoder_mode="attention",
|
| 58 |
+
decoder_label="Fusion Transformer temporal attention",
|
| 59 |
+
encoder_layer_count=12,
|
| 60 |
+
encoder_head_count=6,
|
| 61 |
+
),
|
| 62 |
+
"vjepa2": ExplainabilitySpec(
|
| 63 |
+
encoder_mode="attention",
|
| 64 |
+
encoder_label="V-JEPA2 encoder self-attention",
|
| 65 |
+
decoder_mode="proxy",
|
| 66 |
+
decoder_label="MLP decoder feature energy (proxy)",
|
| 67 |
+
encoder_layer_count=24,
|
| 68 |
+
encoder_head_count=16,
|
| 69 |
+
),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
SPACE_TITLE = "AI-Endo Project Hub"
|
| 74 |
+
FEATURED_PROJECT_TITLE = "DINO-Endo Surgery Workspace"
|
| 75 |
+
MODEL_SLIDER_KEY = "workspace-model-slider"
|
| 76 |
+
SELECTED_MODEL_STATE_KEY = "selected_model_key"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass(frozen=True)
|
| 80 |
+
class HostedProject:
|
| 81 |
+
key: str
|
| 82 |
+
title: str
|
| 83 |
+
status: str
|
| 84 |
+
summary: str
|
| 85 |
+
highlights: tuple[str, ...]
|
| 86 |
+
tags: tuple[str, ...]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
HOSTED_PROJECTS = (
|
| 90 |
+
HostedProject(
|
| 91 |
+
key="dino-endo-surgery",
|
| 92 |
+
title=FEATURED_PROJECT_TITLE,
|
| 93 |
+
status="Live now",
|
| 94 |
+
summary=(
|
| 95 |
+
"Upload single frames or full videos, swap between DINO-Endo, AI-Endo, and V-JEPA2, "
|
| 96 |
+
"and inspect optional explainability overlays inside one surgical phase-recognition workspace."
|
| 97 |
+
),
|
| 98 |
+
highlights=(
|
| 99 |
+
"Large video uploads with on-disk staging",
|
| 100 |
+
"One-click JSON and CSV export",
|
| 101 |
+
"Live encoder and decoder explainability",
|
| 102 |
+
"Manual load and unload for GPU-safe model switching",
|
| 103 |
+
),
|
| 104 |
+
tags=("Computer vision", "Medical video", "Multi-model inference"),
|
| 105 |
+
),
|
| 106 |
+
)
|
| 107 |
|
| 108 |
|
| 109 |
def _phase_index(phase: str) -> int:
|
|
|
|
| 118 |
return np.array(image)
|
| 119 |
|
| 120 |
|
| 121 |
+
def _model_option_label(model_key: str) -> str:
|
| 122 |
+
return MODEL_OPTION_LABELS.get(model_key, MODEL_LABELS.get(model_key, model_key))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
def _enabled_model_keys() -> list[str]:
|
| 126 |
configured = os.getenv("SPACE_ENABLED_MODELS", "").strip()
|
| 127 |
if not configured:
|
|
|
|
| 161 |
def _space_caption(enabled_model_keys: list[str]) -> str:
|
| 162 |
if enabled_model_keys == ["dinov2"]:
|
| 163 |
return "Streamlit Hugging Face Space demo for the DINO-Endo phase-recognition stack."
|
| 164 |
+
return "Streamlit Hugging Face Space demo for DINO-Endo, AI-Endo, and V-JEPA2 with one active model loaded at a time."
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _inject_app_styles() -> None:
|
| 168 |
+
st.markdown(
|
| 169 |
+
"""
|
| 170 |
+
<style>
|
| 171 |
+
.block-container {
|
| 172 |
+
padding-top: 2.4rem;
|
| 173 |
+
padding-bottom: 2rem;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
.hub-hero,
|
| 177 |
+
.hub-card,
|
| 178 |
+
.workspace-card {
|
| 179 |
+
border-radius: 22px;
|
| 180 |
+
border: 1px solid rgba(148, 163, 184, 0.22);
|
| 181 |
+
background: linear-gradient(180deg, rgba(15, 23, 42, 0.86), rgba(15, 23, 42, 0.66));
|
| 182 |
+
box-shadow: 0 20px 45px rgba(15, 23, 42, 0.18);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
.hub-hero {
|
| 186 |
+
padding: 2rem 2.25rem;
|
| 187 |
+
margin-bottom: 1rem;
|
| 188 |
+
background: linear-gradient(135deg, rgba(14, 165, 233, 0.18), rgba(16, 185, 129, 0.18), rgba(15, 23, 42, 0.9));
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.hub-eyebrow {
|
| 192 |
+
margin: 0;
|
| 193 |
+
color: #67e8f9;
|
| 194 |
+
font-size: 0.78rem;
|
| 195 |
+
font-weight: 700;
|
| 196 |
+
letter-spacing: 0.18em;
|
| 197 |
+
text-transform: uppercase;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
.hub-hero h1,
|
| 201 |
+
.workspace-card h2,
|
| 202 |
+
.hub-card h3 {
|
| 203 |
+
margin: 0.4rem 0 0 0;
|
| 204 |
+
color: #f8fafc;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
.hub-subtitle,
|
| 208 |
+
.workspace-copy,
|
| 209 |
+
.hub-card p,
|
| 210 |
+
.hub-card li {
|
| 211 |
+
color: rgba(226, 232, 240, 0.92);
|
| 212 |
+
line-height: 1.55;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.hub-subtitle {
|
| 216 |
+
margin-top: 0.8rem;
|
| 217 |
+
max-width: 62rem;
|
| 218 |
+
font-size: 1.03rem;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.hub-chip-row {
|
| 222 |
+
display: flex;
|
| 223 |
+
flex-wrap: wrap;
|
| 224 |
+
gap: 0.55rem;
|
| 225 |
+
margin-top: 1rem;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
.hub-chip,
|
| 229 |
+
.hub-status {
|
| 230 |
+
display: inline-flex;
|
| 231 |
+
align-items: center;
|
| 232 |
+
border-radius: 999px;
|
| 233 |
+
padding: 0.32rem 0.78rem;
|
| 234 |
+
font-size: 0.82rem;
|
| 235 |
+
font-weight: 600;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
.hub-chip {
|
| 239 |
+
background: rgba(15, 23, 42, 0.56);
|
| 240 |
+
border: 1px solid rgba(103, 232, 249, 0.24);
|
| 241 |
+
color: #e2e8f0;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
.hub-status {
|
| 245 |
+
background: rgba(34, 197, 94, 0.18);
|
| 246 |
+
border: 1px solid rgba(34, 197, 94, 0.28);
|
| 247 |
+
color: #bbf7d0;
|
| 248 |
+
margin-bottom: 0.7rem;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
.hub-card,
|
| 252 |
+
.workspace-card {
|
| 253 |
+
padding: 1.25rem 1.4rem;
|
| 254 |
+
height: 100%;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
.hub-card ul {
|
| 258 |
+
margin: 0.8rem 0 0 1rem;
|
| 259 |
+
padding: 0;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
.workspace-card {
|
| 263 |
+
margin: 0.3rem 0 1rem 0;
|
| 264 |
+
}
|
| 265 |
+
</style>
|
| 266 |
+
""",
|
| 267 |
+
unsafe_allow_html=True,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _render_hub_chips(labels: list[str] | tuple[str, ...]) -> str:
|
| 272 |
+
return "".join(f'<span class="hub-chip">{label}</span>' for label in labels)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _render_project_hub(enabled_model_keys: list[str]) -> None:
|
| 276 |
+
featured = HOSTED_PROJECTS[0]
|
| 277 |
+
enabled_labels = [_model_option_label(key) for key in enabled_model_keys]
|
| 278 |
+
st.markdown(
|
| 279 |
+
f"""
|
| 280 |
+
<section class="hub-hero">
|
| 281 |
+
<p class="hub-eyebrow">Multi-project landing page</p>
|
| 282 |
+
<h1>{SPACE_TITLE}</h1>
|
| 283 |
+
<p class="hub-subtitle">
|
| 284 |
+
A polished landing page for applied vision demos. {FEATURED_PROJECT_TITLE} is the first live workspace,
|
| 285 |
+
and the layout is ready to host more projects later without rebuilding the app shell.
|
| 286 |
+
</p>
|
| 287 |
+
<div class="hub-chip-row">
|
| 288 |
+
{_render_hub_chips(tuple(enabled_labels) + ("Future-project ready", "Streamlit + Docker Space"))}
|
| 289 |
+
</div>
|
| 290 |
+
</section>
|
| 291 |
+
""",
|
| 292 |
+
unsafe_allow_html=True,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
metrics = st.columns(4)
|
| 296 |
+
metrics[0].metric("Hosted projects", len(HOSTED_PROJECTS))
|
| 297 |
+
metrics[1].metric("Model families", len(enabled_model_keys))
|
| 298 |
+
metrics[2].metric("Explainability", "Opt-in")
|
| 299 |
+
metrics[3].metric("Exports", "JSON + CSV")
|
| 300 |
+
|
| 301 |
+
left_col, right_col = st.columns([1.8, 1.2], gap="large")
|
| 302 |
+
with left_col:
|
| 303 |
+
highlights_html = "".join(f"<li>{item}</li>" for item in featured.highlights)
|
| 304 |
+
st.markdown(
|
| 305 |
+
f"""
|
| 306 |
+
<section class="hub-card">
|
| 307 |
+
<span class="hub-status">{featured.status}</span>
|
| 308 |
+
<h3>{featured.title}</h3>
|
| 309 |
+
<p>{featured.summary}</p>
|
| 310 |
+
<div class="hub-chip-row">{_render_hub_chips(featured.tags)}</div>
|
| 311 |
+
<ul>{highlights_html}</ul>
|
| 312 |
+
</section>
|
| 313 |
+
""",
|
| 314 |
+
unsafe_allow_html=True,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
with right_col:
|
| 318 |
+
st.markdown(
|
| 319 |
+
"""
|
| 320 |
+
<section class="hub-card">
|
| 321 |
+
<span class="hub-status">Platform shell</span>
|
| 322 |
+
<h3>Ready for more demos</h3>
|
| 323 |
+
<p>
|
| 324 |
+
The top section now works as a reusable project hub instead of a one-off page. Add more project cards
|
| 325 |
+
and workspace blocks here later, while keeping one shared brand, layout, and deployment target.
|
| 326 |
+
</p>
|
| 327 |
+
<ul>
|
| 328 |
+
<li>Keep each project's controls inside its own workspace section.</li>
|
| 329 |
+
<li>Reuse the same landing-page hero, metrics, and project-card layout.</li>
|
| 330 |
+
<li>Preserve one-model-at-a-time loading so future demos stay GPU-friendly.</li>
|
| 331 |
+
</ul>
|
| 332 |
+
</section>
|
| 333 |
+
""",
|
| 334 |
+
unsafe_allow_html=True,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _render_workspace_header(enabled_model_keys: list[str], model_key: str) -> None:
|
| 339 |
+
selected_label = _model_option_label(model_key)
|
| 340 |
+
selection_note = (
|
| 341 |
+
"Use the model slider to move between DINO-Endo, AI-Endo, and V-JEPA2. "
|
| 342 |
+
"Only one model stays loaded at a time so the Space remains responsive on shared GPU hardware."
|
| 343 |
+
)
|
| 344 |
+
st.markdown(
|
| 345 |
+
f"""
|
| 346 |
+
<section class="workspace-card">
|
| 347 |
+
<p class="hub-eyebrow">Featured project</p>
|
| 348 |
+
<h2>{FEATURED_PROJECT_TITLE}</h2>
|
| 349 |
+
<p class="workspace-copy">
|
| 350 |
+
{selection_note}
|
| 351 |
+
</p>
|
| 352 |
+
<div class="hub-chip-row">
|
| 353 |
+
{_render_hub_chips(tuple(_model_option_label(key) for key in enabled_model_keys))}
|
| 354 |
+
<span class="hub-chip">Selected: {selected_label}</span>
|
| 355 |
+
</div>
|
| 356 |
+
</section>
|
| 357 |
+
""",
|
| 358 |
+
unsafe_allow_html=True,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _resolve_model_selection(enabled_model_keys: list[str], default_model_key: str) -> tuple[str | None, str]:
|
| 363 |
+
previous_selected_model_key = st.session_state.get(SELECTED_MODEL_STATE_KEY)
|
| 364 |
+
current_slider_value = st.session_state.get(MODEL_SLIDER_KEY)
|
| 365 |
+
if current_slider_value not in enabled_model_keys:
|
| 366 |
+
st.session_state[MODEL_SLIDER_KEY] = default_model_key
|
| 367 |
+
|
| 368 |
+
if len(enabled_model_keys) == 1:
|
| 369 |
+
model_key = enabled_model_keys[0]
|
| 370 |
+
st.session_state[MODEL_SLIDER_KEY] = model_key
|
| 371 |
+
return previous_selected_model_key, model_key
|
| 372 |
+
|
| 373 |
+
model_key = st.select_slider(
|
| 374 |
+
"Project model slider",
|
| 375 |
+
options=enabled_model_keys,
|
| 376 |
+
key=MODEL_SLIDER_KEY,
|
| 377 |
+
format_func=_model_option_label,
|
| 378 |
+
help="Prominent model-family slider for the DINO-Endo project workspace.",
|
| 379 |
+
)
|
| 380 |
+
return previous_selected_model_key, model_key
|
| 381 |
|
| 382 |
|
| 383 |
def _get_model_manager() -> SpaceModelManager:
|
|
|
|
| 421 |
return temp_path, meta
|
| 422 |
|
| 423 |
|
| 424 |
+
def _records_to_frame(records):
|
| 425 |
+
if not records:
|
| 426 |
+
return pd.DataFrame(columns=["frame_index", "timestamp_sec", "phase", "confidence"])
|
| 427 |
+
return pd.DataFrame.from_records(records)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _download_payloads(df: pd.DataFrame):
|
| 431 |
+
json_payload = df.to_json(orient="records", indent=2).encode("utf-8")
|
| 432 |
+
csv_payload = df.to_csv(index=False).encode("utf-8")
|
| 433 |
+
return json_payload, csv_payload
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _get_explainability_spec(manager: SpaceModelManager, model_key: str) -> ExplainabilitySpec:
|
| 437 |
+
predictor = manager.get_loaded_predictor(model_key)
|
| 438 |
+
if predictor is not None and hasattr(predictor, "get_explainability_spec"):
|
| 439 |
+
return predictor.get_explainability_spec()
|
| 440 |
+
return FALLBACK_EXPLAINABILITY_SPECS[model_key]
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _build_explainability_config(manager: SpaceModelManager, model_key: str):
|
| 444 |
+
spec = _get_explainability_spec(manager, model_key)
|
| 445 |
+
st.sidebar.markdown("### Explainability")
|
| 446 |
+
enabled = st.sidebar.toggle(
|
| 447 |
+
"Enable live encoder/decoder maps",
|
| 448 |
+
value=False,
|
| 449 |
+
help="Shows encoder heatmaps and decoder temporal strips on every processed frame. Leave this off if you want the fastest video analysis path.",
|
| 450 |
+
)
|
| 451 |
+
config = {"enabled": enabled}
|
| 452 |
+
if not enabled:
|
| 453 |
+
return config, spec
|
| 454 |
+
|
| 455 |
+
st.sidebar.caption(f"Encoder view: {spec.encoder_label}")
|
| 456 |
+
st.sidebar.caption(f"Decoder view: {spec.decoder_label}")
|
| 457 |
+
if spec.encoder_mode == "attention" and spec.encoder_layer_count > 0 and spec.encoder_head_count > 0:
|
| 458 |
+
default_layer = spec.encoder_layer_count - 1
|
| 459 |
+
config["encoder_layer"] = st.sidebar.slider(
|
| 460 |
+
"Encoder layer",
|
| 461 |
+
min_value=1,
|
| 462 |
+
max_value=spec.encoder_layer_count,
|
| 463 |
+
value=default_layer + 1,
|
| 464 |
+
key=f"explainability-layer-{model_key}",
|
| 465 |
+
) - 1
|
| 466 |
+
config["encoder_head"] = st.sidebar.slider(
|
| 467 |
+
"Encoder head",
|
| 468 |
+
min_value=1,
|
| 469 |
+
max_value=spec.encoder_head_count,
|
| 470 |
+
value=1,
|
| 471 |
+
key=f"explainability-head-{model_key}",
|
| 472 |
+
) - 1
|
| 473 |
+
else:
|
| 474 |
+
st.sidebar.info("This model uses a proxy encoder overlay instead of true encoder attention.")
|
| 475 |
+
|
| 476 |
+
st.sidebar.caption("Decoder strips are rendered as temporal heat strips rather than projected back onto the frame.")
|
| 477 |
+
return config, spec
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _render_explainability_panel(target, payload: dict | None, *, enabled: bool, spec: ExplainabilitySpec, title: str) -> None:
|
| 481 |
+
with target.container():
|
| 482 |
+
st.markdown(f"### {title}")
|
| 483 |
+
if not enabled:
|
| 484 |
+
st.caption("Turn on the explainability toggle in the sidebar to inspect encoder heatmaps and decoder temporal strips.")
|
| 485 |
+
return
|
| 486 |
+
|
| 487 |
+
st.caption(f"Encoder default: {spec.encoder_label}")
|
| 488 |
+
st.caption(f"Decoder default: {spec.decoder_label}")
|
| 489 |
+
if payload is None:
|
| 490 |
+
st.info("Run image or video inference to populate this live explainability panel.")
|
| 491 |
+
return
|
| 492 |
+
|
| 493 |
+
layer_index = payload.get("encoder_layer")
|
| 494 |
+
head_index = payload.get("encoder_head")
|
| 495 |
+
encoder_caption = f"{payload['encoder_label']} ({payload['encoder_kind']})"
|
| 496 |
+
if layer_index is not None and head_index is not None:
|
| 497 |
+
encoder_caption += f" · layer {int(layer_index) + 1}, head {int(head_index) + 1}"
|
| 498 |
+
st.caption(encoder_caption)
|
| 499 |
+
st.image(payload["encoder_visualization"], use_container_width=True)
|
| 500 |
+
|
| 501 |
+
st.caption(f"{payload['decoder_label']} ({payload['decoder_kind']})")
|
| 502 |
+
st.image(payload["decoder_visualization"], use_container_width=True)
|
| 503 |
+
|
| 504 |
+
notes = payload.get("notes")
|
| 505 |
+
if notes:
|
| 506 |
+
st.caption(notes)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _analyse_video(
|
| 510 |
+
video_path: str | Path,
|
| 511 |
+
predictor,
|
| 512 |
+
frame_stride: int,
|
| 513 |
+
max_frames: int,
|
| 514 |
+
*,
|
| 515 |
+
explainability_config: dict | None = None,
|
| 516 |
+
explainability_callback=None,
|
| 517 |
+
):
|
| 518 |
temp_path = Path(video_path)
|
| 519 |
capture = cv2.VideoCapture(str(temp_path))
|
| 520 |
if not capture.isOpened():
|
|
|
|
| 529 |
records = []
|
| 530 |
processed = 0
|
| 531 |
frame_index = 0
|
| 532 |
+
explain_enabled = bool(explainability_config and explainability_config.get("enabled"))
|
| 533 |
|
| 534 |
try:
|
| 535 |
while True:
|
|
|
|
| 543 |
|
| 544 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 545 |
started = time.perf_counter()
|
| 546 |
+
result = predictor.predict(rgb, explainability=explainability_config if explain_enabled else None)
|
| 547 |
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
| 548 |
|
| 549 |
probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
|
|
|
|
| 563 |
records.append(record)
|
| 564 |
processed += 1
|
| 565 |
|
| 566 |
+
if explain_enabled and explainability_callback is not None:
|
| 567 |
+
explainability_callback(result.get("explainability"), processed, frame_index)
|
| 568 |
+
|
| 569 |
if total_frames > 0:
|
| 570 |
progress.progress(min(frame_index + 1, total_frames) / total_frames)
|
| 571 |
else:
|
|
|
|
| 584 |
return records, {"fps": fps, "total_frames": total_frames, "sampled_frames": processed}
|
| 585 |
|
| 586 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
def _render_single_result(result: dict):
|
| 588 |
probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
|
| 589 |
metrics = st.columns(3)
|
|
|
|
| 595 |
st.bar_chart(prob_df.set_index("phase"))
|
| 596 |
st.download_button(
|
| 597 |
label="Download JSON",
|
| 598 |
+
data=json.dumps(result, indent=2, default=str).encode("utf-8"),
|
| 599 |
file_name="phase_prediction.json",
|
| 600 |
mime="application/json",
|
| 601 |
key="download-single-json",
|
|
|
|
| 649 |
enabled_model_keys = _enabled_model_keys()
|
| 650 |
default_model_key = _default_model_key(enabled_model_keys)
|
| 651 |
manager = _get_model_manager()
|
| 652 |
+
_inject_app_styles()
|
| 653 |
+
_render_project_hub(enabled_model_keys)
|
| 654 |
+
previous_selected_model_key, model_key = _resolve_model_selection(enabled_model_keys, default_model_key)
|
| 655 |
|
| 656 |
+
_render_workspace_header(enabled_model_keys, model_key)
|
| 657 |
st.caption(_space_caption(enabled_model_keys))
|
| 658 |
|
| 659 |
+
st.session_state[SELECTED_MODEL_STATE_KEY] = model_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
if previous_selected_model_key is not None and previous_selected_model_key != model_key:
|
| 661 |
manager.unload_model()
|
| 662 |
|
| 663 |
+
explainability_config, explainability_spec = _build_explainability_config(manager, model_key)
|
| 664 |
+
|
| 665 |
source_summary = get_model_source_summary(model_key)
|
|
|
|
| 666 |
st.sidebar.markdown("### Runtime")
|
| 667 |
st.sidebar.write(f"Selected model: `{MODEL_LABELS[model_key]}`")
|
| 668 |
+
st.sidebar.caption(MODEL_LOAD_NOTES[model_key])
|
| 669 |
st.sidebar.write(f"CUDA available: `{torch.cuda.is_available()}`")
|
| 670 |
if torch.cuda.is_available():
|
| 671 |
st.sidebar.write(f"Device: `{torch.cuda.get_device_name(torch.cuda.current_device())}`")
|
|
|
|
| 680 |
st.sidebar.write(f"HF repo: `{source_summary['repo_id'] or 'local-only'}`")
|
| 681 |
if source_summary["subfolder"]:
|
| 682 |
st.sidebar.write(f"Repo subfolder: `{source_summary['subfolder']}`")
|
| 683 |
+
with st.sidebar.expander("Checkpoint requirements", expanded=False):
|
| 684 |
+
st.write(", ".join(source_summary["required_files"]))
|
| 685 |
+
if source_summary["optional_files"]:
|
| 686 |
+
st.caption("Optional: " + ", ".join(source_summary["optional_files"]))
|
| 687 |
st.sidebar.write(f"Video upload cap: `{STREAMLIT_SERVER_MAX_UPLOAD_MB} MB`")
|
| 688 |
st.sidebar.write(f"Working storage free: `{format_bytes(get_workspace_free_bytes())}`")
|
| 689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
prepare_col, unload_col = st.sidebar.columns(2)
|
| 691 |
if prepare_col.button("Load model", use_container_width=True):
|
| 692 |
try:
|
|
|
|
| 700 |
manager.unload_model()
|
| 701 |
st.sidebar.success("Model unloaded")
|
| 702 |
|
| 703 |
+
manager_status = manager.status()
|
| 704 |
+
if manager_status.is_loaded and manager_status.active_model_label:
|
| 705 |
+
st.sidebar.success(f"Loaded model: {manager_status.active_model_label}")
|
| 706 |
+
else:
|
| 707 |
+
st.sidebar.info("No model is currently loaded.")
|
| 708 |
+
if manager_status.last_error:
|
| 709 |
+
st.sidebar.error(manager_status.last_error)
|
| 710 |
+
|
| 711 |
image_tab, video_tab = st.tabs(["Image", "Video"])
|
| 712 |
|
| 713 |
with image_tab:
|
| 714 |
+
image_main_col, image_explain_col = st.columns([3, 2], gap="large")
|
| 715 |
+
image_explain_placeholder = image_explain_col.empty()
|
| 716 |
+
image_result = None
|
| 717 |
+
|
| 718 |
+
with image_main_col:
|
| 719 |
+
uploaded_image = st.file_uploader("Upload an RGB frame", type=["png", "jpg", "jpeg"], key="image-uploader")
|
| 720 |
+
if uploaded_image is not None:
|
| 721 |
+
rgb = _image_to_rgb(uploaded_image)
|
| 722 |
+
st.image(rgb, caption=uploaded_image.name, use_container_width=True)
|
| 723 |
+
if st.button("Run image inference", key="run-image"):
|
| 724 |
+
try:
|
| 725 |
+
with st.spinner(f"Running {MODEL_LABELS[model_key]} on {uploaded_image.name}..."):
|
| 726 |
+
predictor = manager.get_predictor(model_key)
|
| 727 |
+
predictor.reset_state()
|
| 728 |
+
started = time.perf_counter()
|
| 729 |
+
image_result = predictor.predict(
|
| 730 |
+
rgb,
|
| 731 |
+
explainability=explainability_config if explainability_config.get("enabled") else None,
|
| 732 |
+
)
|
| 733 |
+
image_result["inference_ms"] = round((time.perf_counter() - started) * 1000.0, 3)
|
| 734 |
+
predictor.reset_state()
|
| 735 |
+
except Exception as exc:
|
| 736 |
+
st.error(str(exc))
|
| 737 |
+
else:
|
| 738 |
+
_render_single_result(image_result)
|
| 739 |
+
|
| 740 |
+
_render_explainability_panel(
|
| 741 |
+
image_explain_placeholder,
|
| 742 |
+
image_result.get("explainability") if image_result else None,
|
| 743 |
+
enabled=bool(explainability_config.get("enabled")),
|
| 744 |
+
spec=explainability_spec,
|
| 745 |
+
title="Explainability",
|
| 746 |
+
)
|
| 747 |
|
| 748 |
with video_tab:
|
| 749 |
+
video_main_col, video_explain_col = st.columns([3, 2], gap="large")
|
| 750 |
+
video_explain_placeholder = video_explain_col.empty()
|
| 751 |
+
_render_explainability_panel(
|
| 752 |
+
video_explain_placeholder,
|
| 753 |
+
None,
|
| 754 |
+
enabled=bool(explainability_config.get("enabled")),
|
| 755 |
+
spec=explainability_spec,
|
| 756 |
+
title="Explainability",
|
|
|
|
|
|
|
|
|
|
| 757 |
)
|
| 758 |
+
|
| 759 |
+
with video_main_col:
|
| 760 |
+
frame_stride = st.slider("Analyze every Nth frame", min_value=1, max_value=30, value=5, step=1)
|
| 761 |
+
max_frames = st.slider("Maximum sampled frames", min_value=10, max_value=600, value=180, step=10)
|
| 762 |
+
uploaded_video = st.file_uploader(
|
| 763 |
+
"Upload a video (MP4 preferred)",
|
| 764 |
+
type=SUPPORTED_VIDEO_TYPES,
|
| 765 |
+
key="video-uploader",
|
| 766 |
+
help=(
|
| 767 |
+
f"Single-file uploads are enabled up to {STREAMLIT_SERVER_MAX_UPLOAD_MB} MB. "
|
| 768 |
+
"MP4 is preferred; MOV/AVI/MKV/WEBM/M4V stay enabled as fallback containers."
|
| 769 |
+
),
|
| 770 |
+
max_upload_size=STREAMLIT_SERVER_MAX_UPLOAD_MB,
|
| 771 |
+
)
|
| 772 |
+
if uploaded_video is not None:
|
| 773 |
+
try:
|
| 774 |
+
temp_path, video_meta = _prepare_staged_video(uploaded_video)
|
| 775 |
+
except Exception as exc:
|
| 776 |
+
st.error(str(exc))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
else:
|
| 778 |
+
info_cols = st.columns(5)
|
| 779 |
+
info_cols[0].metric("File size", video_meta["file_size_label"])
|
| 780 |
+
info_cols[1].metric("Duration", video_meta["duration_label"])
|
| 781 |
+
info_cols[2].metric("FPS", f"{video_meta.get('fps', 0.0):.2f}" if video_meta.get("fps") else "Unknown")
|
| 782 |
+
info_cols[3].metric("Frames", int(video_meta.get("frame_count", 0)))
|
| 783 |
+
info_cols[4].metric("Resolution", video_meta["resolution_label"])
|
| 784 |
+
if video_meta.get("format_name"):
|
| 785 |
+
st.caption(f"Container detected by ffprobe: {video_meta['format_name']}")
|
| 786 |
+
|
| 787 |
+
recommended_stride = recommended_frame_stride(video_meta.get("duration_seconds"))
|
| 788 |
+
st.caption(
|
| 789 |
+
f"Recommended frame stride for this video: every {recommended_stride} frame(s). "
|
| 790 |
+
"Use higher values for very long videos to keep analysis times reasonable."
|
| 791 |
)
|
| 792 |
|
| 793 |
+
if should_show_inline_preview(video_meta["file_size_bytes"]):
|
| 794 |
+
st.video(uploaded_video)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
else:
|
| 796 |
+
st.info(
|
| 797 |
+
"Inline preview is disabled for uploads larger than "
|
| 798 |
+
"256 MB to avoid pushing very large media back through the browser. "
|
| 799 |
+
"The staged video on disk is still used for analysis."
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
if st.button("Analyze video", key="run-video"):
|
| 803 |
+
latest_payload = {"value": None}
|
| 804 |
+
|
| 805 |
+
def _video_explainability_callback(payload, processed_count: int, current_frame_index: int):
|
| 806 |
+
latest_payload["value"] = payload
|
| 807 |
+
_render_explainability_panel(
|
| 808 |
+
video_explain_placeholder,
|
| 809 |
+
payload,
|
| 810 |
+
enabled=True,
|
| 811 |
+
spec=explainability_spec,
|
| 812 |
+
title=f"Live explainability · sampled frame {processed_count}",
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
try:
|
| 816 |
+
with st.spinner(f"Running {MODEL_LABELS[model_key]} on {uploaded_video.name}..."):
|
| 817 |
+
predictor = manager.get_predictor(model_key)
|
| 818 |
+
records, analysis_meta = _analyse_video(
|
| 819 |
+
temp_path,
|
| 820 |
+
predictor,
|
| 821 |
+
frame_stride=frame_stride,
|
| 822 |
+
max_frames=max_frames,
|
| 823 |
+
explainability_config=explainability_config if explainability_config.get("enabled") else None,
|
| 824 |
+
explainability_callback=(
|
| 825 |
+
_video_explainability_callback
|
| 826 |
+
if explainability_config.get("enabled")
|
| 827 |
+
else None
|
| 828 |
+
),
|
| 829 |
+
)
|
| 830 |
+
meta = {
|
| 831 |
+
**video_meta,
|
| 832 |
+
**analysis_meta,
|
| 833 |
+
}
|
| 834 |
+
except Exception as exc:
|
| 835 |
+
st.error(str(exc))
|
| 836 |
+
else:
|
| 837 |
+
_render_video_results(records, meta)
|
| 838 |
+
if explainability_config.get("enabled"):
|
| 839 |
+
_render_explainability_panel(
|
| 840 |
+
video_explain_placeholder,
|
| 841 |
+
latest_payload["value"],
|
| 842 |
+
enabled=True,
|
| 843 |
+
spec=explainability_spec,
|
| 844 |
+
title="Explainability",
|
| 845 |
+
)
|
| 846 |
+
else:
|
| 847 |
+
_clear_video_stage()
|
| 848 |
|
| 849 |
|
| 850 |
if __name__ == "__main__":
|
dinov2/.github/workflows/lint.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Lint
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
pull_request:
|
| 8 |
+
branches:
|
| 9 |
+
- main
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
run-linters:
|
| 13 |
+
name: Run linters
|
| 14 |
+
runs-on: ubuntu-20.04
|
| 15 |
+
|
| 16 |
+
steps:
|
| 17 |
+
- name: Checkout repository
|
| 18 |
+
uses: actions/checkout@v3
|
| 19 |
+
- name: Set up Python
|
| 20 |
+
uses: actions/setup-python@v4
|
| 21 |
+
with:
|
| 22 |
+
python-version: 3.9
|
| 23 |
+
cache: 'pip'
|
| 24 |
+
cache-dependency-path: '**/requirements*.txt'
|
| 25 |
+
- name: Install Python (development) dependencies
|
| 26 |
+
run: |
|
| 27 |
+
pip install -r requirements-dev.txt
|
| 28 |
+
- name: Run flake8
|
| 29 |
+
run: |
|
| 30 |
+
flake8
|
| 31 |
+
- name: Run black
|
| 32 |
+
if: always()
|
| 33 |
+
run: |
|
| 34 |
+
black --check dinov2
|
| 35 |
+
- name: Run pylint
|
| 36 |
+
if: always()
|
| 37 |
+
run: |
|
| 38 |
+
pylint --exit-zero dinov2
|
dinov2/.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build/
|
| 2 |
+
dist/
|
| 3 |
+
*.egg-info/
|
| 4 |
+
**/__pycache__/
|
| 5 |
+
|
| 6 |
+
**/.ipynb_checkpoints
|
| 7 |
+
**/.ipynb_checkpoints/**
|
| 8 |
+
|
| 9 |
+
*.swp
|
| 10 |
+
|
| 11 |
+
.vscode/
|
dinov2/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
dinov2/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to DINOv2
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Meta's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to DINOv2, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
dinov2/LICENSE
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
Apache License
|
| 4 |
+
Version 2.0, January 2004
|
| 5 |
+
http://www.apache.org/licenses/
|
| 6 |
+
|
| 7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 8 |
+
|
| 9 |
+
1. Definitions.
|
| 10 |
+
|
| 11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 13 |
+
|
| 14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 15 |
+
the copyright owner that is granting the License.
|
| 16 |
+
|
| 17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 18 |
+
other entities that control, are controlled by, or are under common
|
| 19 |
+
control with that entity. For the purposes of this definition,
|
| 20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 21 |
+
direction or management of such entity, whether by contract or
|
| 22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 24 |
+
|
| 25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 26 |
+
exercising permissions granted by this License.
|
| 27 |
+
|
| 28 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 29 |
+
including but not limited to software source code, documentation
|
| 30 |
+
source, and configuration files.
|
| 31 |
+
|
| 32 |
+
"Object" form shall mean any form resulting from mechanical
|
| 33 |
+
transformation or translation of a Source form, including but
|
| 34 |
+
not limited to compiled object code, generated documentation,
|
| 35 |
+
and conversions to other media types.
|
| 36 |
+
|
| 37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 38 |
+
Object form, made available under the License, as indicated by a
|
| 39 |
+
copyright notice that is included in or attached to the work
|
| 40 |
+
(an example is provided in the Appendix below).
|
| 41 |
+
|
| 42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 43 |
+
form, that is based on (or derived from) the Work and for which the
|
| 44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 46 |
+
of this License, Derivative Works shall not include works that remain
|
| 47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 48 |
+
the Work and Derivative Works thereof.
|
| 49 |
+
|
| 50 |
+
"Contribution" shall mean any work of authorship, including
|
| 51 |
+
the original version of the Work and any modifications or additions
|
| 52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 56 |
+
means any form of electronic, verbal, or written communication sent
|
| 57 |
+
to the Licensor or its representatives, including but not limited to
|
| 58 |
+
communication on electronic mailing lists, source code control systems,
|
| 59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 61 |
+
excluding communication that is conspicuously marked or otherwise
|
| 62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 63 |
+
|
| 64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 66 |
+
subsequently incorporated within the Work.
|
| 67 |
+
|
| 68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 73 |
+
Work and such Derivative Works in Source or Object form.
|
| 74 |
+
|
| 75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 78 |
+
(except as stated in this section) patent license to make, have made,
|
| 79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 80 |
+
where such license applies only to those patent claims licensable
|
| 81 |
+
by such Contributor that are necessarily infringed by their
|
| 82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 84 |
+
institute patent litigation against any entity (including a
|
| 85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 86 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 87 |
+
or contributory patent infringement, then any patent licenses
|
| 88 |
+
granted to You under this License for that Work shall terminate
|
| 89 |
+
as of the date such litigation is filed.
|
| 90 |
+
|
| 91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 92 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 93 |
+
modifications, and in Source or Object form, provided that You
|
| 94 |
+
meet the following conditions:
|
| 95 |
+
|
| 96 |
+
(a) You must give any other recipients of the Work or
|
| 97 |
+
Derivative Works a copy of this License; and
|
| 98 |
+
|
| 99 |
+
(b) You must cause any modified files to carry prominent notices
|
| 100 |
+
stating that You changed the files; and
|
| 101 |
+
|
| 102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 103 |
+
that You distribute, all copyright, patent, trademark, and
|
| 104 |
+
attribution notices from the Source form of the Work,
|
| 105 |
+
excluding those notices that do not pertain to any part of
|
| 106 |
+
the Derivative Works; and
|
| 107 |
+
|
| 108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 109 |
+
distribution, then any Derivative Works that You distribute must
|
| 110 |
+
include a readable copy of the attribution notices contained
|
| 111 |
+
within such NOTICE file, excluding those notices that do not
|
| 112 |
+
pertain to any part of the Derivative Works, in at least one
|
| 113 |
+
of the following places: within a NOTICE text file distributed
|
| 114 |
+
as part of the Derivative Works; within the Source form or
|
| 115 |
+
documentation, if provided along with the Derivative Works; or,
|
| 116 |
+
within a display generated by the Derivative Works, if and
|
| 117 |
+
wherever such third-party notices normally appear. The contents
|
| 118 |
+
of the NOTICE file are for informational purposes only and
|
| 119 |
+
do not modify the License. You may add Your own attribution
|
| 120 |
+
notices within Derivative Works that You distribute, alongside
|
| 121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 122 |
+
that such additional attribution notices cannot be construed
|
| 123 |
+
as modifying the License.
|
| 124 |
+
|
| 125 |
+
You may add Your own copyright statement to Your modifications and
|
| 126 |
+
may provide additional or different license terms and conditions
|
| 127 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 128 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 129 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 130 |
+
the conditions stated in this License.
|
| 131 |
+
|
| 132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 134 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 135 |
+
this License, without any additional terms or conditions.
|
| 136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 137 |
+
the terms of any separate license agreement you may have executed
|
| 138 |
+
with Licensor regarding such Contributions.
|
| 139 |
+
|
| 140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 142 |
+
except as required for reasonable and customary use in describing the
|
| 143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 144 |
+
|
| 145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 146 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 149 |
+
implied, including, without limitation, any warranties or conditions
|
| 150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 152 |
+
appropriateness of using or redistributing the Work and assume any
|
| 153 |
+
risks associated with Your exercise of permissions under this License.
|
| 154 |
+
|
| 155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 156 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 157 |
+
unless required by applicable law (such as deliberate and grossly
|
| 158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 159 |
+
liable to You for damages, including any direct, indirect, special,
|
| 160 |
+
incidental, or consequential damages of any character arising as a
|
| 161 |
+
result of this License or out of the use or inability to use the
|
| 162 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 163 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 164 |
+
other commercial damages or losses), even if such Contributor
|
| 165 |
+
has been advised of the possibility of such damages.
|
| 166 |
+
|
| 167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 170 |
+
or other liability obligations and/or rights consistent with this
|
| 171 |
+
License. However, in accepting such obligations, You may act only
|
| 172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 173 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 174 |
+
defend, and hold each Contributor harmless for any liability
|
| 175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 176 |
+
of your accepting any such warranty or additional liability.
|
| 177 |
+
|
| 178 |
+
END OF TERMS AND CONDITIONS
|
| 179 |
+
|
| 180 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 181 |
+
|
| 182 |
+
To apply the Apache License to your work, attach the following
|
| 183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 184 |
+
replaced with your own identifying information. (Don't include
|
| 185 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 186 |
+
comment syntax for the file format. We also recommend that a
|
| 187 |
+
file or class name and description of purpose be included on the
|
| 188 |
+
same "printed page" as the copyright notice for easier
|
| 189 |
+
identification within third-party archives.
|
| 190 |
+
|
| 191 |
+
Copyright [yyyy] [name of copyright owner]
|
| 192 |
+
|
| 193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 194 |
+
you may not use this file except in compliance with the License.
|
| 195 |
+
You may obtain a copy of the License at
|
| 196 |
+
|
| 197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 198 |
+
|
| 199 |
+
Unless required by applicable law or agreed to in writing, software
|
| 200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 202 |
+
See the License for the specific language governing permissions and
|
| 203 |
+
limitations under the License.
|
dinov2/MODEL_CARD.md
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Card for DINOv2-S/B/L/g
|
| 2 |
+
|
| 3 |
+
These are Vision Transformer models trained following the method described in the papers:
|
| 4 |
+
"DINOv2: Learning Robust Visual Features without Supervision"
|
| 5 |
+
and
|
| 6 |
+
"Vision Transformers Need Registers".
|
| 7 |
+
|
| 8 |
+
We provide 8 models:
|
| 9 |
+
- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, without registers.
|
| 10 |
+
- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, with registers.
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
The model takes an image as input and returns a class token and patch tokens, and optionally 4 register tokens.
|
| 14 |
+
|
| 15 |
+
The embedding dimension is:
|
| 16 |
+
- 384 for ViT-S.
|
| 17 |
+
- 768 for ViT-B.
|
| 18 |
+
- 1024 for ViT-L.
|
| 19 |
+
- 1536 for ViT-g.
|
| 20 |
+
|
| 21 |
+
The models follow a Transformer architecture, with a patch size of 14. In the case of registers, we add 4 register tokens, learned during training, to the input sequence after the patch embedding.
|
| 22 |
+
|
| 23 |
+
For a 224x224 image, this results in 1 class token + 256 patch tokens, and optionally 4 register tokens.
|
| 24 |
+
|
| 25 |
+
The models can accept larger images provided the image shapes are multiples of the patch size (14).
|
| 26 |
+
If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
|
| 27 |
+
|
| 28 |
+
### Model Description
|
| 29 |
+
|
| 30 |
+
- **Developed by:** Meta AI
|
| 31 |
+
- **Model type:** Vision Transformer
|
| 32 |
+
- **License:** Apache License 2.0
|
| 33 |
+
|
| 34 |
+
- **Repository:** https://github.com/facebookresearch/dinov2
|
| 35 |
+
- **Paper:** https://arxiv.org/abs/2304.07193
|
| 36 |
+
- **Demo:** https://dinov2.metademolab.com/
|
| 37 |
+
|
| 38 |
+
## Uses
|
| 39 |
+
|
| 40 |
+
The models are vision backbones providing multi-purpose features for downstream tasks.
|
| 41 |
+
|
| 42 |
+
### Direct Use
|
| 43 |
+
|
| 44 |
+
The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
|
| 45 |
+
- on depth estimation, semantic segmentation, using linear layers.
|
| 46 |
+
- on image classification, using k-NN classifiers on the class token.
|
| 47 |
+
- on image classification, with logistic regression classifiers applied on the class token.
|
| 48 |
+
- on image classification, with a linear layer applied on the class token and the average of the patch tokens.
|
| 49 |
+
- on image retrieval using nearest neighbors.
|
| 50 |
+
|
| 51 |
+
### Downstream Use
|
| 52 |
+
|
| 53 |
+
It is technically possible to perform fine-tuning on the models, for small gains (we measured +2% on ImageNet-1k classification).
|
| 54 |
+
We recommend keeping this as a very last step and only when necessary, as the features already provide good performance out-of-the-box.
|
| 55 |
+
|
| 56 |
+
## Bias, Risks, and Limitations
|
| 57 |
+
|
| 58 |
+
Despite improvements thanks to the training method not using annotations, we still observe significant biases in our models toward rich households from Western countries.
|
| 59 |
+
|
| 60 |
+
### Recommendations
|
| 61 |
+
|
| 62 |
+
We expect fine-tuning will increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
|
| 63 |
+
|
| 64 |
+
## How to Get Started with the Model
|
| 65 |
+
|
| 66 |
+
Use the code below to get started with the model.
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
import torch
|
| 70 |
+
|
| 71 |
+
# DINOv2
|
| 72 |
+
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
| 73 |
+
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
|
| 74 |
+
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
|
| 75 |
+
dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
|
| 76 |
+
|
| 77 |
+
# DINOv2 with registers
|
| 78 |
+
dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
|
| 79 |
+
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
|
| 80 |
+
dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
|
| 81 |
+
dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## Training Details
|
| 85 |
+
|
| 86 |
+
### Training Data
|
| 87 |
+
|
| 88 |
+
- **Training data:** LVD-142M (see paper)
|
| 89 |
+
- **Training regime:** fp16 using PyTorch-FSDP mixed-precision.
|
| 90 |
+
|
| 91 |
+
### Training Procedure
|
| 92 |
+
|
| 93 |
+
- **Training objective:**
|
| 94 |
+
- DINO self-distillation loss with multi-crop
|
| 95 |
+
- iBOT masked-image modeling loss
|
| 96 |
+
- KoLeo regularization on [CLS] tokens
|
| 97 |
+
- **Architectures:**
|
| 98 |
+
- ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN
|
| 99 |
+
- ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN
|
| 100 |
+
- ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN
|
| 101 |
+
- ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN
|
| 102 |
+
- **Distillation:**
|
| 103 |
+
- Distillation follows the standard DINOv2 pretraining procedure, except the teacher is a pretrained ViT-g, frozen.
|
| 104 |
+
|
| 105 |
+
## Evaluation
|
| 106 |
+
|
| 107 |
+
We refer users to the associated papers for the evaluation protocols.
|
| 108 |
+
|
| 109 |
+
<table>
|
| 110 |
+
<tr>
|
| 111 |
+
<th colspan="2"></th>
|
| 112 |
+
<th colspan="3">ImageNet-1k</th>
|
| 113 |
+
<th>NYU-Depth v2</th>
|
| 114 |
+
<th>SUN-RGBD</th>
|
| 115 |
+
<th>ADE20k</th>
|
| 116 |
+
<th>iNaturalist 2018</th>
|
| 117 |
+
<th>Oxford-H</th>
|
| 118 |
+
</tr>
|
| 119 |
+
<tr>
|
| 120 |
+
<th rowspan="2">model</th>
|
| 121 |
+
<th rowspan="2">with <br /> registers</th>
|
| 122 |
+
<th>classif. (acc)</th>
|
| 123 |
+
<th>classif. (acc)</th>
|
| 124 |
+
<th>classif. V2 (acc)</th>
|
| 125 |
+
<th>depth (RMSE)</th>
|
| 126 |
+
<th>depth (RMSE)</th>
|
| 127 |
+
<th>segm. (mAP)</th>
|
| 128 |
+
<th>classif. (acc)</th>
|
| 129 |
+
<th>retrieval (mAP)</th>
|
| 130 |
+
</tr>
|
| 131 |
+
<tr>
|
| 132 |
+
<!-- <th>^</th> -->
|
| 133 |
+
<th>k-NN</th>
|
| 134 |
+
<th>linear</th>
|
| 135 |
+
<th>linear</th>
|
| 136 |
+
<th>linear<br />4 layers</th>
|
| 137 |
+
<th>NYU-D transfer</th>
|
| 138 |
+
<th>multiscale</th>
|
| 139 |
+
<th>linear</th>
|
| 140 |
+
<th>nearest neighbor</th>
|
| 141 |
+
</tr>
|
| 142 |
+
<tr>
|
| 143 |
+
<td>ViT-S/14</td>
|
| 144 |
+
<td align="center">:x:</td>
|
| 145 |
+
<td align="right">79.0%</td>
|
| 146 |
+
<td align="right">81.1%</td>
|
| 147 |
+
<td align="right">70.8%</td>
|
| 148 |
+
<td align="right">0.417</td>
|
| 149 |
+
<td align="right">0.431</td>
|
| 150 |
+
<td align="right">47.2</td>
|
| 151 |
+
<td align="right">69.5%</td>
|
| 152 |
+
<td align="right">43.2</td>
|
| 153 |
+
</tr>
|
| 154 |
+
<tr>
|
| 155 |
+
<td>ViT-S/14</td>
|
| 156 |
+
<td align="center">:white_check_mark:</td>
|
| 157 |
+
<td align="right">79.1%</td>
|
| 158 |
+
<td align="right">80.9%</td>
|
| 159 |
+
<td align="right">71.0%</td>
|
| 160 |
+
<td align="right">N/A</td>
|
| 161 |
+
<td align="right">N/A</td>
|
| 162 |
+
<td align="right">N/A</td>
|
| 163 |
+
<td align="right">67.6%</td>
|
| 164 |
+
<td align="right">39.5</td>
|
| 165 |
+
</tr>
|
| 166 |
+
<tr>
|
| 167 |
+
<td>ViT-B/14</td>
|
| 168 |
+
<td align="center">:x:</td>
|
| 169 |
+
<td align="right">82.1%</td>
|
| 170 |
+
<td align="right">84.5%</td>
|
| 171 |
+
<td align="right">74.9%</td>
|
| 172 |
+
<td align="right">0.362</td>
|
| 173 |
+
<td align="right">0.400</td>
|
| 174 |
+
<td align="right">51.3</td>
|
| 175 |
+
<td align="right">76.3%</td>
|
| 176 |
+
<td align="right">49.5</td>
|
| 177 |
+
</tr>
|
| 178 |
+
<td>ViT-B/14</td>
|
| 179 |
+
<td align="center">:white_check_mark:</td>
|
| 180 |
+
<td align="right">82.0%</td>
|
| 181 |
+
<td align="right">84.6%</td>
|
| 182 |
+
<td align="right">75.6%</td>
|
| 183 |
+
<td align="right">N/A</td>
|
| 184 |
+
<td align="right">N/A</td>
|
| 185 |
+
<td align="right">N/A</td>
|
| 186 |
+
<td align="right">73.8%</td>
|
| 187 |
+
<td align="right">51.0</td>
|
| 188 |
+
</tr>
|
| 189 |
+
<tr>
|
| 190 |
+
<td>ViT-L/14</td>
|
| 191 |
+
<td align="center">:x:</td>
|
| 192 |
+
<td align="right">83.5%</td>
|
| 193 |
+
<td align="right">86.3%</td>
|
| 194 |
+
<td align="right">77.6%</td>
|
| 195 |
+
<td align="right">0.333</td>
|
| 196 |
+
<td align="right">0.396</td>
|
| 197 |
+
<td align="right">53.1</td>
|
| 198 |
+
<td align="right">79.8%</td>
|
| 199 |
+
<td align="right">54.0</td>
|
| 200 |
+
</tr>
|
| 201 |
+
<tr>
|
| 202 |
+
<td>ViT-L/14</td>
|
| 203 |
+
<td align="center">:white_check_mark:</td>
|
| 204 |
+
<td align="right">83.8%</td>
|
| 205 |
+
<td align="right">86.7%</td>
|
| 206 |
+
<td align="right">78.5%</td>
|
| 207 |
+
<td align="right">N/A</td>
|
| 208 |
+
<td align="right">N/A</td>
|
| 209 |
+
<td align="right">N/A</td>
|
| 210 |
+
<td align="right">80.9%</td>
|
| 211 |
+
<td align="right">55.7</td>
|
| 212 |
+
</tr>
|
| 213 |
+
<tr>
|
| 214 |
+
<td>ViT-g/14</td>
|
| 215 |
+
<td align="center">:x:</td>
|
| 216 |
+
<td align="right">83.5%</td>
|
| 217 |
+
<td align="right">86.5%</td>
|
| 218 |
+
<td align="right">78.4%</td>
|
| 219 |
+
<td align="right">0.298</td>
|
| 220 |
+
<td align="right">0.362</td>
|
| 221 |
+
<td align="right">53.0</td>
|
| 222 |
+
<td align="right">81.6%</td>
|
| 223 |
+
<td align="right">52.3</td>
|
| 224 |
+
</tr>
|
| 225 |
+
<tr>
|
| 226 |
+
<tr>
|
| 227 |
+
<td>ViT-g/14</td>
|
| 228 |
+
<td align="center">:white_check_mark:</td>
|
| 229 |
+
<td align="right">83.7%</td>
|
| 230 |
+
<td align="right">87.1%</td>
|
| 231 |
+
<td align="right">78.8%</td>
|
| 232 |
+
<td align="right">N/A</td>
|
| 233 |
+
<td align="right">N/A</td>
|
| 234 |
+
<td align="right">N/A</td>
|
| 235 |
+
<td align="right">81.5%</td>
|
| 236 |
+
<td align="right">58.2</td>
|
| 237 |
+
</tr>
|
| 238 |
+
</table>
|
| 239 |
+
|
| 240 |
+
## Environmental Impact
|
| 241 |
+
|
| 242 |
+
- **Hardware Type:** Nvidia A100
|
| 243 |
+
- **Hours used:** 22,000 for ViT-g, 4,500 for ViT-S distillation, 5,300 for ViT-B distillation, 8,000 for ViT-L distillation
|
| 244 |
+
- **Cloud Provider:** Private infra
|
| 245 |
+
- **Compute Region:** USA
|
| 246 |
+
- **Carbon Emitted:** 7t CO2eq
|
| 247 |
+
|
| 248 |
+
#### Hardware
|
| 249 |
+
|
| 250 |
+
Nvidia A100 GPUs
|
| 251 |
+
|
| 252 |
+
#### Software
|
| 253 |
+
|
| 254 |
+
PyTorch 2.0,
|
| 255 |
+
xFormers 0.0.18
|
| 256 |
+
|
| 257 |
+
**BibTeX**
|
| 258 |
+
|
| 259 |
+
```
|
| 260 |
+
@misc{oquab2023dinov2,
|
| 261 |
+
title={DINOv2: Learning Robust Visual Features without Supervision},
|
| 262 |
+
author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
|
| 263 |
+
journal={arXiv:2304.07193},
|
| 264 |
+
year={2023}
|
| 265 |
+
}
|
| 266 |
+
@misc{darcet2023vitneedreg,
|
| 267 |
+
title={Vision Transformers Need Registers},
|
| 268 |
+
author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
|
| 269 |
+
journal={arXiv:2309.16588},
|
| 270 |
+
year={2023}
|
| 271 |
+
}
|
| 272 |
+
```
|
dinov2/README.md
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:new: [2023-10-26] *Added DINOv2 backbones with registers, following [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588).*
|
| 2 |
+
|
| 3 |
+
# DINOv2: Learning Robust Visual Features without Supervision
|
| 4 |
+
|
| 5 |
+
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
|
| 6 |
+
|
| 7 |
+
Maxime Oquab,
|
| 8 |
+
Timothée Darcet,
|
| 9 |
+
Théo Moutakanni,
|
| 10 |
+
Huy V. Vo,
|
| 11 |
+
Marc Szafraniec,
|
| 12 |
+
Vasil Khalidov,
|
| 13 |
+
Patrick Labatut,
|
| 14 |
+
Armand Joulin,
|
| 15 |
+
Piotr Bojanowski
|
| 16 |
+
|
| 17 |
+
[[`Paper #1`](https://arxiv.org/abs/2304.07193)] [`Paper #2`](https://arxiv.org/abs/2309.16588)] [[`Blog`](https://ai.facebook.com/blog/dino-v2-computer-vision-self-supervised-learning/)] [[`Demo`](https://dinov2.metademolab.com)] [[`BibTeX`](#citing-dinov2)]
|
| 18 |
+
|
| 19 |
+
PyTorch implementation and pretrained models for DINOv2. For details, see the papers: **[DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)** and **[Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)**.
|
| 20 |
+
|
| 21 |
+
DINOv2 models produce high-performance visual features that can be directly employed with classifiers as simple as linear layers on a variety of computer vision tasks; these visual features are robust and perform well across domains without any requirement for fine-tuning. The models were pretrained on a dataset of 142 M images without using any labels or annotations.
|
| 22 |
+
|
| 23 |
+
https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b429-578badf5c356
|
| 24 |
+
|
| 25 |
+
<div align="center">
|
| 26 |
+
Visualization of the three first principal components of the patch features of all frames, mapped to RGB values.
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
## Pretrained models
|
| 30 |
+
|
| 31 |
+
<table style="margin: auto">
|
| 32 |
+
<thead>
|
| 33 |
+
<tr>
|
| 34 |
+
<th>model</th>
|
| 35 |
+
<th># of<br />params</th>
|
| 36 |
+
<th>with<br />registers</th>
|
| 37 |
+
<th>ImageNet<br />k-NN</th>
|
| 38 |
+
<th>ImageNet<br />linear</th>
|
| 39 |
+
<th>download</th>
|
| 40 |
+
</tr>
|
| 41 |
+
</thead>
|
| 42 |
+
<tbody>
|
| 43 |
+
<tr>
|
| 44 |
+
<td>ViT-S/14 distilled</td>
|
| 45 |
+
<td align="right">21 M</td>
|
| 46 |
+
<td align="center">:x:</td>
|
| 47 |
+
<td align="right">79.0%</td>
|
| 48 |
+
<td align="right">81.1%</td>
|
| 49 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth">backbone only</a></td>
|
| 50 |
+
</tr>
|
| 51 |
+
<tr>
|
| 52 |
+
<td>ViT-S/14 distilled</td>
|
| 53 |
+
<td align="right">21 M</td>
|
| 54 |
+
<td align="center">:white_check_mark:</td>
|
| 55 |
+
<td align="right">79.1%</td>
|
| 56 |
+
<td align="right">80.9%</td>
|
| 57 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth">backbone only</a></td>
|
| 58 |
+
</tr>
|
| 59 |
+
<tr>
|
| 60 |
+
<td>ViT-B/14 distilled</td>
|
| 61 |
+
<td align="right">86 M</td>
|
| 62 |
+
<td align="center">:x:</td>
|
| 63 |
+
<td align="right">82.1%</td>
|
| 64 |
+
<td align="right">84.5%</td>
|
| 65 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth">backbone only</a></td>
|
| 66 |
+
</tr>
|
| 67 |
+
<tr>
|
| 68 |
+
<td>ViT-B/14 distilled</td>
|
| 69 |
+
<td align="right">86 M</td>
|
| 70 |
+
<td align="center">:white_check_mark:</td>
|
| 71 |
+
<td align="right">82.0%</td>
|
| 72 |
+
<td align="right">84.6%</td>
|
| 73 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth">backbone only</a></td>
|
| 74 |
+
</tr>
|
| 75 |
+
<tr>
|
| 76 |
+
<td>ViT-L/14 distilled</td>
|
| 77 |
+
<td align="right">300 M</td>
|
| 78 |
+
<td align="center">:x:</td>
|
| 79 |
+
<td align="right">83.5%</td>
|
| 80 |
+
<td align="right">86.3%</td>
|
| 81 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth">backbone only</a></td>
|
| 82 |
+
</tr>
|
| 83 |
+
<tr>
|
| 84 |
+
<td>ViT-L/14 distilled</td>
|
| 85 |
+
<td align="right">300 M</td>
|
| 86 |
+
<td align="center">:white_check_mark:</td>
|
| 87 |
+
<td align="right">83.8%</td>
|
| 88 |
+
<td align="right">86.7%</td>
|
| 89 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth">backbone only</a></td>
|
| 90 |
+
</tr>
|
| 91 |
+
<tr>
|
| 92 |
+
<td>ViT-g/14</td>
|
| 93 |
+
<td align="right">1,100 M</td>
|
| 94 |
+
<td align="center">:x:</td>
|
| 95 |
+
<td align="right">83.5%</td>
|
| 96 |
+
<td align="right">86.5%</td>
|
| 97 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth">backbone only</a></td>
|
| 98 |
+
</tr>
|
| 99 |
+
<tr>
|
| 100 |
+
<td>ViT-g/14</td>
|
| 101 |
+
<td align="right">1,100 M</td>
|
| 102 |
+
<td align="center">:white_check_mark:</td>
|
| 103 |
+
<td align="right">83.7%</td>
|
| 104 |
+
<td align="right">87.1%</td>
|
| 105 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth">backbone only</a></td>
|
| 106 |
+
</tr>
|
| 107 |
+
</tbody>
|
| 108 |
+
</table>
|
| 109 |
+
|
| 110 |
+
### Pretrained backbones (via PyTorch Hub)
|
| 111 |
+
|
| 112 |
+
Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended.
|
| 113 |
+
|
| 114 |
+
A corresponding [model card](MODEL_CARD.md) is included in the repository.
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
import torch
|
| 118 |
+
|
| 119 |
+
# DINOv2
|
| 120 |
+
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
| 121 |
+
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
|
| 122 |
+
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
|
| 123 |
+
dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
|
| 124 |
+
|
| 125 |
+
# DINOv2 with registers
|
| 126 |
+
dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
|
| 127 |
+
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
|
| 128 |
+
dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
|
| 129 |
+
dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Pretrained heads - Image classification
|
| 133 |
+
|
| 134 |
+
<table style="margin: auto">
|
| 135 |
+
<thead>
|
| 136 |
+
<tr>
|
| 137 |
+
<th rowspan="2">backbone</th>
|
| 138 |
+
<th rowspan="2">with<br />registers</th>
|
| 139 |
+
<th>download</th>
|
| 140 |
+
</tr>
|
| 141 |
+
<tr>
|
| 142 |
+
<th>ImageNet</th>
|
| 143 |
+
</tr>
|
| 144 |
+
</thead>
|
| 145 |
+
<tbody>
|
| 146 |
+
<tr>
|
| 147 |
+
<td>ViT-S/14 distilled</td>
|
| 148 |
+
<td align="center">:x:</td>
|
| 149 |
+
<td>
|
| 150 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">1 layer</a>,
|
| 151 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear4_head.pth">4 layers</a>)
|
| 152 |
+
</td>
|
| 153 |
+
</tr>
|
| 154 |
+
<tr>
|
| 155 |
+
<td>ViT-S/14 distilled</td>
|
| 156 |
+
<td align="center">:white_check_mark:</td>
|
| 157 |
+
<td>
|
| 158 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">1 layer</a>,
|
| 159 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear4_head.pth">4 layers</a>)
|
| 160 |
+
</td>
|
| 161 |
+
</tr>
|
| 162 |
+
<tr>
|
| 163 |
+
<td>ViT-B/14 distilled</td>
|
| 164 |
+
<td align="center">:x:</td>
|
| 165 |
+
<td>
|
| 166 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>,
|
| 167 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear4_head.pth">4 layers</a>)
|
| 168 |
+
</tr>
|
| 169 |
+
<tr>
|
| 170 |
+
<td>ViT-B/14 distilled</td>
|
| 171 |
+
<td align="center">:white_check_mark:</td>
|
| 172 |
+
<td>
|
| 173 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">1 layer</a>,
|
| 174 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear4_head.pth">4 layers</a>)
|
| 175 |
+
</tr>
|
| 176 |
+
<tr>
|
| 177 |
+
<td>ViT-L/14 distilled</td>
|
| 178 |
+
<td align="center">:x:</td>
|
| 179 |
+
<td>
|
| 180 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>,
|
| 181 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear4_head.pth">4 layers</a>)
|
| 182 |
+
</tr>
|
| 183 |
+
<tr>
|
| 184 |
+
<td>ViT-L/14 distilled</td>
|
| 185 |
+
<td align="center">:white_check_mark:</td>
|
| 186 |
+
<td>
|
| 187 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">1 layer</a>,
|
| 188 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear4_head.pth">4 layers</a>)
|
| 189 |
+
</tr>
|
| 190 |
+
<tr>
|
| 191 |
+
<td>ViT-g/14</td>
|
| 192 |
+
<td align="center">:x:</td>
|
| 193 |
+
<td>
|
| 194 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>,
|
| 195 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear4_head.pth">4 layers</a>)
|
| 196 |
+
</tr>
|
| 197 |
+
<tr>
|
| 198 |
+
<td>ViT-g/14</td>
|
| 199 |
+
<td align="center">:white_check_mark:</td>
|
| 200 |
+
<td>
|
| 201 |
+
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_lreg4_inear_head.pth">1 layer</a>,
|
| 202 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear4_head.pth">4 layers</a>)
|
| 203 |
+
</tr>
|
| 204 |
+
</tbody>
|
| 205 |
+
</table>
|
| 206 |
+
|
| 207 |
+
The (full) classifier models can be loaded via PyTorch Hub:
|
| 208 |
+
|
| 209 |
+
```python
|
| 210 |
+
import torch
|
| 211 |
+
|
| 212 |
+
# DINOv2
|
| 213 |
+
dinov2_vits14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
|
| 214 |
+
dinov2_vitb14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
|
| 215 |
+
dinov2_vitl14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
|
| 216 |
+
dinov2_vitg14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc')
|
| 217 |
+
|
| 218 |
+
# DINOv2 with registers
|
| 219 |
+
dinov2_vits14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc')
|
| 220 |
+
dinov2_vitb14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc')
|
| 221 |
+
dinov2_vitl14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc')
|
| 222 |
+
dinov2_vitg14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
### Pretrained heads - Depth estimation
|
| 226 |
+
|
| 227 |
+
<table style="margin: auto">
|
| 228 |
+
<thead>
|
| 229 |
+
<tr>
|
| 230 |
+
<th rowspan="2">backbone</th>
|
| 231 |
+
<th colspan="2">download head</th>
|
| 232 |
+
</tr>
|
| 233 |
+
<tr>
|
| 234 |
+
<th>NYUd</th>
|
| 235 |
+
<th>KITTI</th>
|
| 236 |
+
</tr>
|
| 237 |
+
</thead>
|
| 238 |
+
<tbody>
|
| 239 |
+
<tr>
|
| 240 |
+
<td>ViT-S/14 distilled</td>
|
| 241 |
+
<td>
|
| 242 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_linear_head.pth">1 layer</a>,
|
| 243 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_linear4_head.pth">4 layers</a>),
|
| 244 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth">DPT</a>
|
| 245 |
+
</td>
|
| 246 |
+
<td>
|
| 247 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_linear_head.pth">1 layer</a>,
|
| 248 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_linear4_head.pth">4 layers</a>),
|
| 249 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_dpt_head.pth">DPT</a>
|
| 250 |
+
</td>
|
| 251 |
+
</tr>
|
| 252 |
+
<tr>
|
| 253 |
+
<td>ViT-B/14 distilled</td>
|
| 254 |
+
<td>
|
| 255 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>,
|
| 256 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_linear4_head.pth">4 layers</a>),
|
| 257 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_dpt_head.pth">DPT</a>
|
| 258 |
+
</td>
|
| 259 |
+
<td>
|
| 260 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_linear_head.pth">1 layer</a>,
|
| 261 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_linear4_head.pth">4 layers</a>),
|
| 262 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_dpt_head.pth">DPT</a>
|
| 263 |
+
</td>
|
| 264 |
+
</tr>
|
| 265 |
+
<tr>
|
| 266 |
+
<td>ViT-L/14 distilled</td>
|
| 267 |
+
<td>
|
| 268 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>,
|
| 269 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_linear4_head.pth">4 layers</a>),
|
| 270 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_dpt_head.pth">DPT</a>
|
| 271 |
+
</td>
|
| 272 |
+
<td>
|
| 273 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_linear_head.pth">1 layer</a>,
|
| 274 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_linear4_head.pth">4 layers</a>),
|
| 275 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_dpt_head.pth">DPT</a>
|
| 276 |
+
</td>
|
| 277 |
+
</tr>
|
| 278 |
+
<tr>
|
| 279 |
+
<td>ViT-g/14</td>
|
| 280 |
+
<td>
|
| 281 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>,
|
| 282 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_linear4_head.pth">4 layers</a>),
|
| 283 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_dpt_head.pth">DPT</a>
|
| 284 |
+
</td>
|
| 285 |
+
<td>
|
| 286 |
+
linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_linear_head.pth">1 layer</a>,
|
| 287 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_linear4_head.pth">4 layers</a>),
|
| 288 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_dpt_head.pth">DPT</a>
|
| 289 |
+
</td>
|
| 290 |
+
</tr>
|
| 291 |
+
</tbody>
|
| 292 |
+
</table>
|
| 293 |
+
|
| 294 |
+
### Pretrained heads - Semantic segmentation
|
| 295 |
+
|
| 296 |
+
<table style="margin: auto">
|
| 297 |
+
<thead>
|
| 298 |
+
<tr>
|
| 299 |
+
<th rowspan="2">backbone</th>
|
| 300 |
+
<th>download model</th>
|
| 301 |
+
<th colspan="2">download head</th>
|
| 302 |
+
</tr>
|
| 303 |
+
<tr>
|
| 304 |
+
<th>ADE20K</th>
|
| 305 |
+
<th>ADE20K</th>
|
| 306 |
+
<th>VOC2012</th>
|
| 307 |
+
</tr>
|
| 308 |
+
</thead>
|
| 309 |
+
<tbody>
|
| 310 |
+
<tr>
|
| 311 |
+
<td>ViT-S/14 distilled</td>
|
| 312 |
+
<td></td>
|
| 313 |
+
<td>
|
| 314 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_linear_head.pth">linear</a>,
|
| 315 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_ms_head.pth">multi-scale</a>
|
| 316 |
+
</td>
|
| 317 |
+
<td>
|
| 318 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_voc2012_linear_head.pth">linear</a>,
|
| 319 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_voc2012_ms_head.pth">multi-scale</a>
|
| 320 |
+
</td>
|
| 321 |
+
</tr>
|
| 322 |
+
<tr>
|
| 323 |
+
<td>ViT-B/14 distilled</td>
|
| 324 |
+
<td></td>
|
| 325 |
+
<td>
|
| 326 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_ade20k_linear_head.pth">linear</a>,
|
| 327 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_ade20k_ms_head.pth">multi-scale</a>
|
| 328 |
+
</td>
|
| 329 |
+
<td>
|
| 330 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_voc2012_linear_head.pth">linear</a>,
|
| 331 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_voc2012_ms_head.pth">multi-scale</a>
|
| 332 |
+
</td>
|
| 333 |
+
</tr>
|
| 334 |
+
<tr>
|
| 335 |
+
<td>ViT-L/14 distilled</td>
|
| 336 |
+
<td></td>
|
| 337 |
+
<td>
|
| 338 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_ade20k_linear_head.pth">linear</a>,
|
| 339 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_ade20k_ms_head.pth">multi-scale</a>
|
| 340 |
+
</td>
|
| 341 |
+
<td>
|
| 342 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_voc2012_linear_head.pth">linear</a>,
|
| 343 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_voc2012_ms_head.pth">multi-scale</a>
|
| 344 |
+
</td>
|
| 345 |
+
</tr>
|
| 346 |
+
<tr>
|
| 347 |
+
<td>ViT-g/14</td>
|
| 348 |
+
<td>
|
| 349 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth">Mask2Former</a>
|
| 350 |
+
</td>
|
| 351 |
+
<td>
|
| 352 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_linear_head.pth">linear</a>,
|
| 353 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_ms_head.pth">multi-scale</a>
|
| 354 |
+
</td>
|
| 355 |
+
<td>
|
| 356 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_voc2012_linear_head.pth">linear</a>,
|
| 357 |
+
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_voc2012_ms_head.pth">multi-scale</a>
|
| 358 |
+
</td>
|
| 359 |
+
</tr>
|
| 360 |
+
</tbody>
|
| 361 |
+
</table>
|
| 362 |
+
|
| 363 |
+
## Installation
|
| 364 |
+
|
| 365 |
+
The training and evaluation code requires PyTorch 2.0 and [xFormers](https://github.com/facebookresearch/xformers) 0.0.18 as well as a number of other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:
|
| 366 |
+
|
| 367 |
+
*[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov2` conda environment using the provided environment definition:
|
| 368 |
+
|
| 369 |
+
```shell
|
| 370 |
+
conda env create -f conda.yaml
|
| 371 |
+
conda activate dinov2
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
*[pip](https://pip.pypa.io/en/stable/getting-started/)* - Clone the repository and then use the provided `requirements.txt` to install the dependencies:
|
| 375 |
+
|
| 376 |
+
```shell
|
| 377 |
+
pip install -r requirements.txt
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
For dense tasks (depth estimation and semantic segmentation), there are additional dependencies (specific versions of `mmcv` and `mmsegmentation`) which are captured in the `extras` dependency specifications:
|
| 381 |
+
|
| 382 |
+
*[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)**:
|
| 383 |
+
|
| 384 |
+
```shell
|
| 385 |
+
conda env create -f conda-extras.yaml
|
| 386 |
+
conda activate dinov2-extras
|
| 387 |
+
```
|
| 388 |
+
|
| 389 |
+
*[pip](https://pip.pypa.io/en/stable/getting-started/)*:
|
| 390 |
+
|
| 391 |
+
```shell
|
| 392 |
+
pip install -r requirements.txt -r requirements-extras.txt
|
| 393 |
+
```
|
| 394 |
+
|
| 395 |
+
## Data preparation
|
| 396 |
+
|
| 397 |
+
### ImageNet-1k
|
| 398 |
+
|
| 399 |
+
The root directory of the dataset should hold the following contents:
|
| 400 |
+
|
| 401 |
+
- `<ROOT>/test/ILSVRC2012_test_00000001.JPEG`
|
| 402 |
+
- `<ROOT>/test/[..]`
|
| 403 |
+
- `<ROOT>/test/ILSVRC2012_test_00100000.JPEG`
|
| 404 |
+
- `<ROOT>/train/n01440764/n01440764_10026.JPEG`
|
| 405 |
+
- `<ROOT>/train/[...]`
|
| 406 |
+
- `<ROOT>/train/n15075141/n15075141_9993.JPEG`
|
| 407 |
+
- `<ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
|
| 408 |
+
- `<ROOT>/val/[...]`
|
| 409 |
+
- `<ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
|
| 410 |
+
- `<ROOT>/labels.txt`
|
| 411 |
+
|
| 412 |
+
The provided dataset implementation expects a few additional metadata files to be present under the extra directory:
|
| 413 |
+
|
| 414 |
+
- `<EXTRA>/class-ids-TRAIN.npy`
|
| 415 |
+
- `<EXTRA>/class-ids-VAL.npy`
|
| 416 |
+
- `<EXTRA>/class-names-TRAIN.npy`
|
| 417 |
+
- `<EXTRA>/class-names-VAL.npy`
|
| 418 |
+
- `<EXTRA>/entries-TEST.npy`
|
| 419 |
+
- `<EXTRA>/entries-TRAIN.npy`
|
| 420 |
+
- `<EXTRA>/entries-VAL.npy`
|
| 421 |
+
|
| 422 |
+
These metadata files can be generated (once) with the following lines of Python code:
|
| 423 |
+
|
| 424 |
+
```python
|
| 425 |
+
from dinov2.data.datasets import ImageNet
|
| 426 |
+
|
| 427 |
+
for split in ImageNet.Split:
|
| 428 |
+
dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
|
| 429 |
+
dataset.dump_extra()
|
| 430 |
+
```
|
| 431 |
+
|
| 432 |
+
Note that the root and extra directories do not have to be distinct directories.
|
| 433 |
+
|
| 434 |
+
### ImageNet-22k
|
| 435 |
+
|
| 436 |
+
Please adapt the [dataset class](dinov2/data/datasets/image_net_22k.py) to match your local setup.
|
| 437 |
+
|
| 438 |
+
<br />
|
| 439 |
+
|
| 440 |
+
:warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
|
| 441 |
+
|
| 442 |
+
## Training
|
| 443 |
+
|
| 444 |
+
### Fast setup: training DINOv2 ViT-L/16 on ImageNet-1k
|
| 445 |
+
|
| 446 |
+
Run DINOv2 training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
|
| 447 |
+
|
| 448 |
+
```shell
|
| 449 |
+
python dinov2/run/train/train.py \
|
| 450 |
+
--nodes 4 \
|
| 451 |
+
--config-file dinov2/configs/train/vitl16_short.yaml \
|
| 452 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 453 |
+
train.dataset_path=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 454 |
+
```
|
| 455 |
+
|
| 456 |
+
Training time is approximately 1 day and the resulting checkpoint should reach 81.6% on k-NN eval and 82.9% on linear eval.
|
| 457 |
+
|
| 458 |
+
The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
|
| 459 |
+
|
| 460 |
+
### Long setup: training DINOv2 ViT-L/14 on ImageNet-22k
|
| 461 |
+
|
| 462 |
+
Run DINOv2 training on 12 A100-80GB nodes (96 GPUs) in a SLURM cluster environment with submitit:
|
| 463 |
+
|
| 464 |
+
```shell
|
| 465 |
+
python dinov2/run/train/train.py \
|
| 466 |
+
--nodes 12 \
|
| 467 |
+
--config-file dinov2/configs/train/vitl14.yaml \
|
| 468 |
+
--output-dir <PATH/TO/OUTPUT/DIR> \
|
| 469 |
+
train.dataset_path=ImageNet22k:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 470 |
+
```
|
| 471 |
+
|
| 472 |
+
Training time is approximately 3.3 days and the resulting checkpoint should reach 82.0% on k-NN eval and 84.5% on linear eval.
|
| 473 |
+
|
| 474 |
+
The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
## Evaluation
|
| 478 |
+
|
| 479 |
+
The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
|
| 480 |
+
|
| 481 |
+
### k-NN classification on ImageNet-1k
|
| 482 |
+
|
| 483 |
+
```shell
|
| 484 |
+
python dinov2/run/eval/knn.py \
|
| 485 |
+
--config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
|
| 486 |
+
--pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
|
| 487 |
+
--output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/knn \
|
| 488 |
+
--train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 489 |
+
--val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
### Logistic regression classification on ImageNet-1k
|
| 493 |
+
|
| 494 |
+
```shell
|
| 495 |
+
python dinov2/run/eval/log_regression.py \
|
| 496 |
+
--config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
|
| 497 |
+
--pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
|
| 498 |
+
--output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/logreg \
|
| 499 |
+
--train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 500 |
+
--val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 501 |
+
```
|
| 502 |
+
|
| 503 |
+
### Linear classification with data augmentation on ImageNet-1k
|
| 504 |
+
|
| 505 |
+
```shell
|
| 506 |
+
python dinov2/run/eval/linear.py \
|
| 507 |
+
--config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
|
| 508 |
+
--pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
|
| 509 |
+
--output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/linear \
|
| 510 |
+
--train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 511 |
+
--val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 512 |
+
```
|
| 513 |
+
|
| 514 |
+
We release the weights from evaluating the different models:
|
| 515 |
+
|
| 516 |
+
<table style="margin: auto">
|
| 517 |
+
<tr>
|
| 518 |
+
<th>model</th>
|
| 519 |
+
<th>with<br />registers</th>
|
| 520 |
+
<th>ImageNet<br />top-1</th>
|
| 521 |
+
<th>linear evaluation</th>
|
| 522 |
+
</tr>
|
| 523 |
+
<tr>
|
| 524 |
+
<td>ViT-S/14 distilled</td>
|
| 525 |
+
<td align="center">:x:</td>
|
| 526 |
+
<td align="right">81.1%</td>
|
| 527 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">linear head weights</a></td>
|
| 528 |
+
</tr>
|
| 529 |
+
<tr>
|
| 530 |
+
<td>ViT-S/14 distilled</td>
|
| 531 |
+
<td align="center">:white_check_mark:</td>
|
| 532 |
+
<td align="right">80.8%</td>
|
| 533 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">linear head weights</a></td>
|
| 534 |
+
</tr>
|
| 535 |
+
<tr>
|
| 536 |
+
<td>ViT-B/14 distilled</td>
|
| 537 |
+
<td align="center">:x:</td>
|
| 538 |
+
<td align="right">84.5%</td>
|
| 539 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">linear head weights</a></td>
|
| 540 |
+
</tr>
|
| 541 |
+
<tr>
|
| 542 |
+
<td>ViT-B/14 distilled</td>
|
| 543 |
+
<td align="center">:white_check_mark:</td>
|
| 544 |
+
<td align="right">84.4%</td>
|
| 545 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">linear head weights</a></td>
|
| 546 |
+
</tr>
|
| 547 |
+
<tr>
|
| 548 |
+
<td>ViT-L/14 distilled</td>
|
| 549 |
+
<td align="center">:x:</td>
|
| 550 |
+
<td align="right">86.3%</td>
|
| 551 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">linear head weights</a></td>
|
| 552 |
+
</tr>
|
| 553 |
+
<tr>
|
| 554 |
+
<td>ViT-L/14 distilled</td>
|
| 555 |
+
<td align="center">:white_check_mark:</td>
|
| 556 |
+
<td align="right">86.5%</td>
|
| 557 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">linear head weights</a></td>
|
| 558 |
+
</tr>
|
| 559 |
+
<tr>
|
| 560 |
+
<td>ViT-g/14</td>
|
| 561 |
+
<td align="center">:x:</td>
|
| 562 |
+
<td align="right">86.5%</td>
|
| 563 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">linear head weights</a></td>
|
| 564 |
+
</tr>
|
| 565 |
+
<tr>
|
| 566 |
+
<td>ViT-g/14</td>
|
| 567 |
+
<td align="center">:white_check_mark:</td>
|
| 568 |
+
<td align="right">87.0%</td>
|
| 569 |
+
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth">linear head weights</a></td>
|
| 570 |
+
</tr>
|
| 571 |
+
</table>
|
| 572 |
+
|
| 573 |
+
The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k:
|
| 574 |
+
|
| 575 |
+
```shell
|
| 576 |
+
python dinov2/run/eval/linear.py \
|
| 577 |
+
--config-file dinov2/configs/eval/vitg14_pretrain.yaml \
|
| 578 |
+
--pretrained-weights https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth \
|
| 579 |
+
--train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
|
| 580 |
+
--val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
|
| 581 |
+
```
|
| 582 |
+
|
| 583 |
+
## Notebooks
|
| 584 |
+
|
| 585 |
+
A few notebooks are provided to help the community leverage the models and code:
|
| 586 |
+
|
| 587 |
+
<ul>
|
| 588 |
+
<li><a href="https://github.com/facebookresearch/dinov2/blob/main/notebooks/depth_estimation.ipynb">Depth estimation</a> - How to load and use the depth heads in combination with a matching backbone via mmcv</li>
|
| 589 |
+
<li><a href="https://github.com/facebookresearch/dinov2/blob/main/notebooks/semantic_segmentation.ipynb">Semantic segmentation</a> - How to load and use the segmentation heads in combination with a matching backbone via mmcv, and also how to load and use the Mask2Former-based segmentation model trained on ADE20K</li>
|
| 590 |
+
</ul>
|
| 591 |
+
|
| 592 |
+
## License
|
| 593 |
+
|
| 594 |
+
DINOv2 code and model weights are released under the Apache License 2.0. See [LICENSE](LICENSE) for additional details.
|
| 595 |
+
|
| 596 |
+
## Contributing
|
| 597 |
+
|
| 598 |
+
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
| 599 |
+
|
| 600 |
+
## Citing DINOv2
|
| 601 |
+
|
| 602 |
+
If you find this repository useful, please consider giving a star :star: and citation :t-rex::
|
| 603 |
+
|
| 604 |
+
```
|
| 605 |
+
@misc{oquab2023dinov2,
|
| 606 |
+
title={DINOv2: Learning Robust Visual Features without Supervision},
|
| 607 |
+
author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
|
| 608 |
+
journal={arXiv:2304.07193},
|
| 609 |
+
year={2023}
|
| 610 |
+
}
|
| 611 |
+
```
|
| 612 |
+
|
| 613 |
+
```
|
| 614 |
+
@misc{darcet2023vitneedreg,
|
| 615 |
+
title={Vision Transformers Need Registers},
|
| 616 |
+
author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
|
| 617 |
+
journal={arXiv:2309.16588},
|
| 618 |
+
year={2023}
|
| 619 |
+
}
|
| 620 |
+
```
|
dinov2/conda-extras.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dinov2-extras
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
- pytorch
|
| 5 |
+
- nvidia
|
| 6 |
+
- xformers
|
| 7 |
+
- conda-forge
|
| 8 |
+
dependencies:
|
| 9 |
+
- python=3.9
|
| 10 |
+
- pytorch::pytorch=2.0.0
|
| 11 |
+
- pytorch::pytorch-cuda=11.7.0
|
| 12 |
+
- pytorch::torchvision=0.15.0
|
| 13 |
+
- omegaconf
|
| 14 |
+
- torchmetrics=0.10.3
|
| 15 |
+
- fvcore
|
| 16 |
+
- iopath
|
| 17 |
+
- xformers::xformers=0.0.18
|
| 18 |
+
- pip
|
| 19 |
+
- pip:
|
| 20 |
+
- git+https://github.com/facebookincubator/submitit
|
| 21 |
+
- --extra-index-url https://pypi.nvidia.com
|
| 22 |
+
- cuml-cu11
|
| 23 |
+
- mmcv-full==1.5.0
|
| 24 |
+
- mmsegmentation==0.27.0
|
dinov2/conda.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dinov2
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
- pytorch
|
| 5 |
+
- nvidia
|
| 6 |
+
- conda-forge
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.9
|
| 9 |
+
- pytorch=2.0.0
|
| 10 |
+
- pytorch-cuda=11.7
|
| 11 |
+
- torchvision=0.15.0
|
| 12 |
+
- omegaconf
|
| 13 |
+
- torchmetrics=0.10.3
|
| 14 |
+
- fvcore
|
| 15 |
+
- iopath
|
| 16 |
+
- pip
|
| 17 |
+
- pip:
|
| 18 |
+
- git+https://github.com/facebookincubator/submitit
|
| 19 |
+
- --extra-index-url https://pypi.nvidia.com
|
| 20 |
+
- cuml-cu11
|
| 21 |
+
- xformers==0.0.20 # Updated xformers version compatible with PyTorch 2.0
|
dinov2/pyproject.toml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.black]
|
| 2 |
+
line-length = 120
|
| 3 |
+
|
| 4 |
+
[tool.pylint.master]
|
| 5 |
+
persistent = false
|
| 6 |
+
score = false
|
| 7 |
+
|
| 8 |
+
[tool.pylint.messages_control]
|
| 9 |
+
disable = "all"
|
| 10 |
+
enable = [
|
| 11 |
+
"miscellaneous",
|
| 12 |
+
"similarities",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[tool.pylint.similarities]
|
| 16 |
+
ignore-comments = true
|
| 17 |
+
ignore-docstrings = true
|
| 18 |
+
ignore-imports = true
|
| 19 |
+
min-similarity-lines = 8
|
| 20 |
+
|
| 21 |
+
[tool.pylint.reports]
|
| 22 |
+
reports = false
|
| 23 |
+
|
| 24 |
+
[tool.pylint.miscellaneous]
|
| 25 |
+
notes = [
|
| 26 |
+
"FIXME",
|
| 27 |
+
"XXX",
|
| 28 |
+
"TODO",
|
| 29 |
+
]
|
dinov2/requirements-dev.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
black==22.6.0
|
| 2 |
+
flake8==5.0.4
|
| 3 |
+
pylint==2.15.0
|
dinov2/requirements-extras.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mmcv-full==1.5.0
|
| 2 |
+
mmsegmentation==0.27.0
|
dinov2/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
| 2 |
+
torch==2.0.0
|
| 3 |
+
torchvision==0.15.0
|
| 4 |
+
omegaconf
|
| 5 |
+
torchmetrics==0.10.3
|
| 6 |
+
fvcore
|
| 7 |
+
iopath
|
| 8 |
+
xformers==0.0.18
|
| 9 |
+
submitit
|
| 10 |
+
--extra-index-url https://pypi.nvidia.com
|
| 11 |
+
cuml-cu11
|
dinov2/scripts/lint.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
if [ -n "$1" ]; then
|
| 4 |
+
echo "linting \"$1\""
|
| 5 |
+
fi
|
| 6 |
+
|
| 7 |
+
echo "running black"
|
| 8 |
+
if [ -n "$1" ]; then
|
| 9 |
+
black "$1"
|
| 10 |
+
else
|
| 11 |
+
black dinov2
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
echo "running flake8"
|
| 15 |
+
if [ -n "$1" ]; then
|
| 16 |
+
flake8 "$1"
|
| 17 |
+
else
|
| 18 |
+
flake8
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
echo "running pylint"
|
| 22 |
+
if [ -n "$1" ]; then
|
| 23 |
+
pylint "$1"
|
| 24 |
+
else
|
| 25 |
+
pylint dinov2
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
exit 0
|
dinov2/setup.cfg
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
max-line-length = 120
|
| 3 |
+
ignore = E203,E501,W503
|
| 4 |
+
per-file-ignores =
|
| 5 |
+
__init__.py:F401
|
| 6 |
+
hubconf.py:F401
|
| 7 |
+
exclude =
|
| 8 |
+
venv
|
dinov2/setup.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import re
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
|
| 10 |
+
from setuptools import setup, find_packages
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
NAME = "dinov2"
|
| 14 |
+
DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method."
|
| 15 |
+
|
| 16 |
+
URL = "https://github.com/facebookresearch/dinov2"
|
| 17 |
+
AUTHOR = "FAIR"
|
| 18 |
+
REQUIRES_PYTHON = ">=3.9.0"
|
| 19 |
+
HERE = Path(__file__).parent
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
with open(HERE / "README.md", encoding="utf-8") as f:
|
| 24 |
+
long_description = "\n" + f.read()
|
| 25 |
+
except FileNotFoundError:
|
| 26 |
+
long_description = DESCRIPTION
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]:
|
| 30 |
+
requirements = []
|
| 31 |
+
extra_indices = []
|
| 32 |
+
with open(path) as f:
|
| 33 |
+
for line in f.readlines():
|
| 34 |
+
line = line.rstrip("\r\n")
|
| 35 |
+
if line.startswith("--extra-index-url "):
|
| 36 |
+
extra_indices.append(line[18:])
|
| 37 |
+
continue
|
| 38 |
+
requirements.append(line)
|
| 39 |
+
return requirements, extra_indices
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_package_version() -> str:
|
| 43 |
+
with open(HERE / "dinov2/__init__.py") as f:
|
| 44 |
+
result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
|
| 45 |
+
if result:
|
| 46 |
+
return result.group(1)
|
| 47 |
+
raise RuntimeError("Can't get package version")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
requirements, extra_indices = get_requirements()
|
| 51 |
+
version = get_package_version()
|
| 52 |
+
dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt")
|
| 53 |
+
extras_requirements, _ = get_requirements(HERE / "requirements-extras.txt")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
setup(
|
| 57 |
+
name=NAME,
|
| 58 |
+
version=version,
|
| 59 |
+
description=DESCRIPTION,
|
| 60 |
+
long_description=long_description,
|
| 61 |
+
long_description_content_type="text/markdown",
|
| 62 |
+
author=AUTHOR,
|
| 63 |
+
python_requires=REQUIRES_PYTHON,
|
| 64 |
+
url=URL,
|
| 65 |
+
packages=find_packages(),
|
| 66 |
+
package_data={
|
| 67 |
+
"": ["*.yaml"],
|
| 68 |
+
},
|
| 69 |
+
install_requires=requirements,
|
| 70 |
+
extras_require={
|
| 71 |
+
"dev": dev_requirements,
|
| 72 |
+
"extras": extras_requirements,
|
| 73 |
+
},
|
| 74 |
+
dependency_links=extra_indices,
|
| 75 |
+
install_package_data=True,
|
| 76 |
+
license="Apache",
|
| 77 |
+
license_files=("LICENSE",),
|
| 78 |
+
classifiers=[
|
| 79 |
+
# Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py
|
| 80 |
+
"Development Status :: 3 - Alpha",
|
| 81 |
+
"Intended Audience :: Developers",
|
| 82 |
+
"Intended Audience :: Science/Research",
|
| 83 |
+
"License :: OSI Approved :: Apache Software License",
|
| 84 |
+
"Programming Language :: Python :: 3.9",
|
| 85 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 86 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
| 87 |
+
],
|
| 88 |
+
)
|
explainability.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class ExplainabilitySpec:
|
| 12 |
+
encoder_mode: str
|
| 13 |
+
encoder_label: str
|
| 14 |
+
decoder_mode: str
|
| 15 |
+
decoder_label: str
|
| 16 |
+
encoder_layer_count: int = 0
|
| 17 |
+
encoder_head_count: int = 0
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ModuleOutputRecorder:
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
self.handle = None
|
| 23 |
+
self.output = None
|
| 24 |
+
|
| 25 |
+
def attach(self, module) -> None:
|
| 26 |
+
self.remove()
|
| 27 |
+
self.handle = module.register_forward_hook(self._hook)
|
| 28 |
+
|
| 29 |
+
def clear(self) -> None:
|
| 30 |
+
self.output = None
|
| 31 |
+
|
| 32 |
+
def remove(self) -> None:
|
| 33 |
+
if self.handle is not None:
|
| 34 |
+
self.handle.remove()
|
| 35 |
+
self.handle = None
|
| 36 |
+
self.output = None
|
| 37 |
+
|
| 38 |
+
def _hook(self, module, inputs, output) -> None: # pragma: no cover - hook signature
|
| 39 |
+
if torch.is_tensor(output):
|
| 40 |
+
self.output = output.detach()
|
| 41 |
+
else:
|
| 42 |
+
self.output = output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def clamp_index(index: int | None, upper_bound: int) -> int:
|
| 46 |
+
if upper_bound <= 0:
|
| 47 |
+
return 0
|
| 48 |
+
if index is None:
|
| 49 |
+
return upper_bound - 1
|
| 50 |
+
return max(0, min(int(index), upper_bound - 1))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def normalize_map(values) -> np.ndarray:
|
| 54 |
+
array = np.asarray(values, dtype=np.float32)
|
| 55 |
+
if array.ndim != 2:
|
| 56 |
+
raise ValueError(f"Expected a 2D array, got shape {array.shape}")
|
| 57 |
+
|
| 58 |
+
array = array.copy()
|
| 59 |
+
min_value = float(array.min(initial=0.0))
|
| 60 |
+
array -= min_value
|
| 61 |
+
max_value = float(array.max(initial=0.0))
|
| 62 |
+
if max_value > 0:
|
| 63 |
+
array /= max_value
|
| 64 |
+
return array
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def resize_rgb_image(rgb_image: np.ndarray, size: tuple[int, int]) -> np.ndarray:
|
| 68 |
+
width, height = size
|
| 69 |
+
return cv2.resize(rgb_image, (width, height), interpolation=cv2.INTER_LINEAR)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def feature_energy_map(feature_tensor: torch.Tensor, output_shape: tuple[int, int]) -> np.ndarray:
|
| 73 |
+
tensor = feature_tensor.detach().float()
|
| 74 |
+
while tensor.dim() > 3:
|
| 75 |
+
tensor = tensor[0]
|
| 76 |
+
if tensor.dim() == 3:
|
| 77 |
+
tensor = tensor.abs().mean(dim=0)
|
| 78 |
+
elif tensor.dim() != 2:
|
| 79 |
+
raise ValueError(f"Unexpected feature tensor shape: {tuple(feature_tensor.shape)}")
|
| 80 |
+
|
| 81 |
+
heatmap = normalize_map(tensor.cpu().numpy())
|
| 82 |
+
height, width = output_shape
|
| 83 |
+
return cv2.resize(heatmap, (width, height), interpolation=cv2.INTER_CUBIC)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def render_heatmap_overlay(rgb_image: np.ndarray, heatmap: np.ndarray, alpha: float = 0.45) -> np.ndarray:
|
| 87 |
+
if heatmap.shape != rgb_image.shape[:2]:
|
| 88 |
+
heatmap = cv2.resize(heatmap, (rgb_image.shape[1], rgb_image.shape[0]), interpolation=cv2.INTER_CUBIC)
|
| 89 |
+
colored = cv2.applyColorMap((normalize_map(heatmap) * 255.0).astype(np.uint8), cv2.COLORMAP_TURBO)
|
| 90 |
+
colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
|
| 91 |
+
return cv2.addWeighted(rgb_image, 1.0 - alpha, colored, alpha, 0.0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def render_temporal_strip(values, *, active_index: int | None = None, cell_width: int = 12, height: int = 72) -> np.ndarray:
|
| 95 |
+
sequence = np.asarray(values, dtype=np.float32).reshape(1, -1)
|
| 96 |
+
if sequence.size == 0:
|
| 97 |
+
sequence = np.zeros((1, 1), dtype=np.float32)
|
| 98 |
+
|
| 99 |
+
normalized = normalize_map(sequence)
|
| 100 |
+
strip = (normalized * 255.0).astype(np.uint8)
|
| 101 |
+
strip = np.repeat(strip, height, axis=0)
|
| 102 |
+
strip = np.repeat(strip, cell_width, axis=1)
|
| 103 |
+
colored = cv2.applyColorMap(strip, cv2.COLORMAP_TURBO)
|
| 104 |
+
colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
|
| 105 |
+
|
| 106 |
+
if active_index is not None and sequence.shape[1] > 0:
|
| 107 |
+
clamped = clamp_index(active_index, sequence.shape[1])
|
| 108 |
+
x0 = clamped * cell_width
|
| 109 |
+
x1 = min(colored.shape[1] - 1, x0 + cell_width - 1)
|
| 110 |
+
cv2.rectangle(colored, (x0, 0), (x1, colored.shape[0] - 1), (255, 255, 255), 2)
|
| 111 |
+
|
| 112 |
+
return colored
|
model/transformer.py
CHANGED
|
@@ -146,7 +146,7 @@ class Decoder(nn.Module):
|
|
| 146 |
super(Decoder, self).__init__()
|
| 147 |
self.layers = nn.ModuleList([DecoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
|
| 148 |
|
| 149 |
-
def forward(self, dec_inputs, enc_outputs):
|
| 150 |
'''
|
| 151 |
dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5]
|
| 152 |
enc_intpus: [batch_size, src_len, d_model] [512, 30, 5]
|
|
@@ -160,6 +160,8 @@ class Decoder(nn.Module):
|
|
| 160 |
# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
|
| 161 |
dec_outputs, dec_enc_attn = layer(dec_outputs, enc_outputs)
|
| 162 |
dec_enc_attns.append(dec_enc_attn)
|
|
|
|
|
|
|
| 163 |
return dec_outputs
|
| 164 |
|
| 165 |
|
|
@@ -175,7 +177,7 @@ class Transformer2_3_1(nn.Module):
|
|
| 175 |
self.encoder = Encoder(d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q)
|
| 176 |
self.decoder = Decoder(d_model, d_ff, d_k, d_v, 1, n_heads, len_q)
|
| 177 |
|
| 178 |
-
def forward(self, enc_inputs, dec_inputs):
|
| 179 |
'''
|
| 180 |
enc_inputs: [Frames, src_len, d_model] [512, 30, 5]
|
| 181 |
dec_inputs: [Frames, 1, d_model] [512, 1, 5]
|
|
@@ -185,8 +187,11 @@ class Transformer2_3_1(nn.Module):
|
|
| 185 |
|
| 186 |
# enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
|
| 187 |
enc_outputs, enc_self_attns = self.encoder(enc_inputs) # Self-attention for temporal features
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
class Transformer(nn.Module):
|
|
@@ -210,7 +215,7 @@ class Transformer(nn.Module):
|
|
| 210 |
nn.Linear(self.d_model, out_features, bias=False)
|
| 211 |
)
|
| 212 |
|
| 213 |
-
def forward(self, x, long_feature):
|
| 214 |
# x: [B, 256, T]; long_feature: [B, T, 256]
|
| 215 |
B, D, T = x.shape
|
| 216 |
out_features = x.transpose(1, 2) # [B, T, 256]
|
|
@@ -238,9 +243,24 @@ class Transformer(nn.Module):
|
|
| 238 |
win = out_features[:, i - spa_len + 1:i + 1, :]
|
| 239 |
out_feas.append(win)
|
| 240 |
out_feas = torch.stack(out_feas, dim=0).squeeze(1)
|
| 241 |
-
out_feas,
|
| 242 |
|
| 243 |
# Temporal-spatial fusion
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
output = self.out(output) # [T, B, C]
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
super(Decoder, self).__init__()
|
| 147 |
self.layers = nn.ModuleList([DecoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
|
| 148 |
|
| 149 |
+
def forward(self, dec_inputs, enc_outputs, return_attentions=False):
|
| 150 |
'''
|
| 151 |
dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5]
|
| 152 |
enc_intpus: [batch_size, src_len, d_model] [512, 30, 5]
|
|
|
|
| 160 |
# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
|
| 161 |
dec_outputs, dec_enc_attn = layer(dec_outputs, enc_outputs)
|
| 162 |
dec_enc_attns.append(dec_enc_attn)
|
| 163 |
+
if return_attentions:
|
| 164 |
+
return dec_outputs, dec_enc_attns
|
| 165 |
return dec_outputs
|
| 166 |
|
| 167 |
|
|
|
|
| 177 |
self.encoder = Encoder(d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q)
|
| 178 |
self.decoder = Decoder(d_model, d_ff, d_k, d_v, 1, n_heads, len_q)
|
| 179 |
|
| 180 |
+
def forward(self, enc_inputs, dec_inputs, return_attentions=False):
|
| 181 |
'''
|
| 182 |
enc_inputs: [Frames, src_len, d_model] [512, 30, 5]
|
| 183 |
dec_inputs: [Frames, 1, d_model] [512, 1, 5]
|
|
|
|
| 187 |
|
| 188 |
# enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
|
| 189 |
enc_outputs, enc_self_attns = self.encoder(enc_inputs) # Self-attention for temporal features
|
| 190 |
+
decoder_outputs = self.decoder(dec_inputs, enc_outputs, return_attentions=return_attentions)
|
| 191 |
+
if return_attentions:
|
| 192 |
+
dec_outputs, dec_enc_attns = decoder_outputs
|
| 193 |
+
return dec_outputs, {"encoder_self_attns": enc_self_attns, "decoder_cross_attns": dec_enc_attns}
|
| 194 |
+
return decoder_outputs
|
| 195 |
|
| 196 |
|
| 197 |
class Transformer(nn.Module):
|
|
|
|
| 215 |
nn.Linear(self.d_model, out_features, bias=False)
|
| 216 |
)
|
| 217 |
|
| 218 |
+
def forward(self, x, long_feature, return_attention=False):
|
| 219 |
# x: [B, 256, T]; long_feature: [B, T, 256]
|
| 220 |
B, D, T = x.shape
|
| 221 |
out_features = x.transpose(1, 2) # [B, T, 256]
|
|
|
|
| 243 |
win = out_features[:, i - spa_len + 1:i + 1, :]
|
| 244 |
out_feas.append(win)
|
| 245 |
out_feas = torch.stack(out_feas, dim=0).squeeze(1)
|
| 246 |
+
out_feas, spatial_attn = self.spatial_encoder(out_feas)
|
| 247 |
|
| 248 |
# Temporal-spatial fusion
|
| 249 |
+
transformer_outputs = self.transformer(inputs, out_feas, return_attentions=return_attention)
|
| 250 |
+
if return_attention:
|
| 251 |
+
output, attention_meta = transformer_outputs
|
| 252 |
+
else:
|
| 253 |
+
output = transformer_outputs
|
| 254 |
output = self.out(output) # [T, B, C]
|
| 255 |
+
output = output.transpose(0, 1) # [B, T, C]
|
| 256 |
+
if not return_attention:
|
| 257 |
+
return output
|
| 258 |
+
|
| 259 |
+
decoder_attn = attention_meta["decoder_cross_attns"][-1]
|
| 260 |
+
spatial_attn_last = spatial_attn[-1]
|
| 261 |
+
decoder_strip = decoder_attn[-1].mean(dim=0).mean(dim=0).detach()
|
| 262 |
+
spatial_strip = spatial_attn_last.mean(dim=0).mean(dim=0).detach()
|
| 263 |
+
return output, {
|
| 264 |
+
"decoder_strip": decoder_strip,
|
| 265 |
+
"spatial_strip": spatial_strip,
|
| 266 |
+
}
|
model_manager.py
CHANGED
|
@@ -56,6 +56,16 @@ class SpaceModelManager:
|
|
| 56 |
self.current_predictor = predictor
|
| 57 |
return predictor
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def reset_predictor_state(self) -> None:
|
| 60 |
if self.current_predictor is not None and hasattr(self.current_predictor, "reset_state"):
|
| 61 |
self.current_predictor.reset_state()
|
|
|
|
| 56 |
self.current_predictor = predictor
|
| 57 |
return predictor
|
| 58 |
|
| 59 |
+
def get_loaded_predictor(self, model_key: str | None = None):
|
| 60 |
+
if self.current_predictor is None:
|
| 61 |
+
return None
|
| 62 |
+
if model_key is None:
|
| 63 |
+
return self.current_predictor
|
| 64 |
+
normalized_key = normalize_model_key(model_key)
|
| 65 |
+
if self.current_model_key != normalized_key:
|
| 66 |
+
return None
|
| 67 |
+
return self.current_predictor
|
| 68 |
+
|
| 69 |
def reset_predictor_state(self) -> None:
|
| 70 |
if self.current_predictor is not None and hasattr(self.current_predictor, "reset_state"):
|
| 71 |
self.current_predictor.reset_state()
|
predictor.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
from contextlib import nullcontext
|
|
@@ -21,6 +22,15 @@ except ImportError: # pragma: no cover
|
|
| 21 |
from model.resnet import ResNet
|
| 22 |
from model.mstcn import MultiStageModel
|
| 23 |
from model.transformer import Transformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
PHASE_LABELS = ("idle", "marking", "injection", "dissection")
|
| 26 |
MODEL_LABELS = {
|
|
@@ -108,6 +118,37 @@ def _resolve_vendor_repo(repo_name: str, extra_candidates=()):
|
|
| 108 |
raise FileNotFoundError(f"Required vendor repo '{repo_name}' not found. Stage it into this folder or keep the repo-root copy available.")
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
class Predictor:
|
| 112 |
def __init__(self, model_dir: str | None = None, device: str = "cuda"):
|
| 113 |
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
@@ -118,6 +159,14 @@ class Predictor:
|
|
| 118 |
self.frame_feature_cache = None
|
| 119 |
self.label_dict = dict(enumerate(PHASE_LABELS))
|
| 120 |
self.available = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
self._norm_mean = None
|
| 123 |
self._norm_std = None
|
|
@@ -134,6 +183,9 @@ class Predictor:
|
|
| 134 |
paras = {k.replace("share.", "resnet."): v for k, v in paras.items()}
|
| 135 |
self.resnet.load_state_dict(paras, strict=True)
|
| 136 |
self.resnet.to(self.device).eval()
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
self.fusion = MultiStageModel(
|
| 139 |
mstcn_stages=2,
|
|
@@ -174,11 +226,18 @@ class Predictor:
|
|
| 174 |
self.predict(dummy)
|
| 175 |
self.reset_state()
|
| 176 |
|
|
|
|
|
|
|
|
|
|
| 177 |
def reset_state(self):
|
| 178 |
self.frame_feature_cache = None
|
|
|
|
| 179 |
if torch.cuda.is_available():
|
| 180 |
torch.cuda.empty_cache()
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
def unload(self):
|
| 183 |
self.available = False
|
| 184 |
self.resnet.to("cpu")
|
|
@@ -188,6 +247,10 @@ class Predictor:
|
|
| 188 |
self.fusion = None
|
| 189 |
self.transformer = None
|
| 190 |
self.frame_feature_cache = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
if torch.cuda.is_available():
|
| 192 |
torch.cuda.empty_cache()
|
| 193 |
|
|
@@ -200,7 +263,10 @@ class Predictor:
|
|
| 200 |
self.frame_feature_cache = torch.cat([self.frame_feature_cache, feature], dim=0)
|
| 201 |
|
| 202 |
@torch.inference_mode()
|
| 203 |
-
def predict(self, rgb_image: np.ndarray):
|
|
|
|
|
|
|
|
|
|
| 204 |
if self._norm_mean is not None:
|
| 205 |
tensor = self._preprocess_gpu(rgb_image)
|
| 206 |
else:
|
|
@@ -216,33 +282,91 @@ class Predictor:
|
|
| 216 |
single_frame_feature = feature.unsqueeze(1)
|
| 217 |
temporal_input = single_frame_feature.transpose(1, 2)
|
| 218 |
temporal_feature = self.fusion(temporal_input)
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
final_logits = outputs[-1, -1, :]
|
| 221 |
probs = F.softmax(final_logits.float(), dim=-1)
|
| 222 |
pred_np = probs.detach().cpu().numpy()
|
| 223 |
confidence = float(np.max(pred_np))
|
| 224 |
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 225 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
if self.frame_feature_cache.shape[0] < 30:
|
| 229 |
available_frames = self.frame_feature_cache.shape[0] + 1
|
| 230 |
cat_frame_feature = torch.cat([self.frame_feature_cache, feature], dim=0).unsqueeze(0)
|
| 231 |
temporal_input = cat_frame_feature.transpose(1, 2)
|
| 232 |
temporal_feature = self.fusion(temporal_input)
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
final_logits = outputs[-1, -1, :]
|
| 235 |
probs = F.softmax(final_logits.float(), dim=-1)
|
| 236 |
pred_np = probs.detach().cpu().numpy()
|
| 237 |
confidence = float(np.max(pred_np))
|
| 238 |
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 239 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
cat_frame_feature = self.frame_feature_cache.unsqueeze(0)
|
| 243 |
temporal_input = cat_frame_feature.transpose(1, 2)
|
| 244 |
temporal_feature = self.fusion(temporal_input)
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
final_logits = outputs[-1, -1, :]
|
| 247 |
probs = F.softmax(final_logits.float(), dim=-1)
|
| 248 |
pred_np = probs.detach().cpu().numpy()
|
|
@@ -250,7 +374,22 @@ class Predictor:
|
|
| 250 |
confidence = float(np.max(pred_np))
|
| 251 |
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 252 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
class PredictorDinoV2:
|
|
@@ -267,7 +406,19 @@ class PredictorDinoV2:
|
|
| 267 |
A.CenterCrop(height=224, width=224),
|
| 268 |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
|
| 269 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
self.frame_features = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
self._load_models(self.model_dir)
|
| 272 |
|
| 273 |
def _amp_context(self):
|
|
@@ -297,6 +448,14 @@ class PredictorDinoV2:
|
|
| 297 |
encoder_load = self.backbone.load_state_dict(encoder_state, strict=False)
|
| 298 |
_validate_load_result(encoder_load, "DINOv2 backbone")
|
| 299 |
self.backbone.to(self.device).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
decoder_path = os.path.join(model_dir, "fusion_transformer_decoder_best_model.pth")
|
| 302 |
if not os.path.exists(decoder_path):
|
|
@@ -326,7 +485,7 @@ class PredictorDinoV2:
|
|
| 326 |
d_model=d_model,
|
| 327 |
)
|
| 328 |
|
| 329 |
-
def forward(self, x):
|
| 330 |
x = x.permute(0, 2, 1)
|
| 331 |
x_reduced = self.reduce(x)
|
| 332 |
mstcn_input = x_reduced.permute(0, 2, 1)
|
|
@@ -341,8 +500,15 @@ class PredictorDinoV2:
|
|
| 341 |
else:
|
| 342 |
transformer_input = mstcn_input.detach()
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
self.decoder = FusionTransformerDecoder()
|
| 348 |
decoder_load = self.decoder.load_state_dict(decoder_state, strict=False)
|
|
@@ -359,6 +525,7 @@ class PredictorDinoV2:
|
|
| 359 |
|
| 360 |
def reset_state(self):
|
| 361 |
self.frame_features = []
|
|
|
|
| 362 |
if torch.cuda.is_available():
|
| 363 |
torch.cuda.empty_cache()
|
| 364 |
|
|
@@ -367,6 +534,40 @@ class PredictorDinoV2:
|
|
| 367 |
self.predict(dummy_img)
|
| 368 |
self.reset_state()
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
def unload(self):
|
| 371 |
if self.backbone is not None:
|
| 372 |
self.backbone.to("cpu")
|
|
@@ -375,15 +576,33 @@ class PredictorDinoV2:
|
|
| 375 |
self.backbone = None
|
| 376 |
self.decoder = None
|
| 377 |
self.frame_features = []
|
|
|
|
|
|
|
| 378 |
self.available = False
|
| 379 |
if torch.cuda.is_available():
|
| 380 |
torch.cuda.empty_cache()
|
| 381 |
|
| 382 |
@torch.inference_mode()
|
| 383 |
-
def predict(self, rgb_image: np.ndarray):
|
| 384 |
if not self.available or self.backbone is None or self.decoder is None:
|
| 385 |
raise RuntimeError("DINO-Endo predictor is not available")
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
processed = self.aug(image=rgb_image)["image"]
|
| 388 |
chw = np.transpose(processed, (2, 0, 1))
|
| 389 |
tensor = torch.tensor(chw, dtype=torch.float32).unsqueeze(0).to(self.device)
|
|
@@ -408,7 +627,11 @@ class PredictorDinoV2:
|
|
| 408 |
|
| 409 |
decoder_input = seq.transpose(1, 2)
|
| 410 |
with self._amp_context():
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
if logits.dim() != 3:
|
| 414 |
raise ValueError(f"Unexpected DINOv2 decoder output shape: {tuple(logits.shape)}")
|
|
@@ -424,7 +647,22 @@ class PredictorDinoV2:
|
|
| 424 |
confidence = float(np.max(pred_np))
|
| 425 |
phase_idx = int(np.argmax(pred_np))
|
| 426 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
|
| 430 |
class PredictorVJEPA2:
|
|
@@ -443,6 +681,15 @@ class PredictorVJEPA2:
|
|
| 443 |
self._feature_buffer = []
|
| 444 |
self._vjepa_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1, 1)
|
| 445 |
self._vjepa_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
self._load_models(self.model_dir)
|
| 447 |
|
| 448 |
def _amp_context(self):
|
|
@@ -509,7 +756,9 @@ class PredictorVJEPA2:
|
|
| 509 |
sys.path.insert(0, str(vjepa2_path))
|
| 510 |
|
| 511 |
from src.models import vision_transformer as vjepa_vit
|
|
|
|
| 512 |
from src.utils.checkpoint_loader import robust_checkpoint_loader
|
|
|
|
| 513 |
|
| 514 |
encoder_path = os.path.join(model_dir, "vjepa_encoder_human.pt")
|
| 515 |
if not os.path.exists(encoder_path):
|
|
@@ -530,6 +779,14 @@ class PredictorVJEPA2:
|
|
| 530 |
encoder_load = self.encoder.load_state_dict(encoder_state, strict=False)
|
| 531 |
self._validate_load_result(encoder_load, "V-JEPA2 encoder")
|
| 532 |
self.encoder.to(self.device).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
decoder_path = os.path.join(model_dir, "mlp_decoder_human.pth")
|
| 535 |
if not os.path.exists(decoder_path):
|
|
@@ -566,6 +823,7 @@ class PredictorVJEPA2:
|
|
| 566 |
def reset_state(self):
|
| 567 |
self._frame_buffer = []
|
| 568 |
self._feature_buffer = []
|
|
|
|
| 569 |
if torch.cuda.is_available():
|
| 570 |
torch.cuda.empty_cache()
|
| 571 |
|
|
@@ -574,6 +832,67 @@ class PredictorVJEPA2:
|
|
| 574 |
self.predict(dummy)
|
| 575 |
self.reset_state()
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
def unload(self):
|
| 578 |
if self.encoder is not None:
|
| 579 |
self.encoder.to("cpu")
|
|
@@ -583,15 +902,30 @@ class PredictorVJEPA2:
|
|
| 583 |
self.decoder = None
|
| 584 |
self._frame_buffer = []
|
| 585 |
self._feature_buffer = []
|
|
|
|
|
|
|
| 586 |
self.available = False
|
| 587 |
if torch.cuda.is_available():
|
| 588 |
torch.cuda.empty_cache()
|
| 589 |
|
| 590 |
@torch.inference_mode()
|
| 591 |
-
def predict(self, rgb_image: np.ndarray):
|
| 592 |
if not self.available:
|
| 593 |
raise RuntimeError("V-JEPA2 predictor is not available")
|
| 594 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
frame = np.ascontiguousarray(rgb_image, dtype=np.uint8)
|
| 596 |
self._frame_buffer.append(frame)
|
| 597 |
if len(self._frame_buffer) > self._clip_frames:
|
|
@@ -625,7 +959,30 @@ class PredictorVJEPA2:
|
|
| 625 |
confidence = float(np.max(pred_np))
|
| 626 |
phase_idx = int(np.argmax(pred_np))
|
| 627 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 628 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
|
| 630 |
|
| 631 |
def create_predictor(model_key: str, model_dir: str | None = None, device: str | None = None):
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import math
|
| 4 |
import os
|
| 5 |
import sys
|
| 6 |
from contextlib import nullcontext
|
|
|
|
| 22 |
from model.resnet import ResNet
|
| 23 |
from model.mstcn import MultiStageModel
|
| 24 |
from model.transformer import Transformer
|
| 25 |
+
from explainability import (
|
| 26 |
+
ExplainabilitySpec,
|
| 27 |
+
ModuleOutputRecorder,
|
| 28 |
+
clamp_index,
|
| 29 |
+
feature_energy_map,
|
| 30 |
+
render_heatmap_overlay,
|
| 31 |
+
render_temporal_strip,
|
| 32 |
+
resize_rgb_image,
|
| 33 |
+
)
|
| 34 |
|
| 35 |
PHASE_LABELS = ("idle", "marking", "injection", "dissection")
|
| 36 |
MODEL_LABELS = {
|
|
|
|
| 118 |
raise FileNotFoundError(f"Required vendor repo '{repo_name}' not found. Stage it into this folder or keep the repo-root copy available.")
|
| 119 |
|
| 120 |
|
| 121 |
+
def _build_explainability_payload(
|
| 122 |
+
*,
|
| 123 |
+
display_image: np.ndarray,
|
| 124 |
+
encoder_heatmap: np.ndarray,
|
| 125 |
+
encoder_kind: str,
|
| 126 |
+
encoder_label: str,
|
| 127 |
+
decoder_values,
|
| 128 |
+
decoder_kind: str,
|
| 129 |
+
decoder_label: str,
|
| 130 |
+
active_decoder_index: int | None = None,
|
| 131 |
+
encoder_layer: int | None = None,
|
| 132 |
+
encoder_head: int | None = None,
|
| 133 |
+
notes: str | None = None,
|
| 134 |
+
) -> dict:
|
| 135 |
+
payload = {
|
| 136 |
+
"encoder_kind": encoder_kind,
|
| 137 |
+
"encoder_label": encoder_label,
|
| 138 |
+
"encoder_visualization": render_heatmap_overlay(display_image, encoder_heatmap),
|
| 139 |
+
"decoder_kind": decoder_kind,
|
| 140 |
+
"decoder_label": decoder_label,
|
| 141 |
+
"decoder_visualization": render_temporal_strip(decoder_values, active_index=active_decoder_index),
|
| 142 |
+
}
|
| 143 |
+
if encoder_layer is not None:
|
| 144 |
+
payload["encoder_layer"] = int(encoder_layer)
|
| 145 |
+
if encoder_head is not None:
|
| 146 |
+
payload["encoder_head"] = int(encoder_head)
|
| 147 |
+
if notes:
|
| 148 |
+
payload["notes"] = notes
|
| 149 |
+
return payload
|
| 150 |
+
|
| 151 |
+
|
| 152 |
class Predictor:
|
| 153 |
def __init__(self, model_dir: str | None = None, device: str = "cuda"):
|
| 154 |
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
|
|
| 159 |
self.frame_feature_cache = None
|
| 160 |
self.label_dict = dict(enumerate(PHASE_LABELS))
|
| 161 |
self.available = False
|
| 162 |
+
self._resnet_activation = None
|
| 163 |
+
self._resnet_activation_hook = None
|
| 164 |
+
self._explainability_spec = ExplainabilitySpec(
|
| 165 |
+
encoder_mode="proxy",
|
| 166 |
+
encoder_label="ResNet layer4 activation energy (proxy)",
|
| 167 |
+
decoder_mode="attention",
|
| 168 |
+
decoder_label="Temporal Transformer attention",
|
| 169 |
+
)
|
| 170 |
|
| 171 |
self._norm_mean = None
|
| 172 |
self._norm_std = None
|
|
|
|
| 183 |
paras = {k.replace("share.", "resnet."): v for k, v in paras.items()}
|
| 184 |
self.resnet.load_state_dict(paras, strict=True)
|
| 185 |
self.resnet.to(self.device).eval()
|
| 186 |
+
self._resnet_activation_hook = self.resnet.resnet.layer4[-1].relu.register_forward_hook(
|
| 187 |
+
self._capture_resnet_activation
|
| 188 |
+
)
|
| 189 |
|
| 190 |
self.fusion = MultiStageModel(
|
| 191 |
mstcn_stages=2,
|
|
|
|
| 226 |
self.predict(dummy)
|
| 227 |
self.reset_state()
|
| 228 |
|
| 229 |
+
def _capture_resnet_activation(self, module, inputs, output): # pragma: no cover - hook signature
|
| 230 |
+
self._resnet_activation = output.detach()
|
| 231 |
+
|
| 232 |
def reset_state(self):
|
| 233 |
self.frame_feature_cache = None
|
| 234 |
+
self._resnet_activation = None
|
| 235 |
if torch.cuda.is_available():
|
| 236 |
torch.cuda.empty_cache()
|
| 237 |
|
| 238 |
+
def get_explainability_spec(self) -> ExplainabilitySpec:
|
| 239 |
+
return self._explainability_spec
|
| 240 |
+
|
| 241 |
def unload(self):
|
| 242 |
self.available = False
|
| 243 |
self.resnet.to("cpu")
|
|
|
|
| 247 |
self.fusion = None
|
| 248 |
self.transformer = None
|
| 249 |
self.frame_feature_cache = None
|
| 250 |
+
self._resnet_activation = None
|
| 251 |
+
if self._resnet_activation_hook is not None:
|
| 252 |
+
self._resnet_activation_hook.remove()
|
| 253 |
+
self._resnet_activation_hook = None
|
| 254 |
if torch.cuda.is_available():
|
| 255 |
torch.cuda.empty_cache()
|
| 256 |
|
|
|
|
| 263 |
self.frame_feature_cache = torch.cat([self.frame_feature_cache, feature], dim=0)
|
| 264 |
|
| 265 |
@torch.inference_mode()
|
| 266 |
+
def predict(self, rgb_image: np.ndarray, explainability: dict | None = None):
|
| 267 |
+
explain_enabled = bool(explainability and explainability.get("enabled"))
|
| 268 |
+
attention_meta = None
|
| 269 |
+
display_image = resize_rgb_image(rgb_image, (224, 224)) if explain_enabled else None
|
| 270 |
if self._norm_mean is not None:
|
| 271 |
tensor = self._preprocess_gpu(rgb_image)
|
| 272 |
else:
|
|
|
|
| 282 |
single_frame_feature = feature.unsqueeze(1)
|
| 283 |
temporal_input = single_frame_feature.transpose(1, 2)
|
| 284 |
temporal_feature = self.fusion(temporal_input)
|
| 285 |
+
transformer_outputs = self.transformer(
|
| 286 |
+
temporal_feature.detach(),
|
| 287 |
+
single_frame_feature,
|
| 288 |
+
return_attention=explain_enabled,
|
| 289 |
+
)
|
| 290 |
+
if explain_enabled:
|
| 291 |
+
outputs, attention_meta = transformer_outputs
|
| 292 |
+
else:
|
| 293 |
+
outputs = transformer_outputs
|
| 294 |
final_logits = outputs[-1, -1, :]
|
| 295 |
probs = F.softmax(final_logits.float(), dim=-1)
|
| 296 |
pred_np = probs.detach().cpu().numpy()
|
| 297 |
confidence = float(np.max(pred_np))
|
| 298 |
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 299 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 300 |
+
frames_used = 1
|
| 301 |
+
result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": frames_used}
|
| 302 |
+
if explain_enabled and attention_meta is not None and display_image is not None and self._resnet_activation is not None:
|
| 303 |
+
encoder_heatmap = feature_energy_map(self._resnet_activation, display_image.shape[:2])
|
| 304 |
+
result["explainability"] = _build_explainability_payload(
|
| 305 |
+
display_image=display_image,
|
| 306 |
+
encoder_heatmap=encoder_heatmap,
|
| 307 |
+
encoder_kind="proxy",
|
| 308 |
+
encoder_label=self._explainability_spec.encoder_label,
|
| 309 |
+
decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
|
| 310 |
+
decoder_kind="attention",
|
| 311 |
+
decoder_label=self._explainability_spec.decoder_label,
|
| 312 |
+
active_decoder_index=frames_used - 1,
|
| 313 |
+
notes="Encoder view is a proxy activation map because the ResNet backbone is not attention-based.",
|
| 314 |
+
)
|
| 315 |
+
return result
|
| 316 |
|
| 317 |
if self.frame_feature_cache.shape[0] < 30:
|
| 318 |
available_frames = self.frame_feature_cache.shape[0] + 1
|
| 319 |
cat_frame_feature = torch.cat([self.frame_feature_cache, feature], dim=0).unsqueeze(0)
|
| 320 |
temporal_input = cat_frame_feature.transpose(1, 2)
|
| 321 |
temporal_feature = self.fusion(temporal_input)
|
| 322 |
+
transformer_outputs = self.transformer(
|
| 323 |
+
temporal_feature.detach(),
|
| 324 |
+
cat_frame_feature,
|
| 325 |
+
return_attention=explain_enabled,
|
| 326 |
+
)
|
| 327 |
+
if explain_enabled:
|
| 328 |
+
outputs, attention_meta = transformer_outputs
|
| 329 |
+
else:
|
| 330 |
+
outputs = transformer_outputs
|
| 331 |
final_logits = outputs[-1, -1, :]
|
| 332 |
probs = F.softmax(final_logits.float(), dim=-1)
|
| 333 |
pred_np = probs.detach().cpu().numpy()
|
| 334 |
confidence = float(np.max(pred_np))
|
| 335 |
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 336 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 337 |
+
result = {
|
| 338 |
+
"phase": phase,
|
| 339 |
+
"probs": pred_np.tolist(),
|
| 340 |
+
"confidence": confidence,
|
| 341 |
+
"frames_used": available_frames,
|
| 342 |
+
}
|
| 343 |
+
if explain_enabled and attention_meta is not None and display_image is not None and self._resnet_activation is not None:
|
| 344 |
+
encoder_heatmap = feature_energy_map(self._resnet_activation, display_image.shape[:2])
|
| 345 |
+
result["explainability"] = _build_explainability_payload(
|
| 346 |
+
display_image=display_image,
|
| 347 |
+
encoder_heatmap=encoder_heatmap,
|
| 348 |
+
encoder_kind="proxy",
|
| 349 |
+
encoder_label=self._explainability_spec.encoder_label,
|
| 350 |
+
decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
|
| 351 |
+
decoder_kind="attention",
|
| 352 |
+
decoder_label=self._explainability_spec.decoder_label,
|
| 353 |
+
active_decoder_index=available_frames - 1,
|
| 354 |
+
notes="Encoder view is a proxy activation map because the ResNet backbone is not attention-based.",
|
| 355 |
+
)
|
| 356 |
+
return result
|
| 357 |
|
| 358 |
cat_frame_feature = self.frame_feature_cache.unsqueeze(0)
|
| 359 |
temporal_input = cat_frame_feature.transpose(1, 2)
|
| 360 |
temporal_feature = self.fusion(temporal_input)
|
| 361 |
+
transformer_outputs = self.transformer(
|
| 362 |
+
temporal_feature.detach(),
|
| 363 |
+
cat_frame_feature,
|
| 364 |
+
return_attention=explain_enabled,
|
| 365 |
+
)
|
| 366 |
+
if explain_enabled:
|
| 367 |
+
outputs, attention_meta = transformer_outputs
|
| 368 |
+
else:
|
| 369 |
+
outputs = transformer_outputs
|
| 370 |
final_logits = outputs[-1, -1, :]
|
| 371 |
probs = F.softmax(final_logits.float(), dim=-1)
|
| 372 |
pred_np = probs.detach().cpu().numpy()
|
|
|
|
| 374 |
confidence = float(np.max(pred_np))
|
| 375 |
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 376 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 377 |
+
frames_used = min(self.trans_seq, self.frame_feature_cache.shape[0])
|
| 378 |
+
result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": frames_used}
|
| 379 |
+
if explain_enabled and attention_meta is not None and display_image is not None and self._resnet_activation is not None:
|
| 380 |
+
encoder_heatmap = feature_energy_map(self._resnet_activation, display_image.shape[:2])
|
| 381 |
+
result["explainability"] = _build_explainability_payload(
|
| 382 |
+
display_image=display_image,
|
| 383 |
+
encoder_heatmap=encoder_heatmap,
|
| 384 |
+
encoder_kind="proxy",
|
| 385 |
+
encoder_label=self._explainability_spec.encoder_label,
|
| 386 |
+
decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
|
| 387 |
+
decoder_kind="attention",
|
| 388 |
+
decoder_label=self._explainability_spec.decoder_label,
|
| 389 |
+
active_decoder_index=frames_used - 1,
|
| 390 |
+
notes="Encoder view is a proxy activation map because the ResNet backbone is not attention-based.",
|
| 391 |
+
)
|
| 392 |
+
return result
|
| 393 |
|
| 394 |
|
| 395 |
class PredictorDinoV2:
|
|
|
|
| 406 |
A.CenterCrop(height=224, width=224),
|
| 407 |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
|
| 408 |
])
|
| 409 |
+
self.display_aug = A.Compose([
|
| 410 |
+
A.SmallestMaxSize(max_size=256, interpolation=cv2.INTER_LINEAR),
|
| 411 |
+
A.CenterCrop(height=224, width=224),
|
| 412 |
+
])
|
| 413 |
self.frame_features = []
|
| 414 |
+
self._attention_recorder = ModuleOutputRecorder()
|
| 415 |
+
self._attention_layer_index = None
|
| 416 |
+
self._explainability_spec = ExplainabilitySpec(
|
| 417 |
+
encoder_mode="attention",
|
| 418 |
+
encoder_label="DINOv2 encoder self-attention",
|
| 419 |
+
decoder_mode="attention",
|
| 420 |
+
decoder_label="Fusion Transformer temporal attention",
|
| 421 |
+
)
|
| 422 |
self._load_models(self.model_dir)
|
| 423 |
|
| 424 |
def _amp_context(self):
|
|
|
|
| 448 |
encoder_load = self.backbone.load_state_dict(encoder_state, strict=False)
|
| 449 |
_validate_load_result(encoder_load, "DINOv2 backbone")
|
| 450 |
self.backbone.to(self.device).eval()
|
| 451 |
+
self._explainability_spec = ExplainabilitySpec(
|
| 452 |
+
encoder_mode="attention",
|
| 453 |
+
encoder_label="DINOv2 encoder self-attention",
|
| 454 |
+
decoder_mode="attention",
|
| 455 |
+
decoder_label="Fusion Transformer temporal attention",
|
| 456 |
+
encoder_layer_count=len(self.backbone.blocks),
|
| 457 |
+
encoder_head_count=int(self.backbone.num_heads),
|
| 458 |
+
)
|
| 459 |
|
| 460 |
decoder_path = os.path.join(model_dir, "fusion_transformer_decoder_best_model.pth")
|
| 461 |
if not os.path.exists(decoder_path):
|
|
|
|
| 485 |
d_model=d_model,
|
| 486 |
)
|
| 487 |
|
| 488 |
+
def forward(self, x, return_attention=False):
|
| 489 |
x = x.permute(0, 2, 1)
|
| 490 |
x_reduced = self.reduce(x)
|
| 491 |
mstcn_input = x_reduced.permute(0, 2, 1)
|
|
|
|
| 500 |
else:
|
| 501 |
transformer_input = mstcn_input.detach()
|
| 502 |
|
| 503 |
+
transformer_outputs = self.transformer(
|
| 504 |
+
transformer_input,
|
| 505 |
+
x_reduced,
|
| 506 |
+
return_attention=return_attention,
|
| 507 |
+
)
|
| 508 |
+
if return_attention:
|
| 509 |
+
transformer_out, attention_meta = transformer_outputs
|
| 510 |
+
return transformer_out.permute(0, 2, 1), attention_meta
|
| 511 |
+
return transformer_outputs.permute(0, 2, 1)
|
| 512 |
|
| 513 |
self.decoder = FusionTransformerDecoder()
|
| 514 |
decoder_load = self.decoder.load_state_dict(decoder_state, strict=False)
|
|
|
|
| 525 |
|
| 526 |
def reset_state(self):
|
| 527 |
self.frame_features = []
|
| 528 |
+
self._attention_recorder.clear()
|
| 529 |
if torch.cuda.is_available():
|
| 530 |
torch.cuda.empty_cache()
|
| 531 |
|
|
|
|
| 534 |
self.predict(dummy_img)
|
| 535 |
self.reset_state()
|
| 536 |
|
| 537 |
+
def get_explainability_spec(self) -> ExplainabilitySpec:
|
| 538 |
+
return self._explainability_spec
|
| 539 |
+
|
| 540 |
+
def _ensure_attention_hook(self, layer_index: int) -> None:
|
| 541 |
+
clamped_layer = clamp_index(layer_index, self._explainability_spec.encoder_layer_count)
|
| 542 |
+
if self._attention_layer_index == clamped_layer and self._attention_recorder.handle is not None:
|
| 543 |
+
return
|
| 544 |
+
self._attention_recorder.attach(self.backbone.blocks[clamped_layer].norm1)
|
| 545 |
+
self._attention_layer_index = clamped_layer
|
| 546 |
+
|
| 547 |
+
def _compute_encoder_attention_map(self, head_index: int, output_shape: tuple[int, int]) -> np.ndarray:
|
| 548 |
+
if self._attention_recorder.output is None or self._attention_layer_index is None:
|
| 549 |
+
raise RuntimeError("DINO encoder attention recorder did not capture any tokens")
|
| 550 |
+
|
| 551 |
+
tokens = self._attention_recorder.output.to(self.device)
|
| 552 |
+
block = self.backbone.blocks[self._attention_layer_index]
|
| 553 |
+
attn_module = block.attn
|
| 554 |
+
qkv = attn_module.qkv(tokens).reshape(tokens.shape[0], tokens.shape[1], 3, attn_module.num_heads, -1).permute(
|
| 555 |
+
2, 0, 3, 1, 4
|
| 556 |
+
)
|
| 557 |
+
q = qkv[0] * attn_module.scale
|
| 558 |
+
k = qkv[1]
|
| 559 |
+
attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
|
| 560 |
+
|
| 561 |
+
head = clamp_index(head_index, attn.shape[1])
|
| 562 |
+
patch_start = 1 + int(getattr(self.backbone, "num_register_tokens", 0))
|
| 563 |
+
cls_attention = attn[0, head, 0, patch_start:]
|
| 564 |
+
patch_count = int(cls_attention.numel())
|
| 565 |
+
grid_size = int(math.sqrt(patch_count))
|
| 566 |
+
if grid_size * grid_size != patch_count:
|
| 567 |
+
raise RuntimeError(f"Unexpected DINO patch attention size: {patch_count}")
|
| 568 |
+
heatmap = cls_attention.view(grid_size, grid_size).detach().cpu().numpy()
|
| 569 |
+
return cv2.resize(heatmap, (output_shape[1], output_shape[0]), interpolation=cv2.INTER_CUBIC)
|
| 570 |
+
|
| 571 |
def unload(self):
|
| 572 |
if self.backbone is not None:
|
| 573 |
self.backbone.to("cpu")
|
|
|
|
| 576 |
self.backbone = None
|
| 577 |
self.decoder = None
|
| 578 |
self.frame_features = []
|
| 579 |
+
self._attention_recorder.remove()
|
| 580 |
+
self._attention_layer_index = None
|
| 581 |
self.available = False
|
| 582 |
if torch.cuda.is_available():
|
| 583 |
torch.cuda.empty_cache()
|
| 584 |
|
| 585 |
@torch.inference_mode()
|
| 586 |
+
def predict(self, rgb_image: np.ndarray, explainability: dict | None = None):
|
| 587 |
if not self.available or self.backbone is None or self.decoder is None:
|
| 588 |
raise RuntimeError("DINO-Endo predictor is not available")
|
| 589 |
|
| 590 |
+
explain_enabled = bool(explainability and explainability.get("enabled"))
|
| 591 |
+
encoder_layer = clamp_index(
|
| 592 |
+
explainability.get("encoder_layer") if explainability else None,
|
| 593 |
+
self._explainability_spec.encoder_layer_count,
|
| 594 |
+
)
|
| 595 |
+
encoder_head = clamp_index(
|
| 596 |
+
explainability.get("encoder_head") if explainability else None,
|
| 597 |
+
self._explainability_spec.encoder_head_count,
|
| 598 |
+
)
|
| 599 |
+
if explain_enabled:
|
| 600 |
+
self._ensure_attention_hook(encoder_layer)
|
| 601 |
+
self._attention_recorder.clear()
|
| 602 |
+
display_image = self.display_aug(image=rgb_image)["image"]
|
| 603 |
+
else:
|
| 604 |
+
display_image = None
|
| 605 |
+
|
| 606 |
processed = self.aug(image=rgb_image)["image"]
|
| 607 |
chw = np.transpose(processed, (2, 0, 1))
|
| 608 |
tensor = torch.tensor(chw, dtype=torch.float32).unsqueeze(0).to(self.device)
|
|
|
|
| 627 |
|
| 628 |
decoder_input = seq.transpose(1, 2)
|
| 629 |
with self._amp_context():
|
| 630 |
+
decoder_outputs = self.decoder(decoder_input, return_attention=explain_enabled)
|
| 631 |
+
if explain_enabled:
|
| 632 |
+
logits, attention_meta = decoder_outputs
|
| 633 |
+
else:
|
| 634 |
+
logits = decoder_outputs
|
| 635 |
|
| 636 |
if logits.dim() != 3:
|
| 637 |
raise ValueError(f"Unexpected DINOv2 decoder output shape: {tuple(logits.shape)}")
|
|
|
|
| 647 |
confidence = float(np.max(pred_np))
|
| 648 |
phase_idx = int(np.argmax(pred_np))
|
| 649 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 650 |
+
result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
|
| 651 |
+
if explain_enabled and display_image is not None:
|
| 652 |
+
encoder_heatmap = self._compute_encoder_attention_map(encoder_head, display_image.shape[:2])
|
| 653 |
+
result["explainability"] = _build_explainability_payload(
|
| 654 |
+
display_image=display_image,
|
| 655 |
+
encoder_heatmap=encoder_heatmap,
|
| 656 |
+
encoder_kind="attention",
|
| 657 |
+
encoder_label=self._explainability_spec.encoder_label,
|
| 658 |
+
decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
|
| 659 |
+
decoder_kind="attention",
|
| 660 |
+
decoder_label=self._explainability_spec.decoder_label,
|
| 661 |
+
active_decoder_index=available_frames - 1,
|
| 662 |
+
encoder_layer=encoder_layer,
|
| 663 |
+
encoder_head=encoder_head,
|
| 664 |
+
)
|
| 665 |
+
return result
|
| 666 |
|
| 667 |
|
| 668 |
class PredictorVJEPA2:
|
|
|
|
| 681 |
self._feature_buffer = []
|
| 682 |
self._vjepa_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1, 1)
|
| 683 |
self._vjepa_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1, 1)
|
| 684 |
+
self._attention_recorder = ModuleOutputRecorder()
|
| 685 |
+
self._attention_layer_index = None
|
| 686 |
+
self._rotate_queries_or_keys = None
|
| 687 |
+
self._explainability_spec = ExplainabilitySpec(
|
| 688 |
+
encoder_mode="attention",
|
| 689 |
+
encoder_label="V-JEPA2 encoder self-attention",
|
| 690 |
+
decoder_mode="proxy",
|
| 691 |
+
decoder_label="MLP decoder feature energy (proxy)",
|
| 692 |
+
)
|
| 693 |
self._load_models(self.model_dir)
|
| 694 |
|
| 695 |
def _amp_context(self):
|
|
|
|
| 756 |
sys.path.insert(0, str(vjepa2_path))
|
| 757 |
|
| 758 |
from src.models import vision_transformer as vjepa_vit
|
| 759 |
+
from src.models.utils.modules import rotate_queries_or_keys
|
| 760 |
from src.utils.checkpoint_loader import robust_checkpoint_loader
|
| 761 |
+
self._rotate_queries_or_keys = rotate_queries_or_keys
|
| 762 |
|
| 763 |
encoder_path = os.path.join(model_dir, "vjepa_encoder_human.pt")
|
| 764 |
if not os.path.exists(encoder_path):
|
|
|
|
| 779 |
encoder_load = self.encoder.load_state_dict(encoder_state, strict=False)
|
| 780 |
self._validate_load_result(encoder_load, "V-JEPA2 encoder")
|
| 781 |
self.encoder.to(self.device).eval()
|
| 782 |
+
self._explainability_spec = ExplainabilitySpec(
|
| 783 |
+
encoder_mode="attention",
|
| 784 |
+
encoder_label="V-JEPA2 encoder self-attention",
|
| 785 |
+
decoder_mode="proxy",
|
| 786 |
+
decoder_label="MLP decoder feature energy (proxy)",
|
| 787 |
+
encoder_layer_count=len(self.encoder.blocks),
|
| 788 |
+
encoder_head_count=int(self.encoder.num_heads),
|
| 789 |
+
)
|
| 790 |
|
| 791 |
decoder_path = os.path.join(model_dir, "mlp_decoder_human.pth")
|
| 792 |
if not os.path.exists(decoder_path):
|
|
|
|
| 823 |
def reset_state(self):
|
| 824 |
self._frame_buffer = []
|
| 825 |
self._feature_buffer = []
|
| 826 |
+
self._attention_recorder.clear()
|
| 827 |
if torch.cuda.is_available():
|
| 828 |
torch.cuda.empty_cache()
|
| 829 |
|
|
|
|
| 832 |
self.predict(dummy)
|
| 833 |
self.reset_state()
|
| 834 |
|
| 835 |
+
def get_explainability_spec(self) -> ExplainabilitySpec:
|
| 836 |
+
return self._explainability_spec
|
| 837 |
+
|
| 838 |
+
def _ensure_attention_hook(self, layer_index: int) -> None:
|
| 839 |
+
clamped_layer = clamp_index(layer_index, self._explainability_spec.encoder_layer_count)
|
| 840 |
+
if self._attention_layer_index == clamped_layer and self._attention_recorder.handle is not None:
|
| 841 |
+
return
|
| 842 |
+
self._attention_recorder.attach(self.encoder.blocks[clamped_layer].norm1)
|
| 843 |
+
self._attention_layer_index = clamped_layer
|
| 844 |
+
|
| 845 |
+
def _compute_encoder_attention_map(
|
| 846 |
+
self,
|
| 847 |
+
*,
|
| 848 |
+
head_index: int,
|
| 849 |
+
temporal_group_index: int,
|
| 850 |
+
output_shape: tuple[int, int],
|
| 851 |
+
) -> np.ndarray:
|
| 852 |
+
if self._attention_recorder.output is None or self._attention_layer_index is None:
|
| 853 |
+
raise RuntimeError("V-JEPA2 encoder attention recorder did not capture any tokens")
|
| 854 |
+
if self._rotate_queries_or_keys is None:
|
| 855 |
+
raise RuntimeError("V-JEPA2 rotation helper is unavailable")
|
| 856 |
+
|
| 857 |
+
tokens = self._attention_recorder.output.to(self.device)
|
| 858 |
+
block = self.encoder.blocks[self._attention_layer_index]
|
| 859 |
+
attn_module = block.attn
|
| 860 |
+
qkv = attn_module.qkv(tokens).unflatten(-1, (3, attn_module.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
| 861 |
+
q, k = qkv[0], qkv[1]
|
| 862 |
+
|
| 863 |
+
patch_grid = self._crop_size // 16
|
| 864 |
+
temporal_groups = self._clip_frames // self._tubelet_size
|
| 865 |
+
if hasattr(attn_module, "separate_positions"):
|
| 866 |
+
mask = torch.arange(int(temporal_groups * patch_grid * patch_grid), device=tokens.device)
|
| 867 |
+
d_mask, h_mask, w_mask = attn_module.separate_positions(mask, patch_grid, patch_grid)
|
| 868 |
+
offset = 0
|
| 869 |
+
qd = self._rotate_queries_or_keys(q[..., offset : offset + attn_module.d_dim], pos=d_mask)
|
| 870 |
+
kd = self._rotate_queries_or_keys(k[..., offset : offset + attn_module.d_dim], pos=d_mask)
|
| 871 |
+
offset += attn_module.d_dim
|
| 872 |
+
qh = self._rotate_queries_or_keys(q[..., offset : offset + attn_module.h_dim], pos=h_mask)
|
| 873 |
+
kh = self._rotate_queries_or_keys(k[..., offset : offset + attn_module.h_dim], pos=h_mask)
|
| 874 |
+
offset += attn_module.h_dim
|
| 875 |
+
qw = self._rotate_queries_or_keys(q[..., offset : offset + attn_module.w_dim], pos=w_mask)
|
| 876 |
+
kw = self._rotate_queries_or_keys(k[..., offset : offset + attn_module.w_dim], pos=w_mask)
|
| 877 |
+
offset += attn_module.w_dim
|
| 878 |
+
q_parts = [qd, qh, qw]
|
| 879 |
+
k_parts = [kd, kh, kw]
|
| 880 |
+
if offset < attn_module.head_dim:
|
| 881 |
+
q_parts.append(q[..., offset:])
|
| 882 |
+
k_parts.append(k[..., offset:])
|
| 883 |
+
q = torch.cat(q_parts, dim=-1)
|
| 884 |
+
k = torch.cat(k_parts, dim=-1)
|
| 885 |
+
|
| 886 |
+
attn = ((q @ k.transpose(-2, -1)) * attn_module.scale).softmax(dim=-1)
|
| 887 |
+
head = clamp_index(head_index, attn.shape[1])
|
| 888 |
+
group_size = patch_grid * patch_grid
|
| 889 |
+
group_index = clamp_index(temporal_group_index, temporal_groups)
|
| 890 |
+
start = group_index * group_size
|
| 891 |
+
end = start + group_size
|
| 892 |
+
group_attention = attn[0, head, start:end, start:end].mean(dim=0)
|
| 893 |
+
heatmap = group_attention.view(patch_grid, patch_grid).detach().cpu().numpy()
|
| 894 |
+
return cv2.resize(heatmap, (output_shape[1], output_shape[0]), interpolation=cv2.INTER_CUBIC)
|
| 895 |
+
|
| 896 |
def unload(self):
|
| 897 |
if self.encoder is not None:
|
| 898 |
self.encoder.to("cpu")
|
|
|
|
| 902 |
self.decoder = None
|
| 903 |
self._frame_buffer = []
|
| 904 |
self._feature_buffer = []
|
| 905 |
+
self._attention_recorder.remove()
|
| 906 |
+
self._attention_layer_index = None
|
| 907 |
self.available = False
|
| 908 |
if torch.cuda.is_available():
|
| 909 |
torch.cuda.empty_cache()
|
| 910 |
|
| 911 |
@torch.inference_mode()
|
| 912 |
+
def predict(self, rgb_image: np.ndarray, explainability: dict | None = None):
|
| 913 |
if not self.available:
|
| 914 |
raise RuntimeError("V-JEPA2 predictor is not available")
|
| 915 |
|
| 916 |
+
explain_enabled = bool(explainability and explainability.get("enabled"))
|
| 917 |
+
encoder_layer = clamp_index(
|
| 918 |
+
explainability.get("encoder_layer") if explainability else None,
|
| 919 |
+
self._explainability_spec.encoder_layer_count,
|
| 920 |
+
)
|
| 921 |
+
encoder_head = clamp_index(
|
| 922 |
+
explainability.get("encoder_head") if explainability else None,
|
| 923 |
+
self._explainability_spec.encoder_head_count,
|
| 924 |
+
)
|
| 925 |
+
if explain_enabled:
|
| 926 |
+
self._ensure_attention_hook(encoder_layer)
|
| 927 |
+
self._attention_recorder.clear()
|
| 928 |
+
|
| 929 |
frame = np.ascontiguousarray(rgb_image, dtype=np.uint8)
|
| 930 |
self._frame_buffer.append(frame)
|
| 931 |
if len(self._frame_buffer) > self._clip_frames:
|
|
|
|
| 959 |
confidence = float(np.max(pred_np))
|
| 960 |
phase_idx = int(np.argmax(pred_np))
|
| 961 |
phase = self.label_dict.get(phase_idx, "idle")
|
| 962 |
+
result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
|
| 963 |
+
if explain_enabled:
|
| 964 |
+
latest_group_index = latest_feature_idx // self._tubelet_size
|
| 965 |
+
display_image = resize_rgb_image(frame, (self._crop_size, self._crop_size))
|
| 966 |
+
encoder_heatmap = self._compute_encoder_attention_map(
|
| 967 |
+
head_index=encoder_head,
|
| 968 |
+
temporal_group_index=latest_group_index,
|
| 969 |
+
output_shape=display_image.shape[:2],
|
| 970 |
+
)
|
| 971 |
+
decoder_proxy_values = [feature.abs().mean().item() for feature in self._feature_buffer]
|
| 972 |
+
result["explainability"] = _build_explainability_payload(
|
| 973 |
+
display_image=display_image,
|
| 974 |
+
encoder_heatmap=encoder_heatmap,
|
| 975 |
+
encoder_kind="attention",
|
| 976 |
+
encoder_label=self._explainability_spec.encoder_label,
|
| 977 |
+
decoder_values=decoder_proxy_values,
|
| 978 |
+
decoder_kind="proxy",
|
| 979 |
+
decoder_label=self._explainability_spec.decoder_label,
|
| 980 |
+
active_decoder_index=available_frames - 1,
|
| 981 |
+
encoder_layer=encoder_layer,
|
| 982 |
+
encoder_head=encoder_head,
|
| 983 |
+
notes="Decoder view is a proxy feature-energy strip because the V-JEPA2 classifier head is an MLP.",
|
| 984 |
+
)
|
| 985 |
+
return result
|
| 986 |
|
| 987 |
|
| 988 |
def create_predictor(model_key: str, model_dir: str | None = None, device: str | None = None):
|
runtime-requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
-
streamlit>=1.
|
| 3 |
torch==2.5.1
|
| 4 |
torchvision==0.20.1
|
| 5 |
numpy>=1.26,<3
|
|
|
|
| 1 |
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
streamlit>=1.55,<2
|
| 3 |
torch==2.5.1
|
| 4 |
torchvision==0.20.1
|
| 5 |
numpy>=1.26,<3
|
scripts/publish_model_repo.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import HfApi
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
SCRIPT_PATH = Path(__file__).resolve()
|
| 14 |
+
SPACE_ROOT = SCRIPT_PATH.parents[1]
|
| 15 |
+
if str(SPACE_ROOT) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(SPACE_ROOT))
|
| 17 |
+
|
| 18 |
+
from model_registry import MODEL_SPECS
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
ENV_VAR_BY_FAMILY = {
|
| 22 |
+
"aiendo": "AIENDO_MODEL_REPO_ID",
|
| 23 |
+
"dinov2": "DINO_MODEL_REPO_ID",
|
| 24 |
+
"vjepa2": "VJEPA2_MODEL_REPO_ID",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _render_model_card(*, family: str, repo_id: str, copied_files: list[str]) -> str:
|
| 29 |
+
spec = MODEL_SPECS[family]
|
| 30 |
+
file_list = "\n".join(f"- `{name}`" for name in copied_files)
|
| 31 |
+
return f"""---
|
| 32 |
+
tags:
|
| 33 |
+
- medical-imaging
|
| 34 |
+
- endoscopy
|
| 35 |
+
- surgical-phase-recognition
|
| 36 |
+
- {family}
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
# {spec.label} checkpoints for the AI-Endo Hugging Face Space
|
| 40 |
+
|
| 41 |
+
This repository stores the published checkpoint set for the **{spec.label}** phase-recognition path used by `hf_spaces/DINO-ENDO/`.
|
| 42 |
+
|
| 43 |
+
## Files
|
| 44 |
+
|
| 45 |
+
{file_list}
|
| 46 |
+
|
| 47 |
+
## Consumed by the Space
|
| 48 |
+
|
| 49 |
+
Set the following Space environment variable so the Streamlit Space can download these files lazily at runtime:
|
| 50 |
+
|
| 51 |
+
```text
|
| 52 |
+
{ENV_VAR_BY_FAMILY[family]}={repo_id}
|
| 53 |
+
```
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _stage_model_family(*, family: str, model_dir: Path, staging_dir: Path, repo_id: str) -> int:
|
| 58 |
+
spec = MODEL_SPECS[family]
|
| 59 |
+
copied_files: list[str] = []
|
| 60 |
+
total_bytes = 0
|
| 61 |
+
|
| 62 |
+
for filename in spec.required_files:
|
| 63 |
+
src = model_dir / filename
|
| 64 |
+
if not src.exists():
|
| 65 |
+
raise FileNotFoundError(f"Missing required checkpoint: {src}")
|
| 66 |
+
dst = staging_dir / filename
|
| 67 |
+
shutil.copy2(src, dst)
|
| 68 |
+
copied_files.append(filename)
|
| 69 |
+
total_bytes += src.stat().st_size
|
| 70 |
+
|
| 71 |
+
for filename in spec.optional_files:
|
| 72 |
+
src = model_dir / filename
|
| 73 |
+
if not src.exists():
|
| 74 |
+
continue
|
| 75 |
+
dst = staging_dir / filename
|
| 76 |
+
shutil.copy2(src, dst)
|
| 77 |
+
copied_files.append(filename)
|
| 78 |
+
total_bytes += src.stat().st_size
|
| 79 |
+
|
| 80 |
+
(staging_dir / "README.md").write_text(
|
| 81 |
+
_render_model_card(family=family, repo_id=repo_id, copied_files=copied_files),
|
| 82 |
+
encoding="utf-8",
|
| 83 |
+
)
|
| 84 |
+
return total_bytes
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _should_use_large_upload(mode: str, total_bytes: int) -> bool:
|
| 88 |
+
if mode == "always":
|
| 89 |
+
return True
|
| 90 |
+
if mode == "never":
|
| 91 |
+
return False
|
| 92 |
+
return total_bytes >= 2 * 1024 * 1024 * 1024
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def parse_args() -> argparse.Namespace:
|
| 96 |
+
parser = argparse.ArgumentParser(description="Publish a model-family checkpoint repo for the HF Space.")
|
| 97 |
+
parser.add_argument("--family", choices=sorted(MODEL_SPECS), required=True)
|
| 98 |
+
parser.add_argument("--repo-id", required=True, help="Target Hugging Face model repo ID.")
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--model-dir",
|
| 101 |
+
default=str(SPACE_ROOT / "model"),
|
| 102 |
+
help="Directory containing the local checkpoints to publish.",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--upload-mode",
|
| 106 |
+
choices=("auto", "never", "always"),
|
| 107 |
+
default="auto",
|
| 108 |
+
help="Choose whether to force upload_large_folder for this family.",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument("--revision", default=None, help="Optional target revision or branch.")
|
| 111 |
+
parser.add_argument("--private", action="store_true", help="Create the model repo as private.")
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--token-env",
|
| 114 |
+
default="HF_TOKEN",
|
| 115 |
+
help="Environment variable name containing the Hugging Face write token.",
|
| 116 |
+
)
|
| 117 |
+
return parser.parse_args()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def main() -> None:
|
| 121 |
+
args = parse_args()
|
| 122 |
+
model_dir = Path(args.model_dir).expanduser().resolve()
|
| 123 |
+
token = os.getenv(args.token_env) or None
|
| 124 |
+
api = HfApi(token=token)
|
| 125 |
+
api.create_repo(repo_id=args.repo_id, repo_type="model", private=args.private, exist_ok=True)
|
| 126 |
+
|
| 127 |
+
with tempfile.TemporaryDirectory(prefix=f"hf-space-{args.family}-") as temp_dir:
|
| 128 |
+
staging_dir = Path(temp_dir)
|
| 129 |
+
total_bytes = _stage_model_family(
|
| 130 |
+
family=args.family,
|
| 131 |
+
model_dir=model_dir,
|
| 132 |
+
staging_dir=staging_dir,
|
| 133 |
+
repo_id=args.repo_id,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
upload_kwargs = {
|
| 137 |
+
"repo_id": args.repo_id,
|
| 138 |
+
"repo_type": "model",
|
| 139 |
+
"folder_path": str(staging_dir),
|
| 140 |
+
}
|
| 141 |
+
if args.revision:
|
| 142 |
+
upload_kwargs["revision"] = args.revision
|
| 143 |
+
|
| 144 |
+
if _should_use_large_upload(args.upload_mode, total_bytes):
|
| 145 |
+
api.upload_large_folder(**upload_kwargs)
|
| 146 |
+
mode = "upload_large_folder"
|
| 147 |
+
else:
|
| 148 |
+
api.upload_folder(**upload_kwargs)
|
| 149 |
+
mode = "upload_folder"
|
| 150 |
+
|
| 151 |
+
print(f"Published {args.family} checkpoints to {args.repo_id} via {mode}")
|
| 152 |
+
print(f"Suggested Space variable: {ENV_VAR_BY_FAMILY[args.family]}={args.repo_id}")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
main()
|
scripts/publish_space_repo.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from huggingface_hub import HfApi
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
SCRIPT_PATH = Path(__file__).resolve()
|
| 13 |
+
SPACE_ROOT = SCRIPT_PATH.parents[1]
|
| 14 |
+
if str(SPACE_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(SPACE_ROOT))
|
| 16 |
+
|
| 17 |
+
from stage_space_bundle import stage_bundle
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _space_variables(args: argparse.Namespace) -> dict[str, str]:
|
| 21 |
+
variables = {
|
| 22 |
+
"SPACE_ENABLED_MODELS": args.enabled_models,
|
| 23 |
+
"SPACE_DEFAULT_MODEL": args.default_model,
|
| 24 |
+
}
|
| 25 |
+
if args.aiendo_model_repo_id:
|
| 26 |
+
variables["AIENDO_MODEL_REPO_ID"] = args.aiendo_model_repo_id
|
| 27 |
+
if args.dino_model_repo_id:
|
| 28 |
+
variables["DINO_MODEL_REPO_ID"] = args.dino_model_repo_id
|
| 29 |
+
if args.vjepa2_model_repo_id:
|
| 30 |
+
variables["VJEPA2_MODEL_REPO_ID"] = args.vjepa2_model_repo_id
|
| 31 |
+
return variables
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_args() -> argparse.Namespace:
|
| 35 |
+
parser = argparse.ArgumentParser(description="Publish the staged Docker Space bundle and set its variables.")
|
| 36 |
+
parser.add_argument("--repo-id", required=True, help="Target Hugging Face Space repo ID.")
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--bundle-dir",
|
| 39 |
+
default=None,
|
| 40 |
+
help="Optional pre-staged bundle directory. If omitted, a temporary bundle is staged automatically.",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument("--enabled-models", default="dinov2,aiendo,vjepa2")
|
| 43 |
+
parser.add_argument("--default-model", default="dinov2")
|
| 44 |
+
parser.add_argument("--aiendo-model-repo-id", default=None)
|
| 45 |
+
parser.add_argument("--dino-model-repo-id", default=None)
|
| 46 |
+
parser.add_argument("--vjepa2-model-repo-id", default=None)
|
| 47 |
+
parser.add_argument("--revision", default=None, help="Optional target revision or branch.")
|
| 48 |
+
parser.add_argument("--private", action="store_true", help="Create the Space repo as private.")
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--token-env",
|
| 51 |
+
default="HF_TOKEN",
|
| 52 |
+
help="Environment variable name containing the Hugging Face write token.",
|
| 53 |
+
)
|
| 54 |
+
return parser.parse_args()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _publish_bundle(api: HfApi, *, repo_id: str, bundle_dir: Path, revision: str | None) -> None:
|
| 58 |
+
upload_kwargs = {
|
| 59 |
+
"repo_id": repo_id,
|
| 60 |
+
"repo_type": "space",
|
| 61 |
+
"folder_path": str(bundle_dir),
|
| 62 |
+
}
|
| 63 |
+
if revision:
|
| 64 |
+
upload_kwargs["revision"] = revision
|
| 65 |
+
api.upload_folder(**upload_kwargs)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main() -> None:
|
| 69 |
+
args = parse_args()
|
| 70 |
+
token = os.getenv(args.token_env) or None
|
| 71 |
+
api = HfApi(token=token)
|
| 72 |
+
api.create_repo(repo_id=args.repo_id, repo_type="space", space_sdk="docker", private=args.private, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
if args.bundle_dir:
|
| 75 |
+
bundle_dir = Path(args.bundle_dir).expanduser().resolve()
|
| 76 |
+
if not bundle_dir.exists():
|
| 77 |
+
raise FileNotFoundError(f"Bundle directory not found: {bundle_dir}")
|
| 78 |
+
_publish_bundle(api, repo_id=args.repo_id, bundle_dir=bundle_dir, revision=args.revision)
|
| 79 |
+
else:
|
| 80 |
+
with tempfile.TemporaryDirectory(prefix="hf-space-bundle-") as temp_dir:
|
| 81 |
+
bundle_dir = stage_bundle(SPACE_ROOT, Path(temp_dir), overwrite=True)
|
| 82 |
+
_publish_bundle(api, repo_id=args.repo_id, bundle_dir=bundle_dir, revision=args.revision)
|
| 83 |
+
|
| 84 |
+
for key, value in _space_variables(args).items():
|
| 85 |
+
api.add_space_variable(
|
| 86 |
+
repo_id=args.repo_id,
|
| 87 |
+
key=key,
|
| 88 |
+
value=value,
|
| 89 |
+
description=f"Managed by publish_space_repo.py for {key}",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
print(f"Published Space bundle to {args.repo_id}")
|
| 93 |
+
for key, value in _space_variables(args).items():
|
| 94 |
+
print(f"{key}={value}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
main()
|
scripts/stage_space_bundle.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
ROOT_FILES = (
|
| 9 |
+
".dockerignore",
|
| 10 |
+
".gitattributes",
|
| 11 |
+
".gitignore",
|
| 12 |
+
"Dockerfile",
|
| 13 |
+
"README.md",
|
| 14 |
+
"app.py",
|
| 15 |
+
"explainability.py",
|
| 16 |
+
"model_manager.py",
|
| 17 |
+
"model_registry.py",
|
| 18 |
+
"predictor.py",
|
| 19 |
+
"requirements.txt",
|
| 20 |
+
"runtime-requirements.txt",
|
| 21 |
+
"video_utils.py",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
ROOT_DIRS = (
|
| 25 |
+
".streamlit",
|
| 26 |
+
"dinov2",
|
| 27 |
+
"model",
|
| 28 |
+
"scripts",
|
| 29 |
+
"vjepa2",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
IGNORE_PATTERNS = (
|
| 33 |
+
".git",
|
| 34 |
+
".cache",
|
| 35 |
+
"__pycache__",
|
| 36 |
+
".pytest_cache",
|
| 37 |
+
".mypy_cache",
|
| 38 |
+
"*.egg-info",
|
| 39 |
+
"*.ipynb",
|
| 40 |
+
"*.pt",
|
| 41 |
+
"*.pth",
|
| 42 |
+
"*.pyc",
|
| 43 |
+
"*.pyo",
|
| 44 |
+
"assets",
|
| 45 |
+
"notebooks",
|
| 46 |
+
"tests",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _copy_item(src: Path, dst: Path) -> None:
|
| 51 |
+
if not src.exists():
|
| 52 |
+
raise FileNotFoundError(f"Missing required Space item: {src}")
|
| 53 |
+
|
| 54 |
+
if src.is_dir():
|
| 55 |
+
shutil.copytree(src, dst, ignore=shutil.ignore_patterns(*IGNORE_PATTERNS))
|
| 56 |
+
else:
|
| 57 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
shutil.copy2(src, dst)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def stage_bundle(space_root: Path, output_dir: Path, overwrite: bool) -> Path:
|
| 62 |
+
if output_dir.exists():
|
| 63 |
+
if not overwrite:
|
| 64 |
+
raise FileExistsError(f"Destination already exists: {output_dir}")
|
| 65 |
+
shutil.rmtree(output_dir)
|
| 66 |
+
|
| 67 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
for name in ROOT_FILES:
|
| 70 |
+
_copy_item(space_root / name, output_dir / name)
|
| 71 |
+
for name in ROOT_DIRS:
|
| 72 |
+
_copy_item(space_root / name, output_dir / name)
|
| 73 |
+
|
| 74 |
+
return output_dir
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def parse_args() -> argparse.Namespace:
|
| 78 |
+
parser = argparse.ArgumentParser(
|
| 79 |
+
description="Stage a code-only Hugging Face Space bundle from the local DINO-ENDO scaffold."
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--output-dir",
|
| 83 |
+
default="/tmp/dino_space_minimal_upload",
|
| 84 |
+
help="Destination directory for the staged bundle.",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--overwrite",
|
| 88 |
+
action="store_true",
|
| 89 |
+
help="Replace the destination directory if it already exists.",
|
| 90 |
+
)
|
| 91 |
+
return parser.parse_args()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def main() -> None:
|
| 95 |
+
args = parse_args()
|
| 96 |
+
script_path = Path(__file__).resolve()
|
| 97 |
+
space_root = script_path.parents[1]
|
| 98 |
+
output_dir = Path(args.output_dir).expanduser().resolve()
|
| 99 |
+
staged_dir = stage_bundle(space_root, output_dir, overwrite=args.overwrite)
|
| 100 |
+
print(f"Staged Space bundle at {staged_dir}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
scripts/stage_vendor_sources.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def copy_tree(src: Path, dst: Path, overwrite: bool) -> None:
|
| 9 |
+
if not src.exists():
|
| 10 |
+
raise FileNotFoundError(f"Source directory not found: {src}")
|
| 11 |
+
if dst.exists():
|
| 12 |
+
if not overwrite:
|
| 13 |
+
print(f"Skipping existing {dst}")
|
| 14 |
+
return
|
| 15 |
+
shutil.rmtree(dst)
|
| 16 |
+
shutil.copytree(
|
| 17 |
+
src,
|
| 18 |
+
dst,
|
| 19 |
+
ignore=shutil.ignore_patterns('.git', '__pycache__', '.pytest_cache', '.mypy_cache', '*.pyc', '*.pyo'),
|
| 20 |
+
)
|
| 21 |
+
print(f"Copied {src} -> {dst}")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main() -> None:
|
| 25 |
+
parser = argparse.ArgumentParser(description='Copy vendored dinov2/ and vjepa2/ source trees into the Space folder.')
|
| 26 |
+
parser.add_argument('--overwrite', action='store_true', help='Replace existing destination directories.')
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
script_path = Path(__file__).resolve()
|
| 30 |
+
space_root = script_path.parents[1]
|
| 31 |
+
repo_root = script_path.parents[3]
|
| 32 |
+
|
| 33 |
+
copy_tree(repo_root / 'dinov2', space_root / 'dinov2', overwrite=args.overwrite)
|
| 34 |
+
copy_tree(repo_root / 'vjepa2', space_root / 'vjepa2', overwrite=args.overwrite)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
main()
|
vjepa2/.flake8
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
max-line-length = 119
|
| 3 |
+
select = E,F,W
|
| 4 |
+
ignore = E203,E701,W503
|
| 5 |
+
per-file-ignores=__init__.py:F401 version.py:F401
|
vjepa2/.github/workflows/base_tests.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: UnitTests
|
| 2 |
+
|
| 3 |
+
on: [push]
|
| 4 |
+
|
| 5 |
+
jobs:
|
| 6 |
+
unittests:
|
| 7 |
+
runs-on: ubuntu-latest
|
| 8 |
+
strategy:
|
| 9 |
+
max-parallel: 4
|
| 10 |
+
|
| 11 |
+
steps:
|
| 12 |
+
- uses: actions/checkout@v4
|
| 13 |
+
- name: Set up Python 3.12
|
| 14 |
+
uses: actions/setup-python@v5
|
| 15 |
+
with:
|
| 16 |
+
python-version: '3.12'
|
| 17 |
+
- name: Add conda to system path
|
| 18 |
+
run: |
|
| 19 |
+
# $CONDA is an environment variable pointing to the root of the miniconda directory
|
| 20 |
+
echo $CONDA/bin >> $GITHUB_PATH
|
| 21 |
+
- name: Install dependencies
|
| 22 |
+
run: |
|
| 23 |
+
conda create --name test-env python=3.12
|
| 24 |
+
conda install pytest
|
| 25 |
+
echo "Starting setup from $PWD"
|
| 26 |
+
pip install -e .
|
| 27 |
+
- name: Test with pytest
|
| 28 |
+
run: |
|
| 29 |
+
pytest tests
|
vjepa2/.github/workflows/linters.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Lint (Common Code)
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- master
|
| 7 |
+
paths:
|
| 8 |
+
- 'app/'
|
| 9 |
+
- 'evals/*.py'
|
| 10 |
+
- 'src/'
|
| 11 |
+
- 'tests/'
|
| 12 |
+
pull_request:
|
| 13 |
+
branches:
|
| 14 |
+
- master
|
| 15 |
+
- 'gh/**'
|
| 16 |
+
paths:
|
| 17 |
+
- 'app/'
|
| 18 |
+
- 'evals/*.py'
|
| 19 |
+
- 'src/'
|
| 20 |
+
- 'tests/'
|
| 21 |
+
|
| 22 |
+
jobs:
|
| 23 |
+
run-linters:
|
| 24 |
+
name: Run linters
|
| 25 |
+
runs-on: ubuntu-latest
|
| 26 |
+
|
| 27 |
+
steps:
|
| 28 |
+
- uses: actions/checkout@v4
|
| 29 |
+
- name: Set up Python 3.12
|
| 30 |
+
uses: actions/setup-python@v5
|
| 31 |
+
with:
|
| 32 |
+
python-version: '3.12'
|
| 33 |
+
- name: Install Python lint dependencies
|
| 34 |
+
run: |
|
| 35 |
+
pip install -r requirements-test.txt
|
| 36 |
+
- name: Set lint paths
|
| 37 |
+
run: echo "lint_paths=app evals/*.py src tests" >> "$GITHUB_ENV"
|
| 38 |
+
- name: Run isort
|
| 39 |
+
run: |
|
| 40 |
+
python -m isort $lint_paths --check
|
| 41 |
+
- name: Run flake8
|
| 42 |
+
if: always()
|
| 43 |
+
run: |
|
| 44 |
+
python -m flake8 --config .flake8 --show-source --statistics $lint_paths
|
| 45 |
+
- name: Run black
|
| 46 |
+
if: always()
|
| 47 |
+
run: |
|
| 48 |
+
python -m black --check $lint_paths
|
vjepa2/.gitignore
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
| 2 |
+
.vscode/
|
| 3 |
+
.*.swp
|
| 4 |
+
|
| 5 |
+
run_vjepa_aws.py
|
| 6 |
+
run.py
|
| 7 |
+
main_distributed_video.py
|
| 8 |
+
main_video.py
|
| 9 |
+
|
| 10 |
+
app/vjepa/configs/temp_aws
|
| 11 |
+
app/main_dev.py
|
| 12 |
+
app/main_distributed_dev.py
|
| 13 |
+
evals/ava/alphaction/data
|
| 14 |
+
|
| 15 |
+
run_evals.py
|
| 16 |
+
run_evals_v2.py
|
| 17 |
+
run_pretrain.py
|
| 18 |
+
|
| 19 |
+
*.egg-info/
|
| 20 |
+
*.ipynb_checkpoints/
|
| 21 |
+
|
| 22 |
+
traces/
|
| 23 |
+
third_party/*
|
| 24 |
+
|
| 25 |
+
evals/simu_env_planning/local/
|
| 26 |
+
evals/simu_env_planning/docker2/
|
| 27 |
+
evals/simu_env_planning/docker/
|
| 28 |
+
app/vjepa_droid/local/
|
| 29 |
+
app/vjepa_droid_v2/local/
|
| 30 |
+
app/vjepa_droid_v3/local/
|
| 31 |
+
app/vjepa_droid_v4/local/
|
| 32 |
+
configs/local
|
vjepa2/APACHE-LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2018-2021 William Falcon
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
vjepa2/CHANGELOG.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Changelog
|
| 2 |
+
|
| 3 |
+
## [0.0.1] - 2025-06-05
|
| 4 |
+
|
| 5 |
+
Initial release of V-JEPA 2 codebase
|
vjepa2/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
vjepa2/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to V-JEPA 2
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code is consistent with style guidance (below) and lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
7. Add reviewer(s) for approval.
|
| 15 |
+
|
| 16 |
+
## Contributor License Agreement ("CLA")
|
| 17 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 18 |
+
to do this once to work on any of Facebook's open source projects.
|
| 19 |
+
|
| 20 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 21 |
+
|
| 22 |
+
## Issues
|
| 23 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 24 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 25 |
+
|
| 26 |
+
Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
|
| 27 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 28 |
+
outlined on that page and do not file a public issue.
|
| 29 |
+
|
| 30 |
+
## Coding Style
|
| 31 |
+
* 4 spaces for indentation rather than tabs
|
| 32 |
+
* 119 character line length
|
| 33 |
+
* PEP8 formatting
|
| 34 |
+
|
| 35 |
+
We recommend using `black`, `isort`, and `flake8` to format your code changes.
|
| 36 |
+
|
| 37 |
+
## License
|
| 38 |
+
By contributing to this repository, you agree that your contributions will be licensed
|
| 39 |
+
under the LICENSE file in the root directory of this source tree.
|
vjepa2/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
vjepa2/README.md
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning
|
| 2 |
+
|
| 3 |
+
### [Meta FAIR](https://ai.meta.com/research/)
|
| 4 |
+
|
| 5 |
+
Mahmoud Assran∗, Adrien Bardes∗, David Fan∗, Quentin Garrido∗, Russell Howes∗, Mojtaba
|
| 6 |
+
Komeili∗, Matthew Muckley∗, Ammar Rizvi∗, Claire Roberts∗, Koustuv Sinha∗, Artem Zholus*,
|
| 7 |
+
Sergio Arnaud*, Abha Gejji*, Ada Martin*, Francois Robert Hogan*, Daniel Dugas*, Piotr
|
| 8 |
+
Bojanowski, Vasil Khalidov, Patrick Labatut, Francisco Massa, Marc Szafraniec, Kapil
|
| 9 |
+
Krishnakumar, Yong Li, Xiaodong Ma, Sarath Chandar, Franziska Meier*, Yann LeCun*, Michael
|
| 10 |
+
Rabbat*, Nicolas Ballas*
|
| 11 |
+
|
| 12 |
+
*Core Team
|
| 13 |
+
|
| 14 |
+
[[`Paper`](https://arxiv.org/abs/2506.09985)] [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)] [[`BibTex`](#Citation)]
|
| 15 |
+
|
| 16 |
+
Official Pytorch codebase for V-JEPA 2 and V-JEPA 2-AC.
|
| 17 |
+
|
| 18 |
+
V-JEPA 2 is a self-supervised approach to training video encoders, using internet-scale video data, that attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.
|
| 19 |
+
|
| 20 |
+
<p align="center">
|
| 21 |
+
<img src="assets/flowchart.png" width=100%>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
<!---
|
| 25 |
+
## Updates
|
| 26 |
+
|
| 27 |
+
* **[Jun-6-25]:** V-JEPA 2 is released. [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)]
|
| 28 |
+
--->
|
| 29 |
+
|
| 30 |
+
## V-JEPA 2 Pre-training
|
| 31 |
+
|
| 32 |
+
**(Top)** The encoder and predictor are pre-trained through self-supervised learning from video using a masked latent feature prediction objective, leveraging abundant natural videos to bootstrap physical world understanding and prediction. **(Bottom)** Performance of V-JEPA 2 on downstream understanding and prediction tasks.
|
| 33 |
+
|
| 34 |
+
<img align="left" src="https://github.com/user-attachments/assets/914942d8-6a1e-409d-86ff-ff856b7346ab" width=65%>
|
| 35 |
+
<table>
|
| 36 |
+
<tr>
|
| 37 |
+
<th colspan="1">Benchmark</th>
|
| 38 |
+
<th colspan="1">VJEPA 2</th>
|
| 39 |
+
<th colspan="1">Previous Best</th>
|
| 40 |
+
</tr>
|
| 41 |
+
<tr>
|
| 42 |
+
<td>EK100</td>
|
| 43 |
+
<td>39.7%</td>
|
| 44 |
+
<td>27.6% (PlausiVL)</td>
|
| 45 |
+
</tr>
|
| 46 |
+
<tr>
|
| 47 |
+
<td>SSv2 (Probe)</td>
|
| 48 |
+
<td>77.3%</td>
|
| 49 |
+
<td>69.7% (InternVideo2-1B)</td>
|
| 50 |
+
</tr>
|
| 51 |
+
<tr>
|
| 52 |
+
<td>Diving48 (Probe)</td>
|
| 53 |
+
<td>90.2%</td>
|
| 54 |
+
<td>86.4% (InternVideo2-1B)</td>
|
| 55 |
+
</tr>
|
| 56 |
+
<tr>
|
| 57 |
+
<td>MVP (Video QA)</td>
|
| 58 |
+
<td>44.5%</td>
|
| 59 |
+
<td>39.9% (InternVL-2.5)</td>
|
| 60 |
+
</tr>
|
| 61 |
+
<tr>
|
| 62 |
+
<td>TempCompass (Video QA)</td>
|
| 63 |
+
<td>76.9%</td>
|
| 64 |
+
<td>75.3% (Tarsier 2)</td>
|
| 65 |
+
</tr>
|
| 66 |
+
</table>
|
| 67 |
+
|
| 68 |
+
## V-JEPA 2-AC Post-training
|
| 69 |
+
|
| 70 |
+
**(Top)** After post-training with a small amount of robot data, we can deploy the model on a robot arm in new environments, and tackle foundational tasks like reaching, grasping, and pick-and-place by planning from image goals. **(Bottom)** Performance on robot manipulation tasks using a Franka arm, with input provided through a monocular RGB camera.
|
| 71 |
+
|
| 72 |
+
<img align="left" src="https://github.com/user-attachments/assets/c5d42221-0102-4216-911d-061a4369a805" width=65%>
|
| 73 |
+
<table>
|
| 74 |
+
<tr>
|
| 75 |
+
<th colspan="1"></th>
|
| 76 |
+
<th colspan="1"></th>
|
| 77 |
+
<th colspan="2">Grasp</th>
|
| 78 |
+
<th colspan="2">Pick-and-Place</th>
|
| 79 |
+
</tr>
|
| 80 |
+
<tr>
|
| 81 |
+
<th colspan="1">Method</th>
|
| 82 |
+
<th colspan="1">Reach</th>
|
| 83 |
+
<th colspan="1">Cup</th>
|
| 84 |
+
<th colspan="1">Box</th>
|
| 85 |
+
<th colspan="1">Cup</th>
|
| 86 |
+
<th colspan="1">Box</th>
|
| 87 |
+
</tr>
|
| 88 |
+
<tr>
|
| 89 |
+
<td>Octo</td>
|
| 90 |
+
<td>100%</td>
|
| 91 |
+
<td>10%</td>
|
| 92 |
+
<td>0%</td>
|
| 93 |
+
<td>10%</td>
|
| 94 |
+
<td>10%</td>
|
| 95 |
+
</tr>
|
| 96 |
+
<tr>
|
| 97 |
+
<td>Cosmos</td>
|
| 98 |
+
<td>80%</td>
|
| 99 |
+
<td>0%</td>
|
| 100 |
+
<td>20%</td>
|
| 101 |
+
<td>0%</td>
|
| 102 |
+
<td>0%</td>
|
| 103 |
+
</tr>
|
| 104 |
+
<tr>
|
| 105 |
+
<td>VJEPA 2-AC</td>
|
| 106 |
+
<td>100%</td>
|
| 107 |
+
<td>60%</td>
|
| 108 |
+
<td>20%</td>
|
| 109 |
+
<td>80%</td>
|
| 110 |
+
<td>50%</td>
|
| 111 |
+
</tr>
|
| 112 |
+
</table>
|
| 113 |
+
|
| 114 |
+
## Models
|
| 115 |
+
|
| 116 |
+
### V-JEPA 2
|
| 117 |
+
|
| 118 |
+
#### HuggingFace
|
| 119 |
+
|
| 120 |
+
See our [HuggingFace collection](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) for V-JEPA 2.
|
| 121 |
+
|
| 122 |
+
#### Pretrained Checkpoints
|
| 123 |
+
|
| 124 |
+
<table>
|
| 125 |
+
<tr>
|
| 126 |
+
<th colspan="1">Model</th>
|
| 127 |
+
<th colspan="1">#Parameters</th>
|
| 128 |
+
<th colspan="1">Resolution</th>
|
| 129 |
+
<th colspan="1">Download Link</th>
|
| 130 |
+
<th colspan="1">Pretraining Config</th>
|
| 131 |
+
</tr>
|
| 132 |
+
<tr>
|
| 133 |
+
<td>ViT-L/16</td>
|
| 134 |
+
<td>300M</td>
|
| 135 |
+
<td>256</td>
|
| 136 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitl.pt">checkpoint</a></td>
|
| 137 |
+
<td><a href="configs/train/vitl16">configs</a></td>
|
| 138 |
+
</tr>
|
| 139 |
+
<tr>
|
| 140 |
+
<td>ViT-H/16</td>
|
| 141 |
+
<td>600M</td>
|
| 142 |
+
<td>256</td>
|
| 143 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/vith.pt">checkpoint</a></td>
|
| 144 |
+
<td><a href="configs/train/vith16/">configs</a></td>
|
| 145 |
+
</tr>
|
| 146 |
+
<tr>
|
| 147 |
+
<td>ViT-g/16</td>
|
| 148 |
+
<td>1B</td>
|
| 149 |
+
<td>256</td>
|
| 150 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitg.pt">checkpoint</a></td>
|
| 151 |
+
<td><a href="configs/train/vitg16">configs</a></td>
|
| 152 |
+
</tr>
|
| 153 |
+
<tr>
|
| 154 |
+
<td>ViT-g/16<sub>384</sub></td>
|
| 155 |
+
<td>1B</td>
|
| 156 |
+
<td>384</td>
|
| 157 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt">checkpoint</a></td>
|
| 158 |
+
<td><a href="configs/train/vitg16">configs</a></td>
|
| 159 |
+
</tr>
|
| 160 |
+
</table>
|
| 161 |
+
|
| 162 |
+
#### Pretrained backbones (via PyTorch Hub)
|
| 163 |
+
|
| 164 |
+
Please install [Pytorch](https://pytorch.org/get-started/locally/), [timm](https://pypi.org/project/timm/) and [einops](https://pypi.org/project/einops/) locally, then run the following to load each model. Installing Pytorch with CUDA support is strongly recommended.
|
| 165 |
+
|
| 166 |
+
```python
|
| 167 |
+
import torch
|
| 168 |
+
|
| 169 |
+
# preprocessor
|
| 170 |
+
processor = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_preprocessor')
|
| 171 |
+
# models
|
| 172 |
+
vjepa2_vit_large = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_large')
|
| 173 |
+
vjepa2_vit_huge = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_huge')
|
| 174 |
+
vjepa2_vit_giant = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant')
|
| 175 |
+
vjepa2_vit_giant_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')
|
| 176 |
+
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
#### Pretrained checkpoints on Huggingface
|
| 180 |
+
|
| 181 |
+
You can also use our pretrained checkpoints on [Huggingface](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6).
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
from transformers import AutoVideoProcessor, AutoModel
|
| 185 |
+
|
| 186 |
+
hf_repo = "facebook/vjepa2-vitg-fpc64-256"
|
| 187 |
+
# facebook/vjepa2-vitl-fpc64-256
|
| 188 |
+
# facebook/vjepa2-vith-fpc64-256
|
| 189 |
+
# facebook/vjepa2-vitg-fpc64-256
|
| 190 |
+
# facebook/vjepa2-vitg-fpc64-384
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
model = AutoModel.from_pretrained(hf_repo)
|
| 194 |
+
processor = AutoVideoProcessor.from_pretrained(hf_repo)
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
#### Evaluation Attentive Probes
|
| 198 |
+
|
| 199 |
+
We share the trained attentive probes for two of our visual understanding evals (Something-Something v2 and Diving48) and the action anticipation eval EPIC-KITCHENS-100.
|
| 200 |
+
|
| 201 |
+
<table>
|
| 202 |
+
<tr>
|
| 203 |
+
<th colspan="1">Model</th>
|
| 204 |
+
<th colspan="4">SSv2</th>
|
| 205 |
+
<th colspan="4">Diving48</th>
|
| 206 |
+
<th colspan="4">EK100</th>
|
| 207 |
+
</tr>
|
| 208 |
+
<tr>
|
| 209 |
+
<th colspan="1"></th>
|
| 210 |
+
<th colspan="1">Checkpoint</th>
|
| 211 |
+
<th colspan="1">Training Config</th>
|
| 212 |
+
<th colspan="1">Inference Config</th>
|
| 213 |
+
<th colspan="1">Result</th>
|
| 214 |
+
<th colspan="1">Checkpoint</th>
|
| 215 |
+
<th colspan="1">Training Config</th>
|
| 216 |
+
<th colspan="1">Inference Config</th>
|
| 217 |
+
<th colspan="1">Result</th>
|
| 218 |
+
<th colspan="1">Checkpoint</th>
|
| 219 |
+
<th colspan="1">Training Config</th>
|
| 220 |
+
<th colspan="1">Inference Config</th>
|
| 221 |
+
<th colspan="1">Result</th>
|
| 222 |
+
</tr>
|
| 223 |
+
<tr>
|
| 224 |
+
<td>ViT-L/16</td>
|
| 225 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitl-16x2x3.pt">checkpoint</a></td>
|
| 226 |
+
<td><a href="configs/eval/vitl/ssv2.yaml">config</a></td>
|
| 227 |
+
<td><a href="configs/inference/vitl/ssv2.yaml">config</a></td>
|
| 228 |
+
<td>73.7%</td>
|
| 229 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitl-256.pt">checkpoint</a></td>
|
| 230 |
+
<td><a href="configs/eval/vitl/diving48.yaml">config</a></td>
|
| 231 |
+
<td><a href="configs/inference/vitl/diving48.yaml">config</a></td>
|
| 232 |
+
<td>89.0%</td>
|
| 233 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ek100-vitl-256.pt">checkpoint</a></td>
|
| 234 |
+
<td><a href="configs/eval/vitl/ek100.yaml">config</a></td>
|
| 235 |
+
<td><a href="configs/inference/vitl/ek100.yaml">config</a></td>
|
| 236 |
+
<td>32.7 R@5</td>
|
| 237 |
+
</tr>
|
| 238 |
+
<tr>
|
| 239 |
+
<td>ViT-g/16<sub>384</td>
|
| 240 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt">checkpoint</a></td>
|
| 241 |
+
<td><a href="configs/eval/vitg-384/ssv2.yaml">config</a></td>
|
| 242 |
+
<td><a href="configs/inference/vitg-384/ssv2.yaml">config</a></td>
|
| 243 |
+
<td>77.3%</td>
|
| 244 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitg-384-32x4x3.pt">checkpoint</a></td>
|
| 245 |
+
<td><a href="configs/eval/vitg-384/diving48.yaml">config</a></td>
|
| 246 |
+
<td><a href="configs/inference/vitg-384/diving48.yaml">config</a></td>
|
| 247 |
+
<td>90.2%</td>
|
| 248 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ek100-vitg-384.pt">checkpoint</a></td>
|
| 249 |
+
<td><a href="configs/eval/vitg-384/ek100.yaml">config</a></td>
|
| 250 |
+
<td><a href="configs/inference/vitg-384/ek100.yaml">config</a></td>
|
| 251 |
+
<td>39.7 R@5</td>
|
| 252 |
+
</tr>
|
| 253 |
+
</table>
|
| 254 |
+
|
| 255 |
+
### V-JEPA 2-AC
|
| 256 |
+
|
| 257 |
+
Our action-conditioned checkpoint was trained from the ViT-g encoder.
|
| 258 |
+
<table>
|
| 259 |
+
<tr>
|
| 260 |
+
<th colspan="1">Model</th>
|
| 261 |
+
<th colspan="1">Download Link</th>
|
| 262 |
+
<th colspan="1">Training Config</th>
|
| 263 |
+
</tr>
|
| 264 |
+
<tr>
|
| 265 |
+
<td>ViT-g/16</td>
|
| 266 |
+
<td><a href="https://dl.fbaipublicfiles.com/vjepa2/vjepa2-ac-vitg.pt">checkpoint</a></td>
|
| 267 |
+
<td><a href="configs/train/vitg16/droid-256px-8f.yaml">config</a></td>
|
| 268 |
+
</tr>
|
| 269 |
+
</table>
|
| 270 |
+
|
| 271 |
+
#### Pretrained action-conditioned backbone (via PyTorch Hub)
|
| 272 |
+
|
| 273 |
+
Please install [Pytorch](https://pytorch.org/get-started/locally/), [timm](https://pypi.org/project/timm/) and [einops](https://pypi.org/project/einops/) locally, then run the following to load each model. Installing Pytorch with CUDA support is strongly recommended.
|
| 274 |
+
|
| 275 |
+
```python
|
| 276 |
+
import torch
|
| 277 |
+
|
| 278 |
+
vjepa2_encoder, vjepa2_ac_predictor = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_ac_vit_giant')
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
See [energy_landscape_example.ipynb](notebooks/energy_landscape_example.ipynb) for an example notebook computing the energy landscape of the pretrained action-conditioned backbone using a robot trajectory collected from our lab.
|
| 283 |
+
To run this notebook, you'll need to additionally install [Jupyter](https://jupyter.org/install) and [Scipy](https://scipy.org/install/) in your conda environment.
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
## Getting Started
|
| 287 |
+
|
| 288 |
+
### Setup
|
| 289 |
+
|
| 290 |
+
```
|
| 291 |
+
conda create -n vjepa2-312 python=3.12
|
| 292 |
+
conda activate vjepa2-312
|
| 293 |
+
pip install . # or `pip install -e .` for development mode
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
**Note to macOS users:** V-JEPA 2 relies on [`decord`](https://github.com/dmlc/decord), which does not support macOS (and, unfortunately, is also no longer under development). In order to run the V-JEPA 2 code on macOS, you will need a different `decord` implementation. We do not make specific recommendations, although some users have reported the use of [`eva-decord`](https://github.com/georgia-tech-db/eva-decord) (see [PR 1](https://github.com/facebookresearch/vjepa2/pull/1)) or [`decord2`](https://github.com/johnnynunez/decord2) (see [PR 31](https://github.com/facebookresearch/vjepa2/pull/31)). We leave the selection of the `decord` package up to the user's discretion.
|
| 297 |
+
|
| 298 |
+
### Usage Demo
|
| 299 |
+
|
| 300 |
+
See [vjepa2_demo.ipynb](notebooks/vjepa2_demo.ipynb) [(Colab Link)](https://colab.research.google.com/github/facebookresearch/vjepa2/blob/main/notebooks/vjepa2_demo.ipynb) or [vjepa2_demo.py](notebooks/vjepa2_demo.py) for an example of how to load both the HuggingFace and PyTorch V-JEPA 2 models and run inference on a sample video to get a sample classification result.
|
| 301 |
+
|
| 302 |
+
The script assumes the presence of downloaded model checkpoints so you will need to download the model weights and update the corresponding paths in the script. E.g.:
|
| 303 |
+
```
|
| 304 |
+
wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P YOUR_DIR
|
| 305 |
+
wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P YOUR_DIR
|
| 306 |
+
|
| 307 |
+
# Then update your model paths in vjepa2_demo.py.
|
| 308 |
+
pt_model_path = YOUR_DIR/vitg-384.pt
|
| 309 |
+
classifier_model_path = YOUR_DIR/ssv2-vitg-384-64x2x3.pt
|
| 310 |
+
|
| 311 |
+
# Then run the script (assumes your machine has a GPU)
|
| 312 |
+
python -m notebooks.vjepa2_demo
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
### Probe-based evaluation
|
| 316 |
+
|
| 317 |
+
Probe-based evaluation consists in training an attentive probe on top of frozen V-JEPA 2 features. We provide training scripts for training your own probes, and checkpoints to run inference directly.
|
| 318 |
+
|
| 319 |
+
#### Training probes
|
| 320 |
+
|
| 321 |
+
Evaluations can be run either locally, or distributed via SLURM. (Running locally is useful for debugging and validation).
|
| 322 |
+
These sample commands launch Something-Something v2 video classification; other evals are launched by specifying the corresponding config.
|
| 323 |
+
Use provided training configs under "Evaluation Attentive Probes". These configs allow to train multiple probes in parallel with various optimization parameters.
|
| 324 |
+
Change filepaths as needed (e.g. `folder`, `checkpoint`, `dataset_train`, `dataset_val`) to match locations of data and downloaded checkpoints on your local filesystem.
|
| 325 |
+
Change \# nodes and local batch size as needed to not exceed available GPU memory.
|
| 326 |
+
|
| 327 |
+
##### Local
|
| 328 |
+
|
| 329 |
+
To run locally, specify the GPUs to use on
|
| 330 |
+
```
|
| 331 |
+
python -m evals.main --fname configs/eval/vitl16/ssv2.yaml \
|
| 332 |
+
--devices cuda:0 cuda:1
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
##### Distributed
|
| 336 |
+
|
| 337 |
+
```
|
| 338 |
+
python -m evals.main_distributed \
|
| 339 |
+
--fname configs/eval/vitl/ssv2.yaml \
|
| 340 |
+
--time 8600 \
|
| 341 |
+
--account my_account --qos=my_qos
|
| 342 |
+
```
|
| 343 |
+
|
| 344 |
+
#### Inference from existing probes
|
| 345 |
+
|
| 346 |
+
Use provided inference configs under [Evaluation Attentive Probes](#evaluation-attentive-probes).
|
| 347 |
+
Download the corresponding checkpoint, rename it to 'latest.pt', and create a folder with the checkpoint inside, with the format matching the variables in the config:
|
| 348 |
+
```
|
| 349 |
+
[folder]/[eval_name]/[tag]/latest.pt
|
| 350 |
+
```
|
| 351 |
+
Then run inference, locally or distributed, using the same evaluation commands as above, but with configs from `configs/inference`.
|
| 352 |
+
|
| 353 |
+
### Pretraining
|
| 354 |
+
|
| 355 |
+
Likewise, training can also be run locally or distributed. Pretraining and cooldown training phases are
|
| 356 |
+
run with the same command using different configs.
|
| 357 |
+
These sample commands launch initial training of a ViT-L model. Configs for cooldown (or action-conditioned) training
|
| 358 |
+
can be found in the same directory as the config for initial training.
|
| 359 |
+
|
| 360 |
+
#### Local
|
| 361 |
+
|
| 362 |
+
```
|
| 363 |
+
python -m app.main --fname configs/train/vitl16/pretrain-256px-16f.yaml \
|
| 364 |
+
--devices cuda:0
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
#### Distributed
|
| 368 |
+
|
| 369 |
+
```
|
| 370 |
+
python -m app.main_distributed \
|
| 371 |
+
--fname configs/train/vitl16/pretrain-256px-16f.yaml
|
| 372 |
+
--time 6000
|
| 373 |
+
--account my_account --qos=my_qos
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
### Postraining
|
| 377 |
+
|
| 378 |
+
Post-training of the action-conditioned model, starting from the pretrained VJEPA 2 backbone, also follows a similar interface, and can be run locally or distributed using [this config](configs/train/vitg16/droid-256px-8f.yaml).
|
| 379 |
+
We post-train the model starting from the ViT-g/16 backbone.
|
| 380 |
+
|
| 381 |
+
#### Local
|
| 382 |
+
|
| 383 |
+
```
|
| 384 |
+
python -m app.main --fname configs/train/vitg16/droid-256px-8f.yaml \
|
| 385 |
+
--devices cuda:0
|
| 386 |
+
```
|
| 387 |
+
|
| 388 |
+
#### Distributed
|
| 389 |
+
|
| 390 |
+
```
|
| 391 |
+
python -m app.main_distributed \
|
| 392 |
+
--fname configs/train/vitg16/droid-256px-8f.yaml
|
| 393 |
+
--time 6000
|
| 394 |
+
--account my_account --qos=my_qos
|
| 395 |
+
```
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
## Code Structure
|
| 399 |
+
|
| 400 |
+
```
|
| 401 |
+
.
|
| 402 |
+
├── app # training loops
|
| 403 |
+
│ ├── vjepa # video JEPA pre-training
|
| 404 |
+
│ ├── vjepa_droid # training the action-conditioned model
|
| 405 |
+
│ ├── main_distributed.py # entrypoint for launch app on slurm cluster
|
| 406 |
+
│ └── main.py # entrypoint for launch app locally on your machine
|
| 407 |
+
├── configs # config files with experiment params for training and evaluation
|
| 408 |
+
│ ├── train # pretraining (phase 1), cooldown (phase 2), and action-conditioned training
|
| 409 |
+
│ └── eval # frozen evaluations
|
| 410 |
+
├── evals # evaluation loops training an attentive probe with frozen backbone...
|
| 411 |
+
│ ├── action_anticipation_frozen # action anticipation
|
| 412 |
+
│ ├── image_classification_frozen # image understanding
|
| 413 |
+
│ ├── video_classification_frozen # video understanding
|
| 414 |
+
│ ├── main_distributed.py # entrypoint for distributed evaluations
|
| 415 |
+
│ └── main.py # entrypoint for locally-run evaluations
|
| 416 |
+
├── src # the package
|
| 417 |
+
│ ├── datasets # datasets, data loaders, ...
|
| 418 |
+
│ ├── models # model definitions
|
| 419 |
+
│ ├── masks # mask collators, masking utilities, ...
|
| 420 |
+
│ └── utils # shared utilities
|
| 421 |
+
├── tests # unit tests for some modules in `src`
|
| 422 |
+
|
| 423 |
+
```
|
| 424 |
+
|
| 425 |
+
## License
|
| 426 |
+
|
| 427 |
+
The majority of V-JEPA 2 is licensed under MIT, however portions of the project are available under separate license terms:
|
| 428 |
+
|
| 429 |
+
[src/datasets/utils/video/randaugment.py](src/datasets/utils/video/randaugment.py)<br>
|
| 430 |
+
[src/datasets/utils/video/randerase.py](src/datasets/utils/video/randerase.py)<br>
|
| 431 |
+
[src/datasets/utils/worker_init_fn.py](src/datasets/utils/worker_init_fn.py)<br>
|
| 432 |
+
|
| 433 |
+
are licensed under the Apache 2.0 license.
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
## Citation
|
| 437 |
+
If you find this repository useful in your research, please consider giving a star :star: and a citation
|
| 438 |
+
```bibtex
|
| 439 |
+
@article{assran2025vjepa2,
|
| 440 |
+
title={V-JEPA~2: Self-Supervised Video Models Enable Understanding, Prediction and Planning},
|
| 441 |
+
author={Assran, Mahmoud and Bardes, Adrien and Fan, David and Garrido, Quentin and Howes, Russell and
|
| 442 |
+
Komeili, Mojtaba and Muckley, Matthew and Rizvi, Ammar and Roberts, Claire and Sinha, Koustuv and Zholus, Artem and
|
| 443 |
+
Arnaud, Sergio and Gejji, Abha and Martin, Ada and Robert Hogan, Francois and Dugas, Daniel and
|
| 444 |
+
Bojanowski, Piotr and Khalidov, Vasil and Labatut, Patrick and Massa, Francisco and Szafraniec, Marc and
|
| 445 |
+
Krishnakumar, Kapil and Li, Yong and Ma, Xiaodong and Chandar, Sarath and Meier, Franziska and LeCun, Yann and
|
| 446 |
+
Rabbat, Michael and Ballas, Nicolas},
|
| 447 |
+
journal={arXiv preprint arXiv:2506.09985},
|
| 448 |
+
year={2025}
|
| 449 |
+
}
|
| 450 |
+
```
|
vjepa2/app/__init__.py
ADDED
|
File without changes
|
vjepa2/app/main.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import multiprocessing as mp
|
| 8 |
+
import pprint
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
from app.scaffold import main as app_main
|
| 14 |
+
from src.utils.distributed import init_distributed
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument("--fname", type=str, help="name of config file to load", default="configs.yaml")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--devices",
|
| 20 |
+
type=str,
|
| 21 |
+
nargs="+",
|
| 22 |
+
default=["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"],
|
| 23 |
+
help="which devices to use on local machine",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--debugmode",
|
| 27 |
+
type=bool,
|
| 28 |
+
default=False,
|
| 29 |
+
help="Setting this to true will not spin up new processes. "
|
| 30 |
+
"The main code runs the main process, which makes it easier to \
|
| 31 |
+
debug with checkpointing.",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def process_main(rank, fname, world_size, devices):
|
| 36 |
+
import os
|
| 37 |
+
|
| 38 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1])
|
| 39 |
+
|
| 40 |
+
import logging
|
| 41 |
+
|
| 42 |
+
from src.utils.logging import get_logger
|
| 43 |
+
|
| 44 |
+
logger = get_logger(force=True)
|
| 45 |
+
if rank == 0:
|
| 46 |
+
logger.setLevel(logging.INFO)
|
| 47 |
+
else:
|
| 48 |
+
logger.setLevel(logging.ERROR)
|
| 49 |
+
|
| 50 |
+
logger.info(f"called-params {fname}")
|
| 51 |
+
|
| 52 |
+
# Load config
|
| 53 |
+
params = None
|
| 54 |
+
with open(fname, "r") as y_file:
|
| 55 |
+
params = yaml.load(y_file, Loader=yaml.FullLoader)
|
| 56 |
+
logger.info("loaded params...")
|
| 57 |
+
|
| 58 |
+
# Log config
|
| 59 |
+
if rank == 0:
|
| 60 |
+
pprint.PrettyPrinter(indent=4).pprint(params)
|
| 61 |
+
folder = params["folder"]
|
| 62 |
+
params_path = os.path.join(folder, "params-pretrain.yaml")
|
| 63 |
+
folder = Path(folder)
|
| 64 |
+
folder.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
with open(params_path, "w") as f:
|
| 66 |
+
yaml.dump(params, f)
|
| 67 |
+
|
| 68 |
+
# Init distributed (access to comm between GPUS on same machine)
|
| 69 |
+
world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
|
| 70 |
+
logger.info(f"Running... (rank: {rank}/{world_size})")
|
| 71 |
+
|
| 72 |
+
# Launch the app with loaded config
|
| 73 |
+
app_main(params["app"], args=params)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
args = parser.parse_args()
|
| 78 |
+
if args.debugmode:
|
| 79 |
+
process_main(rank=0, fname=args.fname, world_size=1, devices=["cuda:0"])
|
| 80 |
+
else:
|
| 81 |
+
num_gpus = len(args.devices)
|
| 82 |
+
mp.set_start_method("spawn")
|
| 83 |
+
for rank in range(num_gpus):
|
| 84 |
+
mp.Process(target=process_main, args=(rank, args.fname, num_gpus, args.devices)).start()
|
vjepa2/app/main_distributed.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import copy
|
| 8 |
+
import datetime
|
| 9 |
+
import os
|
| 10 |
+
import pprint
|
| 11 |
+
import shutil
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import submitit
|
| 15 |
+
import yaml
|
| 16 |
+
|
| 17 |
+
from app.scaffold import main as app_main
|
| 18 |
+
from src.utils.logging import get_logger, git_information
|
| 19 |
+
|
| 20 |
+
logger = get_logger(force=True)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--fname",
|
| 26 |
+
type=str,
|
| 27 |
+
help="yaml file containing config file names to launch",
|
| 28 |
+
default="configs.yaml",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument("--exclude", type=str, help="nodes to exclude from training", default=None)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--batch-launch",
|
| 33 |
+
action="store_true",
|
| 34 |
+
help="whether fname points to a file to batch-launch several config files",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--use_fname_as_folder",
|
| 38 |
+
action="store_true",
|
| 39 |
+
help="whether to append fname filename to folder",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--folder",
|
| 43 |
+
type=str,
|
| 44 |
+
default=None,
|
| 45 |
+
help="if specified, override 'folder' field in the .yaml with this",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--account",
|
| 49 |
+
type=str,
|
| 50 |
+
default="jepa",
|
| 51 |
+
help="Cluster account to use when submitting jobs",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--partition",
|
| 55 |
+
type=str,
|
| 56 |
+
default="learn",
|
| 57 |
+
help="Cluster partition to use when submitting jobs",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--qos",
|
| 61 |
+
type=str,
|
| 62 |
+
default=None,
|
| 63 |
+
help="If specified, cluster partition to use when submitting jobs",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument("--time", type=int, default=4300, help="time in minutes to run job")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Trainer:
|
| 69 |
+
def __init__(self, args_pretrain, load_model=None):
|
| 70 |
+
self.app = args_pretrain["app"]
|
| 71 |
+
self.args_pretrain = args_pretrain
|
| 72 |
+
self.load_model = load_model
|
| 73 |
+
|
| 74 |
+
def __call__(self):
|
| 75 |
+
app = self.app
|
| 76 |
+
params = self.args_pretrain
|
| 77 |
+
load_model = self.load_model
|
| 78 |
+
|
| 79 |
+
logger.info("loaded pretrain params...")
|
| 80 |
+
pp = pprint.PrettyPrinter(indent=4)
|
| 81 |
+
pp.pprint(params)
|
| 82 |
+
|
| 83 |
+
# Launch app with loaded config
|
| 84 |
+
resume_preempt = False if load_model is None else load_model
|
| 85 |
+
app_main(app, args=params, resume_preempt=resume_preempt)
|
| 86 |
+
|
| 87 |
+
def checkpoint(self):
|
| 88 |
+
fb_trainer = Trainer(self.args_pretrain, True)
|
| 89 |
+
return submitit.helpers.DelayedSubmission(
|
| 90 |
+
fb_trainer,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def copy_code_folder(code_folder, ignore_patterns, ignore_paths):
|
| 95 |
+
path_to_node_folder = {}
|
| 96 |
+
|
| 97 |
+
for path in ignore_paths:
|
| 98 |
+
split_path = path.split("/")
|
| 99 |
+
base_path = "/".join(split_path[:-1])
|
| 100 |
+
node_folder = split_path[-1]
|
| 101 |
+
path_to_node_folder[base_path] = node_folder
|
| 102 |
+
|
| 103 |
+
def ignore_func(path, names):
|
| 104 |
+
ignore_list = ignore_patterns
|
| 105 |
+
if path in path_to_node_folder.keys():
|
| 106 |
+
ignore_list.append(path_to_node_folder[path])
|
| 107 |
+
return ignore_list
|
| 108 |
+
|
| 109 |
+
if not os.path.exists(code_folder):
|
| 110 |
+
shutil.copytree(".", code_folder, ignore=ignore_func)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def update_folder_with_timestamp(args_list):
|
| 114 |
+
new_args_list = copy.deepcopy(args_list)
|
| 115 |
+
for i, args in enumerate(args_list):
|
| 116 |
+
folder = args["folder"]
|
| 117 |
+
load_checkpoint = args["meta"].get("load_checkpoint", False) if "meta" in args else False
|
| 118 |
+
if not load_checkpoint and Path(folder).exists():
|
| 119 |
+
timestamp = datetime.datetime.now().strftime("%y_%m_%d_%H_%M_%S")
|
| 120 |
+
folder = folder.rstrip("/") + f"_{timestamp}"
|
| 121 |
+
logger.info(f"Folder already exists but `load_checkpoint` is False. Logging to new folder {folder}...")
|
| 122 |
+
new_args_list[i]["folder"] = folder
|
| 123 |
+
return new_args_list
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def launch_app_with_parsed_args(
|
| 127 |
+
args_for_pretrain,
|
| 128 |
+
account,
|
| 129 |
+
partition,
|
| 130 |
+
qos,
|
| 131 |
+
mem_per_gpu="210G",
|
| 132 |
+
timeout=4300,
|
| 133 |
+
nodes=1,
|
| 134 |
+
tasks_per_node=1,
|
| 135 |
+
cpus_per_task=12,
|
| 136 |
+
exclude_nodes=None,
|
| 137 |
+
):
|
| 138 |
+
args_for_pretrain = update_folder_with_timestamp(args_for_pretrain)
|
| 139 |
+
for ap in args_for_pretrain:
|
| 140 |
+
folder = ap["folder"]
|
| 141 |
+
Path(folder).mkdir(parents=True, exist_ok=True)
|
| 142 |
+
folder = args_for_pretrain[0]["folder"]
|
| 143 |
+
|
| 144 |
+
# -------------- Copy code --------------
|
| 145 |
+
code_folder = os.path.join(folder, "code")
|
| 146 |
+
ignore_patterns = [
|
| 147 |
+
"__pycache__",
|
| 148 |
+
".vscode",
|
| 149 |
+
".git",
|
| 150 |
+
"core",
|
| 151 |
+
]
|
| 152 |
+
ignore_paths = [
|
| 153 |
+
"./evals/ava/alphaction/data",
|
| 154 |
+
"./demos",
|
| 155 |
+
"./traces",
|
| 156 |
+
]
|
| 157 |
+
copy_code_folder(code_folder, ignore_patterns, ignore_paths)
|
| 158 |
+
os.chdir(code_folder)
|
| 159 |
+
# ---------------------------------------
|
| 160 |
+
|
| 161 |
+
# -------------- Save config file --------------
|
| 162 |
+
params_path = os.path.join(folder, "params-pretrain.yaml")
|
| 163 |
+
if not os.path.exists(params_path):
|
| 164 |
+
with open(params_path, "w") as f:
|
| 165 |
+
yaml.dump(args_for_pretrain, f)
|
| 166 |
+
# ----------------------------------------------
|
| 167 |
+
|
| 168 |
+
# -------------- Save git info file --------------
|
| 169 |
+
git_info_fpath = os.path.join(folder, "git-info.txt")
|
| 170 |
+
with open(git_info_fpath, "w") as f:
|
| 171 |
+
f.write(git_information())
|
| 172 |
+
# ----------------------------------------------
|
| 173 |
+
|
| 174 |
+
# -------------- SET JOB NAME --------------
|
| 175 |
+
folder_ = folder
|
| 176 |
+
if folder[-1] == "/":
|
| 177 |
+
folder_ = folder[:-1]
|
| 178 |
+
job_name = folder_.split("/")[-1]
|
| 179 |
+
# ------------------------------------------
|
| 180 |
+
|
| 181 |
+
executor = submitit.AutoExecutor(folder=os.path.join(folder, "job_%j"), slurm_max_num_timeout=20)
|
| 182 |
+
executor.update_parameters(
|
| 183 |
+
name=job_name,
|
| 184 |
+
slurm_partition=partition,
|
| 185 |
+
slurm_account=account,
|
| 186 |
+
slurm_qos=qos,
|
| 187 |
+
slurm_mem_per_gpu=mem_per_gpu,
|
| 188 |
+
timeout_min=timeout,
|
| 189 |
+
nodes=nodes,
|
| 190 |
+
tasks_per_node=tasks_per_node,
|
| 191 |
+
cpus_per_task=cpus_per_task,
|
| 192 |
+
gpus_per_node=tasks_per_node,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if exclude_nodes is not None:
|
| 196 |
+
executor.update_parameters(slurm_exclude=exclude_nodes)
|
| 197 |
+
|
| 198 |
+
jobs, trainers = [], []
|
| 199 |
+
with executor.batch():
|
| 200 |
+
for ap in args_for_pretrain:
|
| 201 |
+
# TODO Create sub folder and ap['folder']=subfolder
|
| 202 |
+
fb_trainer = Trainer(ap)
|
| 203 |
+
job = executor.submit(
|
| 204 |
+
fb_trainer,
|
| 205 |
+
)
|
| 206 |
+
trainers.append(fb_trainer)
|
| 207 |
+
jobs.append(job)
|
| 208 |
+
|
| 209 |
+
for job in jobs:
|
| 210 |
+
print(job.job_id)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def launch():
|
| 214 |
+
# ---------------------------------------------------------------------- #
|
| 215 |
+
# 1. Put config file names in a list
|
| 216 |
+
# ---------------------------------------------------------------------- #
|
| 217 |
+
config_fnames = [args.fname]
|
| 218 |
+
|
| 219 |
+
# -- If batch-launch is True, then the args.fname yaml file is not a
|
| 220 |
+
# -- config, but actually specifies a list of other config files
|
| 221 |
+
# -- to run in a slurm job array
|
| 222 |
+
if args.batch_launch:
|
| 223 |
+
with open(args.fname, "r") as y_file:
|
| 224 |
+
config_fnames = yaml.load(y_file, Loader=yaml.FullLoader)
|
| 225 |
+
# ---------------------------------------------------------------------- #
|
| 226 |
+
|
| 227 |
+
# ---------------------------------------------------------------------- #
|
| 228 |
+
# 2. Parse each yaml config file as a dict and place in list
|
| 229 |
+
# ---------------------------------------------------------------------- #
|
| 230 |
+
nodes, tasks_per_node = None, None
|
| 231 |
+
configs = []
|
| 232 |
+
for f in config_fnames:
|
| 233 |
+
with open(f, "r") as y_file:
|
| 234 |
+
_params = yaml.load(y_file, Loader=yaml.FullLoader)
|
| 235 |
+
if args.use_fname_as_folder:
|
| 236 |
+
assert not args.folder, "Don't specify --folder if adding fname to folder"
|
| 237 |
+
_params["folder"] = str(Path(_params["folder"]) / f.split("/")[-1].split(".yaml")[0])
|
| 238 |
+
elif args.folder:
|
| 239 |
+
_params["folder"] = args.folder
|
| 240 |
+
nodes = int(_params.get("nodes"))
|
| 241 |
+
tasks_per_node = int(_params.get("tasks_per_node"))
|
| 242 |
+
cpus_per_task = int(_params.get("cpus_per_task", 32))
|
| 243 |
+
mem_per_gpu = _params.get("mem_per_gpu", "210G")
|
| 244 |
+
configs += [_params]
|
| 245 |
+
logger.info(f"Loaded {len(configs)} config files")
|
| 246 |
+
logger.info(f"Running all jobs with {nodes=} / {tasks_per_node=}")
|
| 247 |
+
# ---------------------------------------------------------------------- #
|
| 248 |
+
|
| 249 |
+
# ---------------------------------------------------------------------- #
|
| 250 |
+
# 3. Launch evals with parsed config files
|
| 251 |
+
# ---------------------------------------------------------------------- #
|
| 252 |
+
launch_app_with_parsed_args(
|
| 253 |
+
args_for_pretrain=configs,
|
| 254 |
+
account=args.account,
|
| 255 |
+
partition=args.partition,
|
| 256 |
+
qos=args.qos,
|
| 257 |
+
mem_per_gpu=mem_per_gpu,
|
| 258 |
+
cpus_per_task=cpus_per_task,
|
| 259 |
+
timeout=args.time,
|
| 260 |
+
nodes=nodes,
|
| 261 |
+
tasks_per_node=tasks_per_node,
|
| 262 |
+
exclude_nodes=args.exclude,
|
| 263 |
+
)
|
| 264 |
+
# ---------------------------------------------------------------------- #
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
args = parser.parse_args()
|
| 269 |
+
launch()
|
vjepa2/app/scaffold.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import logging
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
| 11 |
+
logger = logging.getLogger()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main(app, args, resume_preempt=False):
|
| 15 |
+
|
| 16 |
+
logger.info(f"Running pre-training of app: {app}")
|
| 17 |
+
return importlib.import_module(f"app.{app}.train").main(args=args, resume_preempt=resume_preempt)
|
vjepa2/app/vjepa/train.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
|
| 9 |
+
try:
|
| 10 |
+
# -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
|
| 11 |
+
# -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
|
| 12 |
+
# -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
|
| 13 |
+
# -- TO EACH PROCESS
|
| 14 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"]
|
| 15 |
+
except Exception:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
import copy
|
| 19 |
+
import gc
|
| 20 |
+
import random
|
| 21 |
+
import time
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.multiprocessing as mp
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 28 |
+
|
| 29 |
+
from app.vjepa.transforms import make_transforms
|
| 30 |
+
from app.vjepa.utils import init_opt, init_video_model, load_checkpoint
|
| 31 |
+
from src.datasets.data_manager import init_data
|
| 32 |
+
from src.masks.multiseq_multiblock3d import MaskCollator
|
| 33 |
+
from src.masks.utils import apply_masks
|
| 34 |
+
from src.utils.distributed import init_distributed
|
| 35 |
+
from src.utils.logging import AverageMeter, CSVLogger, get_logger, gpu_timer
|
| 36 |
+
|
| 37 |
+
# --
|
| 38 |
+
log_timings = True
|
| 39 |
+
log_freq = 10
|
| 40 |
+
CHECKPOINT_FREQ = 1
|
| 41 |
+
GARBAGE_COLLECT_ITR_FREQ = 50
|
| 42 |
+
# --
|
| 43 |
+
|
| 44 |
+
_GLOBAL_SEED = 0
|
| 45 |
+
random.seed(_GLOBAL_SEED)
|
| 46 |
+
np.random.seed(_GLOBAL_SEED)
|
| 47 |
+
torch.manual_seed(_GLOBAL_SEED)
|
| 48 |
+
torch.backends.cudnn.benchmark = True
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
logger = get_logger(__name__, force=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main(args, resume_preempt=False):
|
| 55 |
+
# ----------------------------------------------------------------------- #
|
| 56 |
+
# PASSED IN PARAMS FROM CONFIG FILE
|
| 57 |
+
# ----------------------------------------------------------------------- #
|
| 58 |
+
|
| 59 |
+
# -- META
|
| 60 |
+
folder = args.get("folder")
|
| 61 |
+
cfgs_meta = args.get("meta")
|
| 62 |
+
load_model = cfgs_meta.get("load_checkpoint") or resume_preempt
|
| 63 |
+
r_file = cfgs_meta.get("read_checkpoint", None)
|
| 64 |
+
seed = cfgs_meta.get("seed", _GLOBAL_SEED)
|
| 65 |
+
save_every_freq = cfgs_meta.get("save_every_freq", -1)
|
| 66 |
+
skip_batches = cfgs_meta.get("skip_batches", -1)
|
| 67 |
+
use_sdpa = cfgs_meta.get("use_sdpa", False)
|
| 68 |
+
sync_gc = cfgs_meta.get("sync_gc", False)
|
| 69 |
+
which_dtype = cfgs_meta.get("dtype")
|
| 70 |
+
logger.info(f"{which_dtype=}")
|
| 71 |
+
if which_dtype.lower() == "bfloat16":
|
| 72 |
+
dtype = torch.bfloat16
|
| 73 |
+
mixed_precision = True
|
| 74 |
+
elif which_dtype.lower() == "float16":
|
| 75 |
+
dtype = torch.float16
|
| 76 |
+
mixed_precision = True
|
| 77 |
+
else:
|
| 78 |
+
dtype = torch.float32
|
| 79 |
+
mixed_precision = False
|
| 80 |
+
|
| 81 |
+
# -- MASK
|
| 82 |
+
cfgs_mask = args.get("mask")
|
| 83 |
+
|
| 84 |
+
# -- MODEL
|
| 85 |
+
cfgs_model = args.get("model")
|
| 86 |
+
compile_model = cfgs_model.get("compile_model", False)
|
| 87 |
+
use_activation_checkpointing = cfgs_model.get("use_activation_checkpointing", False)
|
| 88 |
+
model_name = cfgs_model.get("model_name")
|
| 89 |
+
pred_depth = cfgs_model.get("pred_depth")
|
| 90 |
+
pred_num_heads = cfgs_model.get("pred_num_heads", None)
|
| 91 |
+
pred_embed_dim = cfgs_model.get("pred_embed_dim")
|
| 92 |
+
uniform_power = cfgs_model.get("uniform_power", False)
|
| 93 |
+
use_mask_tokens = cfgs_model.get("use_mask_tokens", False)
|
| 94 |
+
zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True)
|
| 95 |
+
use_rope = cfgs_model.get("use_rope", False)
|
| 96 |
+
use_silu = cfgs_model.get("use_silu", False)
|
| 97 |
+
use_pred_silu = cfgs_model.get("use_pred_silu", False)
|
| 98 |
+
wide_silu = cfgs_model.get("wide_silu", True)
|
| 99 |
+
|
| 100 |
+
# -- DATA
|
| 101 |
+
cfgs_data = args.get("data")
|
| 102 |
+
dataset_type = cfgs_data.get("dataset_type", "videodataset")
|
| 103 |
+
dataset_paths = cfgs_data.get("datasets", [])
|
| 104 |
+
datasets_weights = cfgs_data.get("datasets_weights")
|
| 105 |
+
dataset_fpcs = cfgs_data.get("dataset_fpcs")
|
| 106 |
+
max_num_frames = max(dataset_fpcs)
|
| 107 |
+
if datasets_weights is not None:
|
| 108 |
+
assert len(datasets_weights) == len(dataset_paths), "Must have one sampling weight specified for each dataset"
|
| 109 |
+
batch_size = cfgs_data.get("batch_size")
|
| 110 |
+
tubelet_size = cfgs_data.get("tubelet_size")
|
| 111 |
+
fps = cfgs_data.get("fps")
|
| 112 |
+
crop_size = cfgs_data.get("crop_size", 224)
|
| 113 |
+
patch_size = cfgs_data.get("patch_size")
|
| 114 |
+
pin_mem = cfgs_data.get("pin_mem", False)
|
| 115 |
+
num_workers = cfgs_data.get("num_workers", 1)
|
| 116 |
+
persistent_workers = cfgs_data.get("persistent_workers", True)
|
| 117 |
+
|
| 118 |
+
# -- DATA AUGS
|
| 119 |
+
cfgs_data_aug = args.get("data_aug")
|
| 120 |
+
ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3])
|
| 121 |
+
rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0])
|
| 122 |
+
motion_shift = cfgs_data_aug.get("motion_shift", False)
|
| 123 |
+
reprob = cfgs_data_aug.get("reprob", 0.0)
|
| 124 |
+
use_aa = cfgs_data_aug.get("auto_augment", False)
|
| 125 |
+
|
| 126 |
+
# -- LOSS
|
| 127 |
+
cfgs_loss = args.get("loss")
|
| 128 |
+
loss_exp = cfgs_loss.get("loss_exp")
|
| 129 |
+
|
| 130 |
+
# -- OPTIMIZATION
|
| 131 |
+
cfgs_opt = args.get("optimization")
|
| 132 |
+
is_anneal = cfgs_opt.get("is_anneal", False)
|
| 133 |
+
anneal_ckpt = cfgs_opt.get("anneal_ckpt", None)
|
| 134 |
+
if is_anneal and anneal_ckpt is None:
|
| 135 |
+
raise ValueError("Must specify anneal_ckpt if is_anneal is True")
|
| 136 |
+
resume_anneal = cfgs_opt.get("resume_anneal", False) or (is_anneal and resume_preempt)
|
| 137 |
+
ipe = cfgs_opt.get("ipe", None)
|
| 138 |
+
ipe_scale = cfgs_opt.get("ipe_scale", 1.0)
|
| 139 |
+
wd = float(cfgs_opt.get("weight_decay"))
|
| 140 |
+
final_wd = float(cfgs_opt.get("final_weight_decay"))
|
| 141 |
+
num_epochs = cfgs_opt.get("epochs")
|
| 142 |
+
warmup = cfgs_opt.get("warmup")
|
| 143 |
+
start_lr = cfgs_opt.get("start_lr")
|
| 144 |
+
lr = cfgs_opt.get("lr")
|
| 145 |
+
final_lr = cfgs_opt.get("final_lr")
|
| 146 |
+
ema = cfgs_opt.get("ema")
|
| 147 |
+
betas = cfgs_opt.get("betas", (0.9, 0.999))
|
| 148 |
+
eps = cfgs_opt.get("eps", 1.0e-8)
|
| 149 |
+
# ----------------------------------------------------------------------- #
|
| 150 |
+
# ----------------------------------------------------------------------- #
|
| 151 |
+
|
| 152 |
+
np.random.seed(seed)
|
| 153 |
+
torch.manual_seed(seed)
|
| 154 |
+
torch.backends.cudnn.benchmark = True
|
| 155 |
+
try:
|
| 156 |
+
mp.set_start_method("spawn")
|
| 157 |
+
except Exception:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
# -- init torch distributed backend
|
| 161 |
+
world_size, rank = init_distributed()
|
| 162 |
+
logger.info(f"Initialized (rank/world-size) {rank}/{world_size}")
|
| 163 |
+
|
| 164 |
+
# -- set device
|
| 165 |
+
if not torch.cuda.is_available():
|
| 166 |
+
device = torch.device("cpu")
|
| 167 |
+
else:
|
| 168 |
+
device = torch.device("cuda:0")
|
| 169 |
+
torch.cuda.set_device(device)
|
| 170 |
+
|
| 171 |
+
# -- log/checkpointing paths
|
| 172 |
+
log_file = os.path.join(folder, f"log_r{rank}.csv")
|
| 173 |
+
latest_file = "latest.pt"
|
| 174 |
+
latest_path = os.path.join(folder, latest_file)
|
| 175 |
+
load_path = None
|
| 176 |
+
if load_model:
|
| 177 |
+
if is_anneal:
|
| 178 |
+
if os.path.exists(latest_path) and resume_anneal:
|
| 179 |
+
load_path = latest_path
|
| 180 |
+
else:
|
| 181 |
+
load_path = anneal_ckpt
|
| 182 |
+
resume_anneal = False
|
| 183 |
+
else:
|
| 184 |
+
load_path = r_file if r_file is not None else latest_path
|
| 185 |
+
if not os.path.exists(load_path):
|
| 186 |
+
load_path = None
|
| 187 |
+
load_model = False
|
| 188 |
+
|
| 189 |
+
# -- make csv_logger
|
| 190 |
+
csv_logger = CSVLogger(
|
| 191 |
+
log_file,
|
| 192 |
+
("%d", "epoch"),
|
| 193 |
+
("%d", "itr"),
|
| 194 |
+
("%.5f", "loss"),
|
| 195 |
+
("%d", "iter-time(ms)"),
|
| 196 |
+
("%d", "gpu-time(ms)"),
|
| 197 |
+
("%d", "dataload-time(ms)"),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# -- init model
|
| 201 |
+
encoder, predictor = init_video_model(
|
| 202 |
+
uniform_power=uniform_power,
|
| 203 |
+
use_mask_tokens=use_mask_tokens,
|
| 204 |
+
num_mask_tokens=int(len(cfgs_mask) * len(dataset_fpcs)),
|
| 205 |
+
zero_init_mask_tokens=zero_init_mask_tokens,
|
| 206 |
+
device=device,
|
| 207 |
+
patch_size=patch_size,
|
| 208 |
+
max_num_frames=max_num_frames,
|
| 209 |
+
tubelet_size=tubelet_size,
|
| 210 |
+
model_name=model_name,
|
| 211 |
+
crop_size=crop_size,
|
| 212 |
+
pred_depth=pred_depth,
|
| 213 |
+
pred_num_heads=pred_num_heads,
|
| 214 |
+
pred_embed_dim=pred_embed_dim,
|
| 215 |
+
use_sdpa=use_sdpa,
|
| 216 |
+
use_silu=use_silu,
|
| 217 |
+
use_pred_silu=use_pred_silu,
|
| 218 |
+
wide_silu=wide_silu,
|
| 219 |
+
use_rope=use_rope,
|
| 220 |
+
use_activation_checkpointing=use_activation_checkpointing,
|
| 221 |
+
)
|
| 222 |
+
target_encoder = copy.deepcopy(encoder)
|
| 223 |
+
|
| 224 |
+
if compile_model:
|
| 225 |
+
logger.info("Compiling encoder, target_encoder, and predictor.")
|
| 226 |
+
torch._dynamo.config.optimize_ddp = False
|
| 227 |
+
encoder.compile()
|
| 228 |
+
target_encoder.compile()
|
| 229 |
+
predictor.compile()
|
| 230 |
+
|
| 231 |
+
mask_collator = MaskCollator(
|
| 232 |
+
cfgs_mask=cfgs_mask,
|
| 233 |
+
dataset_fpcs=dataset_fpcs,
|
| 234 |
+
crop_size=crop_size,
|
| 235 |
+
patch_size=patch_size,
|
| 236 |
+
tubelet_size=tubelet_size,
|
| 237 |
+
)
|
| 238 |
+
transform = make_transforms(
|
| 239 |
+
random_horizontal_flip=True,
|
| 240 |
+
random_resize_aspect_ratio=ar_range,
|
| 241 |
+
random_resize_scale=rr_scale,
|
| 242 |
+
reprob=reprob,
|
| 243 |
+
auto_augment=use_aa,
|
| 244 |
+
motion_shift=motion_shift,
|
| 245 |
+
crop_size=crop_size,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# -- init data-loaders/samplers
|
| 249 |
+
(unsupervised_loader, unsupervised_sampler) = init_data(
|
| 250 |
+
data=dataset_type,
|
| 251 |
+
root_path=dataset_paths,
|
| 252 |
+
batch_size=batch_size,
|
| 253 |
+
training=True,
|
| 254 |
+
dataset_fpcs=dataset_fpcs,
|
| 255 |
+
fps=fps,
|
| 256 |
+
transform=transform,
|
| 257 |
+
rank=rank,
|
| 258 |
+
world_size=world_size,
|
| 259 |
+
datasets_weights=datasets_weights,
|
| 260 |
+
persistent_workers=persistent_workers,
|
| 261 |
+
collator=mask_collator,
|
| 262 |
+
num_workers=num_workers,
|
| 263 |
+
pin_mem=pin_mem,
|
| 264 |
+
log_dir=None,
|
| 265 |
+
)
|
| 266 |
+
try:
|
| 267 |
+
_dlen = len(unsupervised_loader)
|
| 268 |
+
except Exception: # Different interface for webdataset
|
| 269 |
+
_dlen = unsupervised_loader.num_batches
|
| 270 |
+
if ipe is None:
|
| 271 |
+
ipe = _dlen
|
| 272 |
+
logger.info(f"iterations per epoch/dataset length: {ipe}/{_dlen}")
|
| 273 |
+
|
| 274 |
+
# -- init optimizer and scheduler
|
| 275 |
+
optimizer, scaler, scheduler, wd_scheduler = init_opt(
|
| 276 |
+
is_anneal=is_anneal,
|
| 277 |
+
encoder=encoder,
|
| 278 |
+
predictor=predictor,
|
| 279 |
+
wd=wd,
|
| 280 |
+
final_wd=final_wd,
|
| 281 |
+
start_lr=start_lr,
|
| 282 |
+
ref_lr=lr,
|
| 283 |
+
final_lr=final_lr,
|
| 284 |
+
iterations_per_epoch=ipe,
|
| 285 |
+
warmup=warmup,
|
| 286 |
+
num_epochs=num_epochs,
|
| 287 |
+
ipe_scale=ipe_scale,
|
| 288 |
+
mixed_precision=mixed_precision,
|
| 289 |
+
betas=betas,
|
| 290 |
+
eps=eps,
|
| 291 |
+
)
|
| 292 |
+
encoder = DistributedDataParallel(encoder, static_graph=True, find_unused_parameters=False)
|
| 293 |
+
predictor = DistributedDataParallel(predictor, static_graph=False, find_unused_parameters=False)
|
| 294 |
+
target_encoder = DistributedDataParallel(target_encoder, static_graph=True, find_unused_parameters=False)
|
| 295 |
+
for p in target_encoder.parameters():
|
| 296 |
+
p.requires_grad = False
|
| 297 |
+
|
| 298 |
+
# -- momentum schedule
|
| 299 |
+
momentum_scheduler = (
|
| 300 |
+
ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)
|
| 301 |
+
for i in range(int(ipe * num_epochs * ipe_scale) + 1)
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
start_epoch = 0
|
| 305 |
+
# -- load training checkpoint
|
| 306 |
+
if load_model or os.path.exists(latest_path):
|
| 307 |
+
(
|
| 308 |
+
encoder,
|
| 309 |
+
predictor,
|
| 310 |
+
target_encoder,
|
| 311 |
+
optimizer,
|
| 312 |
+
scaler,
|
| 313 |
+
start_epoch,
|
| 314 |
+
) = load_checkpoint(
|
| 315 |
+
r_path=load_path,
|
| 316 |
+
encoder=encoder,
|
| 317 |
+
predictor=predictor,
|
| 318 |
+
target_encoder=target_encoder,
|
| 319 |
+
opt=optimizer,
|
| 320 |
+
scaler=scaler,
|
| 321 |
+
is_anneal=is_anneal and not resume_anneal,
|
| 322 |
+
)
|
| 323 |
+
if not is_anneal or resume_anneal:
|
| 324 |
+
for _ in range(start_epoch * ipe):
|
| 325 |
+
scheduler.step()
|
| 326 |
+
wd_scheduler.step()
|
| 327 |
+
next(momentum_scheduler)
|
| 328 |
+
mask_collator.step()
|
| 329 |
+
|
| 330 |
+
def save_checkpoint(epoch, path):
|
| 331 |
+
if rank != 0:
|
| 332 |
+
return
|
| 333 |
+
save_dict = {
|
| 334 |
+
"encoder": encoder.state_dict(),
|
| 335 |
+
"predictor": predictor.state_dict(),
|
| 336 |
+
"opt": optimizer.state_dict(),
|
| 337 |
+
"scaler": None if scaler is None else scaler.state_dict(),
|
| 338 |
+
"target_encoder": target_encoder.state_dict(),
|
| 339 |
+
"epoch": epoch,
|
| 340 |
+
"loss": loss_meter.avg,
|
| 341 |
+
"batch_size": batch_size,
|
| 342 |
+
"world_size": world_size,
|
| 343 |
+
"lr": lr,
|
| 344 |
+
}
|
| 345 |
+
try:
|
| 346 |
+
torch.save(save_dict, path)
|
| 347 |
+
except Exception as e:
|
| 348 |
+
logger.info(f"Encountered exception when saving checkpoint: {e}")
|
| 349 |
+
|
| 350 |
+
logger.info("Initializing loader...")
|
| 351 |
+
unsupervised_sampler.set_epoch(start_epoch)
|
| 352 |
+
loader = iter(unsupervised_loader)
|
| 353 |
+
|
| 354 |
+
if skip_batches > 0:
|
| 355 |
+
logger.info(f"Skip {skip_batches} batches")
|
| 356 |
+
# -- update distributed-data-loader epoch
|
| 357 |
+
|
| 358 |
+
for itr in range(skip_batches):
|
| 359 |
+
if itr % 10 == 0:
|
| 360 |
+
logger.info(f"Skip {itr}/{skip_batches} batches")
|
| 361 |
+
try:
|
| 362 |
+
_ = next(loader)
|
| 363 |
+
except Exception:
|
| 364 |
+
loader = iter(unsupervised_loader)
|
| 365 |
+
_ = next(loader)
|
| 366 |
+
|
| 367 |
+
if sync_gc:
|
| 368 |
+
gc.disable()
|
| 369 |
+
gc.collect()
|
| 370 |
+
|
| 371 |
+
# -- TRAINING LOOP
|
| 372 |
+
for epoch in range(start_epoch, num_epochs):
|
| 373 |
+
logger.info("Epoch %d" % (epoch + 1))
|
| 374 |
+
|
| 375 |
+
loss_meter = AverageMeter()
|
| 376 |
+
mask_meters = {fpc: AverageMeter() for fpc in dataset_fpcs}
|
| 377 |
+
iter_time_meter = AverageMeter()
|
| 378 |
+
gpu_time_meter = AverageMeter()
|
| 379 |
+
data_elapsed_time_meter = AverageMeter()
|
| 380 |
+
|
| 381 |
+
for itr in range(ipe):
|
| 382 |
+
itr_start_time = time.time()
|
| 383 |
+
|
| 384 |
+
iter_retries = 0
|
| 385 |
+
iter_successful = False
|
| 386 |
+
while not iter_successful:
|
| 387 |
+
try:
|
| 388 |
+
sample = next(loader)
|
| 389 |
+
iter_successful = True
|
| 390 |
+
except StopIteration:
|
| 391 |
+
logger.info("Exhausted data loaders. Refreshing...")
|
| 392 |
+
unsupervised_sampler.set_epoch(epoch)
|
| 393 |
+
loader = iter(unsupervised_loader)
|
| 394 |
+
except Exception as e:
|
| 395 |
+
NUM_RETRIES = 5
|
| 396 |
+
if iter_retries < NUM_RETRIES:
|
| 397 |
+
logger.warning(f"Encountered exception when loading data (num retries {iter_retries}):\n{e}")
|
| 398 |
+
iter_retries += 1
|
| 399 |
+
time.sleep(5)
|
| 400 |
+
else:
|
| 401 |
+
logger.warning(f"Exceeded max retries ({NUM_RETRIES}) when loading data. Skipping batch.")
|
| 402 |
+
raise e
|
| 403 |
+
|
| 404 |
+
for _fpc_sample in sample:
|
| 405 |
+
bs, fpc = _fpc_sample[0][-1][0].size()
|
| 406 |
+
mask_meters[fpc].update(bs / batch_size)
|
| 407 |
+
|
| 408 |
+
def load_clips():
|
| 409 |
+
all_clips, all_masks_enc, all_masks_pred = [], [], []
|
| 410 |
+
for fpc_sample in sample:
|
| 411 |
+
udata, masks_enc, masks_pred = fpc_sample
|
| 412 |
+
all_clips += [udata[0][0].to(device, non_blocking=True)]
|
| 413 |
+
all_masks_enc += [[m.to(device, non_blocking=True) for m in masks_enc]]
|
| 414 |
+
all_masks_pred += [[m.to(device, non_blocking=True) for m in masks_pred]]
|
| 415 |
+
return all_clips, all_masks_enc, all_masks_pred
|
| 416 |
+
|
| 417 |
+
clips, masks_enc, masks_pred = load_clips()
|
| 418 |
+
data_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
|
| 419 |
+
|
| 420 |
+
if sync_gc and (itr + 1) % GARBAGE_COLLECT_ITR_FREQ == 0:
|
| 421 |
+
logger.info("Running garbage collection...")
|
| 422 |
+
gc.collect()
|
| 423 |
+
|
| 424 |
+
def train_step():
|
| 425 |
+
_new_lr = scheduler.step()
|
| 426 |
+
_new_wd = wd_scheduler.step()
|
| 427 |
+
# --
|
| 428 |
+
|
| 429 |
+
def forward_target(c):
|
| 430 |
+
with torch.no_grad():
|
| 431 |
+
h = target_encoder(c)
|
| 432 |
+
h = [F.layer_norm(hi, (hi.size(-1),)) for hi in h]
|
| 433 |
+
return h
|
| 434 |
+
|
| 435 |
+
def forward_context(c):
|
| 436 |
+
z = encoder(c, masks_enc)
|
| 437 |
+
z = predictor(z, masks_enc, masks_pred)
|
| 438 |
+
return z
|
| 439 |
+
|
| 440 |
+
def loss_fn(z, h):
|
| 441 |
+
# Assumption: predictor will have returned only masked tokens for z
|
| 442 |
+
h = [apply_masks(hi, mi, concat=False) for hi, mi in zip(h, masks_pred)]
|
| 443 |
+
|
| 444 |
+
loss, n = 0, 0
|
| 445 |
+
for zi, hi in zip(z, h):
|
| 446 |
+
for zij, hij in zip(zi, hi):
|
| 447 |
+
loss += torch.mean(torch.abs(zij - hij) ** loss_exp) / loss_exp
|
| 448 |
+
n += 1
|
| 449 |
+
loss /= n
|
| 450 |
+
return loss
|
| 451 |
+
|
| 452 |
+
# Step 1. Forward
|
| 453 |
+
with torch.amp.autocast('cuda', dtype=dtype, enabled=mixed_precision):
|
| 454 |
+
h = forward_target(clips)
|
| 455 |
+
z = forward_context(clips)
|
| 456 |
+
loss = loss_fn(z, h) # jepa prediction loss
|
| 457 |
+
|
| 458 |
+
# Step 2. Backward & step
|
| 459 |
+
if mixed_precision:
|
| 460 |
+
scaler.scale(loss).backward()
|
| 461 |
+
scaler.unscale_(optimizer)
|
| 462 |
+
else:
|
| 463 |
+
loss.backward()
|
| 464 |
+
if mixed_precision:
|
| 465 |
+
scaler.step(optimizer)
|
| 466 |
+
scaler.update()
|
| 467 |
+
else:
|
| 468 |
+
optimizer.step()
|
| 469 |
+
optimizer.zero_grad()
|
| 470 |
+
|
| 471 |
+
# Step 3. momentum update of target encoder
|
| 472 |
+
m = next(momentum_scheduler)
|
| 473 |
+
with torch.no_grad():
|
| 474 |
+
params_k = []
|
| 475 |
+
params_q = []
|
| 476 |
+
for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
|
| 477 |
+
params_k.append(param_k)
|
| 478 |
+
params_q.append(param_q)
|
| 479 |
+
torch._foreach_mul_(params_k, m)
|
| 480 |
+
torch._foreach_add_(params_k, params_q, alpha=1 - m)
|
| 481 |
+
|
| 482 |
+
return (
|
| 483 |
+
loss.detach().item(),
|
| 484 |
+
_new_lr,
|
| 485 |
+
_new_wd,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
(
|
| 489 |
+
loss,
|
| 490 |
+
_new_lr,
|
| 491 |
+
_new_wd,
|
| 492 |
+
), gpu_etime_ms = gpu_timer(train_step)
|
| 493 |
+
iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
|
| 494 |
+
loss_meter.update(loss)
|
| 495 |
+
iter_time_meter.update(iter_elapsed_time_ms)
|
| 496 |
+
gpu_time_meter.update(gpu_etime_ms)
|
| 497 |
+
data_elapsed_time_meter.update(data_elapsed_time_ms)
|
| 498 |
+
|
| 499 |
+
# -- Logging
|
| 500 |
+
def log_stats():
|
| 501 |
+
csv_logger.log(epoch + 1, itr, loss, iter_elapsed_time_ms, gpu_etime_ms, data_elapsed_time_ms)
|
| 502 |
+
if (itr % log_freq == 0) or (itr == ipe - 1) or np.isnan(loss) or np.isinf(loss):
|
| 503 |
+
logger.info(
|
| 504 |
+
"[%d, %5d] loss: %.3f "
|
| 505 |
+
"masks: %s "
|
| 506 |
+
"[wd: %.2e] [lr: %.2e] "
|
| 507 |
+
"[mem: %.2e] "
|
| 508 |
+
"[iter: %.1f ms] "
|
| 509 |
+
"[gpu: %.1f ms] "
|
| 510 |
+
"[data: %.1f ms]"
|
| 511 |
+
% (
|
| 512 |
+
epoch + 1,
|
| 513 |
+
itr,
|
| 514 |
+
loss_meter.avg,
|
| 515 |
+
"[" + ", ".join([f"{k}: " + "%.1f" % mask_meters[k].avg for k in mask_meters]) + "]",
|
| 516 |
+
_new_wd,
|
| 517 |
+
_new_lr,
|
| 518 |
+
torch.cuda.max_memory_allocated() / 1024.0**2,
|
| 519 |
+
iter_time_meter.avg,
|
| 520 |
+
gpu_time_meter.avg,
|
| 521 |
+
data_elapsed_time_meter.avg,
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
log_stats()
|
| 526 |
+
assert not np.isnan(loss), "loss is nan"
|
| 527 |
+
|
| 528 |
+
# -- Save Checkpoint
|
| 529 |
+
logger.info("avg. loss %.3f" % loss_meter.avg)
|
| 530 |
+
# -- Save Last
|
| 531 |
+
if epoch % CHECKPOINT_FREQ == 0 or epoch == (num_epochs - 1):
|
| 532 |
+
save_checkpoint(epoch + 1, latest_path)
|
| 533 |
+
if save_every_freq > 0 and epoch % save_every_freq == 0:
|
| 534 |
+
save_every_file = f"e{epoch}.pt"
|
| 535 |
+
save_every_path = os.path.join(folder, save_every_file)
|
| 536 |
+
save_checkpoint(epoch + 1, save_every_path)
|
vjepa2/app/vjepa/transforms.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
|
| 9 |
+
import src.datasets.utils.video.transforms as video_transforms
|
| 10 |
+
from src.datasets.utils.video.randerase import RandomErasing
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def make_transforms(
|
| 14 |
+
random_horizontal_flip=True,
|
| 15 |
+
random_resize_aspect_ratio=(3 / 4, 4 / 3),
|
| 16 |
+
random_resize_scale=(0.3, 1.0),
|
| 17 |
+
reprob=0.0,
|
| 18 |
+
auto_augment=False,
|
| 19 |
+
motion_shift=False,
|
| 20 |
+
crop_size=224,
|
| 21 |
+
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 22 |
+
):
|
| 23 |
+
|
| 24 |
+
_frames_augmentation = VideoTransform(
|
| 25 |
+
random_horizontal_flip=random_horizontal_flip,
|
| 26 |
+
random_resize_aspect_ratio=random_resize_aspect_ratio,
|
| 27 |
+
random_resize_scale=random_resize_scale,
|
| 28 |
+
reprob=reprob,
|
| 29 |
+
auto_augment=auto_augment,
|
| 30 |
+
motion_shift=motion_shift,
|
| 31 |
+
crop_size=crop_size,
|
| 32 |
+
normalize=normalize,
|
| 33 |
+
)
|
| 34 |
+
return _frames_augmentation
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class VideoTransform(object):
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
random_horizontal_flip=True,
|
| 42 |
+
random_resize_aspect_ratio=(3 / 4, 4 / 3),
|
| 43 |
+
random_resize_scale=(0.3, 1.0),
|
| 44 |
+
reprob=0.0,
|
| 45 |
+
auto_augment=False,
|
| 46 |
+
motion_shift=False,
|
| 47 |
+
crop_size=224,
|
| 48 |
+
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 49 |
+
):
|
| 50 |
+
|
| 51 |
+
self.random_horizontal_flip = random_horizontal_flip
|
| 52 |
+
self.random_resize_aspect_ratio = random_resize_aspect_ratio
|
| 53 |
+
self.random_resize_scale = random_resize_scale
|
| 54 |
+
self.auto_augment = auto_augment
|
| 55 |
+
self.motion_shift = motion_shift
|
| 56 |
+
self.crop_size = crop_size
|
| 57 |
+
self.mean = torch.tensor(normalize[0], dtype=torch.float32)
|
| 58 |
+
self.std = torch.tensor(normalize[1], dtype=torch.float32)
|
| 59 |
+
if not self.auto_augment:
|
| 60 |
+
# Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
|
| 61 |
+
self.mean *= 255.0
|
| 62 |
+
self.std *= 255.0
|
| 63 |
+
|
| 64 |
+
self.autoaug_transform = video_transforms.create_random_augment(
|
| 65 |
+
input_size=(crop_size, crop_size),
|
| 66 |
+
# auto_augment="rand-m4-n4-w1-mstd0.5-inc1",
|
| 67 |
+
auto_augment="rand-m7-n4-mstd0.5-inc1",
|
| 68 |
+
interpolation="bicubic",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.spatial_transform = (
|
| 72 |
+
video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.reprob = reprob
|
| 76 |
+
self.erase_transform = RandomErasing(
|
| 77 |
+
reprob,
|
| 78 |
+
mode="pixel",
|
| 79 |
+
max_count=1,
|
| 80 |
+
num_splits=1,
|
| 81 |
+
device="cpu",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def __call__(self, buffer):
|
| 85 |
+
|
| 86 |
+
if self.auto_augment:
|
| 87 |
+
buffer = [transforms.ToPILImage()(frame) for frame in buffer]
|
| 88 |
+
buffer = self.autoaug_transform(buffer)
|
| 89 |
+
buffer = [transforms.ToTensor()(img) for img in buffer]
|
| 90 |
+
buffer = torch.stack(buffer) # T C H W
|
| 91 |
+
buffer = buffer.permute(0, 2, 3, 1) # T H W C
|
| 92 |
+
elif torch.is_tensor(buffer):
|
| 93 |
+
# TODO: ensure input is always a tensor?
|
| 94 |
+
buffer = buffer.to(torch.float32)
|
| 95 |
+
else:
|
| 96 |
+
buffer = torch.tensor(buffer, dtype=torch.float32)
|
| 97 |
+
|
| 98 |
+
buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
|
| 99 |
+
|
| 100 |
+
buffer = self.spatial_transform(
|
| 101 |
+
images=buffer,
|
| 102 |
+
target_height=self.crop_size,
|
| 103 |
+
target_width=self.crop_size,
|
| 104 |
+
scale=self.random_resize_scale,
|
| 105 |
+
ratio=self.random_resize_aspect_ratio,
|
| 106 |
+
)
|
| 107 |
+
if self.random_horizontal_flip:
|
| 108 |
+
buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
|
| 109 |
+
|
| 110 |
+
buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
|
| 111 |
+
if self.reprob > 0:
|
| 112 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
| 113 |
+
buffer = self.erase_transform(buffer)
|
| 114 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
| 115 |
+
|
| 116 |
+
return buffer
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def tensor_normalize(tensor, mean, std):
|
| 120 |
+
"""
|
| 121 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
| 122 |
+
Args:
|
| 123 |
+
tensor (tensor): tensor to normalize.
|
| 124 |
+
mean (tensor or list): mean value to subtract.
|
| 125 |
+
std (tensor or list): std to divide.
|
| 126 |
+
"""
|
| 127 |
+
if tensor.dtype == torch.uint8:
|
| 128 |
+
tensor = tensor.float()
|
| 129 |
+
tensor = tensor / 255.0
|
| 130 |
+
if isinstance(mean, list):
|
| 131 |
+
mean = torch.tensor(mean)
|
| 132 |
+
if isinstance(std, list):
|
| 133 |
+
std = torch.tensor(std)
|
| 134 |
+
tensor = tensor - mean
|
| 135 |
+
tensor = tensor / std
|
| 136 |
+
return tensor
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _tensor_normalize_inplace(tensor, mean, std):
|
| 140 |
+
"""
|
| 141 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
| 142 |
+
Args:
|
| 143 |
+
tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
|
| 144 |
+
mean (tensor): mean value to subtract (in 0 to 255 floats).
|
| 145 |
+
std (tensor): std to divide (in 0 to 255 floats).
|
| 146 |
+
"""
|
| 147 |
+
if tensor.dtype == torch.uint8:
|
| 148 |
+
tensor = tensor.float()
|
| 149 |
+
|
| 150 |
+
C, T, H, W = tensor.shape
|
| 151 |
+
tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
|
| 152 |
+
tensor.sub_(mean).div_(std)
|
| 153 |
+
tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
|
| 154 |
+
return tensor
|
vjepa2/app/vjepa/utils.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import sys
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
import src.models.predictor as vit_pred
|
| 14 |
+
import src.models.vision_transformer as video_vit
|
| 15 |
+
from src.utils.checkpoint_loader import robust_checkpoint_loader
|
| 16 |
+
from src.utils.schedulers import CosineWDSchedule, LinearDecaySchedule, WarmupCosineSchedule
|
| 17 |
+
from src.utils.wrappers import MultiSeqWrapper, PredictorMultiSeqWrapper
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
| 20 |
+
logger = logging.getLogger()
|
| 21 |
+
|
| 22 |
+
MAX_RETRIES = 3
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_eval_args(
|
| 26 |
+
model_name,
|
| 27 |
+
patch_size,
|
| 28 |
+
tubelet_size,
|
| 29 |
+
num_frames,
|
| 30 |
+
logging_folder,
|
| 31 |
+
checkpoint,
|
| 32 |
+
write_tag,
|
| 33 |
+
eval_cfg_paths,
|
| 34 |
+
uniform_power=False,
|
| 35 |
+
use_sdpa=False,
|
| 36 |
+
clip_duration=None,
|
| 37 |
+
use_silu=False,
|
| 38 |
+
wide_silu=True,
|
| 39 |
+
tag=None,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Helper function to parse the pre-training configs to construct the
|
| 43 |
+
evaluation configs, return as a list of eval configs.
|
| 44 |
+
"""
|
| 45 |
+
# By convention, the pre-training config should specify any required evals
|
| 46 |
+
# in the 'evals' key
|
| 47 |
+
if eval_cfg_paths is None:
|
| 48 |
+
logger.info("No evaluations specified!")
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
eval_nodes = None
|
| 52 |
+
eval_tasks_per_node = None
|
| 53 |
+
args_eval = []
|
| 54 |
+
for i, f in enumerate(eval_cfg_paths):
|
| 55 |
+
with open(f, "r") as y_file:
|
| 56 |
+
_args = yaml.load(y_file, Loader=yaml.FullLoader)
|
| 57 |
+
_tag = _args.get("tag", "")
|
| 58 |
+
_args["tag"] = f"{tag}-{_tag}"
|
| 59 |
+
_nodes = _args.get("nodes", None)
|
| 60 |
+
_tasks = _args.get("tasks_per_node", 8)
|
| 61 |
+
eval_nodes = _nodes if eval_nodes is None else eval_nodes
|
| 62 |
+
eval_tasks_per_node = _tasks if eval_tasks_per_node is None else eval_tasks_per_node
|
| 63 |
+
if (eval_nodes != _nodes) or (eval_tasks_per_node != _tasks):
|
| 64 |
+
warnings.warn("Configs for online evals must use same number of nodes for slurm-batch processing")
|
| 65 |
+
|
| 66 |
+
# Model params
|
| 67 |
+
_args["pretrain"] = {}
|
| 68 |
+
_args["pretrain"]["model_name"] = model_name
|
| 69 |
+
_args["pretrain"]["patch_size"] = patch_size
|
| 70 |
+
_args["pretrain"]["tubelet_size"] = tubelet_size
|
| 71 |
+
_args["pretrain"]["uniform_power"] = uniform_power
|
| 72 |
+
_args["pretrain"]["use_sdpa"] = use_sdpa
|
| 73 |
+
_args["pretrain"]["clip_duration"] = clip_duration
|
| 74 |
+
_args["pretrain"]["use_silu"] = use_silu
|
| 75 |
+
_args["pretrain"]["wide_silu"] = wide_silu
|
| 76 |
+
|
| 77 |
+
# Data params
|
| 78 |
+
_args["pretrain"]["frames_per_clip"] = num_frames
|
| 79 |
+
|
| 80 |
+
# Misc
|
| 81 |
+
_args["pretrain"]["folder"] = logging_folder
|
| 82 |
+
_args["pretrain"]["checkpoint"] = checkpoint
|
| 83 |
+
_args["pretrain"]["write_tag"] = write_tag
|
| 84 |
+
|
| 85 |
+
args_eval += [_args]
|
| 86 |
+
|
| 87 |
+
return eval_nodes, eval_tasks_per_node, args_eval
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_checkpoint(
|
| 91 |
+
r_path,
|
| 92 |
+
encoder,
|
| 93 |
+
predictor,
|
| 94 |
+
target_encoder,
|
| 95 |
+
opt,
|
| 96 |
+
scaler,
|
| 97 |
+
is_anneal=False,
|
| 98 |
+
):
|
| 99 |
+
logger.info(f"Loading checkpoint from {r_path}")
|
| 100 |
+
checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
|
| 101 |
+
|
| 102 |
+
epoch = 0
|
| 103 |
+
if not is_anneal:
|
| 104 |
+
epoch = checkpoint["epoch"]
|
| 105 |
+
|
| 106 |
+
# -- loading encoder
|
| 107 |
+
pretrained_dict = checkpoint["encoder"]
|
| 108 |
+
msg = encoder.load_state_dict(pretrained_dict)
|
| 109 |
+
logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")
|
| 110 |
+
|
| 111 |
+
# -- loading predictor
|
| 112 |
+
pretrained_dict = checkpoint["predictor"]
|
| 113 |
+
msg = predictor.load_state_dict(pretrained_dict)
|
| 114 |
+
logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")
|
| 115 |
+
|
| 116 |
+
# -- loading target_encoder
|
| 117 |
+
if target_encoder is not None:
|
| 118 |
+
print(list(checkpoint.keys()))
|
| 119 |
+
pretrained_dict = checkpoint["target_encoder"]
|
| 120 |
+
msg = target_encoder.load_state_dict(pretrained_dict)
|
| 121 |
+
logger.info(f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}")
|
| 122 |
+
|
| 123 |
+
# -- loading optimizer
|
| 124 |
+
opt.load_state_dict(checkpoint["opt"])
|
| 125 |
+
if scaler is not None:
|
| 126 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 127 |
+
logger.info(f"loaded optimizers from epoch {epoch}")
|
| 128 |
+
logger.info(f"read-path: {r_path}")
|
| 129 |
+
del checkpoint
|
| 130 |
+
|
| 131 |
+
return (
|
| 132 |
+
encoder,
|
| 133 |
+
predictor,
|
| 134 |
+
target_encoder,
|
| 135 |
+
opt,
|
| 136 |
+
scaler,
|
| 137 |
+
epoch,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def init_video_model(
|
| 142 |
+
device,
|
| 143 |
+
patch_size=16,
|
| 144 |
+
max_num_frames=16,
|
| 145 |
+
tubelet_size=2,
|
| 146 |
+
model_name="vit_base",
|
| 147 |
+
crop_size=224,
|
| 148 |
+
pred_depth=6,
|
| 149 |
+
pred_num_heads=None,
|
| 150 |
+
pred_embed_dim=384,
|
| 151 |
+
uniform_power=False,
|
| 152 |
+
use_mask_tokens=False,
|
| 153 |
+
num_mask_tokens=2,
|
| 154 |
+
zero_init_mask_tokens=True,
|
| 155 |
+
use_sdpa=False,
|
| 156 |
+
use_rope=False,
|
| 157 |
+
use_silu=False,
|
| 158 |
+
use_pred_silu=False,
|
| 159 |
+
wide_silu=False,
|
| 160 |
+
use_activation_checkpointing=False,
|
| 161 |
+
):
|
| 162 |
+
encoder = video_vit.__dict__[model_name](
|
| 163 |
+
img_size=crop_size,
|
| 164 |
+
patch_size=patch_size,
|
| 165 |
+
num_frames=max_num_frames,
|
| 166 |
+
tubelet_size=tubelet_size,
|
| 167 |
+
uniform_power=uniform_power,
|
| 168 |
+
use_sdpa=use_sdpa,
|
| 169 |
+
use_silu=use_silu,
|
| 170 |
+
wide_silu=wide_silu,
|
| 171 |
+
use_activation_checkpointing=use_activation_checkpointing,
|
| 172 |
+
use_rope=use_rope,
|
| 173 |
+
)
|
| 174 |
+
encoder = MultiSeqWrapper(encoder)
|
| 175 |
+
predictor = vit_pred.__dict__["vit_predictor"](
|
| 176 |
+
img_size=crop_size,
|
| 177 |
+
use_mask_tokens=use_mask_tokens,
|
| 178 |
+
patch_size=patch_size,
|
| 179 |
+
num_frames=max_num_frames,
|
| 180 |
+
tubelet_size=tubelet_size,
|
| 181 |
+
embed_dim=encoder.backbone.embed_dim,
|
| 182 |
+
predictor_embed_dim=pred_embed_dim,
|
| 183 |
+
depth=pred_depth,
|
| 184 |
+
num_heads=encoder.backbone.num_heads if pred_num_heads is None else pred_num_heads,
|
| 185 |
+
uniform_power=uniform_power,
|
| 186 |
+
num_mask_tokens=num_mask_tokens,
|
| 187 |
+
zero_init_mask_tokens=zero_init_mask_tokens,
|
| 188 |
+
use_rope=use_rope,
|
| 189 |
+
use_sdpa=use_sdpa,
|
| 190 |
+
use_silu=use_pred_silu,
|
| 191 |
+
wide_silu=wide_silu,
|
| 192 |
+
use_activation_checkpointing=use_activation_checkpointing,
|
| 193 |
+
)
|
| 194 |
+
predictor = PredictorMultiSeqWrapper(predictor)
|
| 195 |
+
|
| 196 |
+
encoder.to(device)
|
| 197 |
+
predictor.to(device)
|
| 198 |
+
logger.info(encoder)
|
| 199 |
+
logger.info(predictor)
|
| 200 |
+
|
| 201 |
+
def count_parameters(model):
|
| 202 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 203 |
+
|
| 204 |
+
logger.info(f"Encoder number of parameters: {count_parameters(encoder)}")
|
| 205 |
+
logger.info(f"Predictor number of parameters: {count_parameters(predictor)}")
|
| 206 |
+
|
| 207 |
+
return encoder, predictor
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def init_opt(
|
| 211 |
+
is_anneal,
|
| 212 |
+
encoder,
|
| 213 |
+
predictor,
|
| 214 |
+
iterations_per_epoch,
|
| 215 |
+
start_lr,
|
| 216 |
+
ref_lr,
|
| 217 |
+
warmup,
|
| 218 |
+
num_epochs,
|
| 219 |
+
wd=1e-6,
|
| 220 |
+
final_wd=1e-6,
|
| 221 |
+
final_lr=0.0,
|
| 222 |
+
mixed_precision=False,
|
| 223 |
+
ipe_scale=1.25,
|
| 224 |
+
betas=(0.9, 0.999),
|
| 225 |
+
eps=1e-8,
|
| 226 |
+
zero_init_bias_wd=True,
|
| 227 |
+
):
|
| 228 |
+
param_groups = [
|
| 229 |
+
{"params": (p for n, p in encoder.named_parameters() if ("bias" not in n) and (len(p.shape) != 1))},
|
| 230 |
+
{"params": (p for n, p in predictor.named_parameters() if ("bias" not in n) and (len(p.shape) != 1))},
|
| 231 |
+
{
|
| 232 |
+
"params": (p for n, p in encoder.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
|
| 233 |
+
"WD_exclude": zero_init_bias_wd,
|
| 234 |
+
"weight_decay": 0,
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"params": (p for n, p in predictor.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
|
| 238 |
+
"WD_exclude": zero_init_bias_wd,
|
| 239 |
+
"weight_decay": 0,
|
| 240 |
+
},
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
|
| 244 |
+
if not is_anneal:
|
| 245 |
+
scheduler = WarmupCosineSchedule(
|
| 246 |
+
optimizer,
|
| 247 |
+
warmup_steps=int(warmup * iterations_per_epoch),
|
| 248 |
+
start_lr=start_lr,
|
| 249 |
+
ref_lr=ref_lr,
|
| 250 |
+
final_lr=final_lr,
|
| 251 |
+
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
scheduler = LinearDecaySchedule(
|
| 255 |
+
optimizer,
|
| 256 |
+
ref_lr=ref_lr,
|
| 257 |
+
final_lr=final_lr,
|
| 258 |
+
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
|
| 259 |
+
)
|
| 260 |
+
wd_scheduler = CosineWDSchedule(
|
| 261 |
+
optimizer,
|
| 262 |
+
ref_wd=wd,
|
| 263 |
+
final_wd=final_wd,
|
| 264 |
+
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
|
| 265 |
+
)
|
| 266 |
+
scaler = torch.amp.GradScaler('cuda') if mixed_precision else None
|
| 267 |
+
return optimizer, scaler, scheduler, wd_scheduler
|
vjepa2/app/vjepa_droid/droid.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from logging import getLogger
|
| 11 |
+
from math import ceil
|
| 12 |
+
|
| 13 |
+
import h5py
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import torch
|
| 17 |
+
import torch.utils.data
|
| 18 |
+
from decord import VideoReader, cpu
|
| 19 |
+
from scipy.spatial.transform import Rotation
|
| 20 |
+
|
| 21 |
+
_GLOBAL_SEED = 0
|
| 22 |
+
logger = getLogger()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def init_data(
|
| 26 |
+
data_path,
|
| 27 |
+
batch_size,
|
| 28 |
+
frames_per_clip=16,
|
| 29 |
+
fps=5,
|
| 30 |
+
crop_size=224,
|
| 31 |
+
rank=0,
|
| 32 |
+
world_size=1,
|
| 33 |
+
camera_views=0,
|
| 34 |
+
stereo_view=False,
|
| 35 |
+
drop_last=True,
|
| 36 |
+
num_workers=10,
|
| 37 |
+
pin_mem=True,
|
| 38 |
+
persistent_workers=True,
|
| 39 |
+
collator=None,
|
| 40 |
+
transform=None,
|
| 41 |
+
camera_frame=False,
|
| 42 |
+
tubelet_size=2,
|
| 43 |
+
):
|
| 44 |
+
dataset = DROIDVideoDataset(
|
| 45 |
+
data_path=data_path,
|
| 46 |
+
frames_per_clip=frames_per_clip,
|
| 47 |
+
transform=transform,
|
| 48 |
+
fps=fps,
|
| 49 |
+
camera_views=camera_views,
|
| 50 |
+
frameskip=tubelet_size,
|
| 51 |
+
camera_frame=camera_frame,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 55 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=True
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
data_loader = torch.utils.data.DataLoader(
|
| 59 |
+
dataset,
|
| 60 |
+
collate_fn=collator,
|
| 61 |
+
sampler=dist_sampler,
|
| 62 |
+
batch_size=batch_size,
|
| 63 |
+
drop_last=drop_last,
|
| 64 |
+
pin_memory=pin_mem,
|
| 65 |
+
num_workers=num_workers,
|
| 66 |
+
persistent_workers=(num_workers > 0) and persistent_workers,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
logger.info("VideoDataset unsupervised data loader created")
|
| 70 |
+
|
| 71 |
+
return data_loader, dist_sampler
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_json(directory):
|
| 75 |
+
for filename in os.listdir(directory):
|
| 76 |
+
if filename.endswith(".json"):
|
| 77 |
+
file_path = os.path.join(directory, filename)
|
| 78 |
+
try:
|
| 79 |
+
with open(file_path, "r") as f:
|
| 80 |
+
return json.load(f)
|
| 81 |
+
except json.JSONDecodeError:
|
| 82 |
+
print(f"Error decoding JSON in file: {filename}")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"An unexpected error occurred while processing {filename}: {e}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class DROIDVideoDataset(torch.utils.data.Dataset):
|
| 88 |
+
"""Video classification dataset."""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
data_path,
|
| 93 |
+
camera_views=["left_mp4_path", "right_mp4_path"],
|
| 94 |
+
frameskip=2,
|
| 95 |
+
frames_per_clip=16,
|
| 96 |
+
fps=5,
|
| 97 |
+
transform=None,
|
| 98 |
+
camera_frame=False,
|
| 99 |
+
):
|
| 100 |
+
self.data_path = data_path
|
| 101 |
+
self.frames_per_clip = frames_per_clip
|
| 102 |
+
self.frameskip = frameskip
|
| 103 |
+
self.fps = fps
|
| 104 |
+
self.transform = transform
|
| 105 |
+
self.camera_frame = camera_frame
|
| 106 |
+
if VideoReader is None:
|
| 107 |
+
raise ImportError('Unable to import "decord" which is required to read videos.')
|
| 108 |
+
|
| 109 |
+
# Camera views
|
| 110 |
+
# ---
|
| 111 |
+
# wrist camera view
|
| 112 |
+
# left camera view
|
| 113 |
+
# right camera view
|
| 114 |
+
self.camera_views = camera_views
|
| 115 |
+
self.h5_name = "trajectory.h5"
|
| 116 |
+
|
| 117 |
+
samples = list(pd.read_csv(data_path, header=None, delimiter=" ").values[:, 0])
|
| 118 |
+
self.samples = samples
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, index):
|
| 121 |
+
path = self.samples[index]
|
| 122 |
+
|
| 123 |
+
# -- keep trying to load videos until you find a valid sample
|
| 124 |
+
loaded_video = False
|
| 125 |
+
while not loaded_video:
|
| 126 |
+
try:
|
| 127 |
+
buffer, actions, states, extrinsics, indices = self.loadvideo_decord(path)
|
| 128 |
+
loaded_video = True
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.info(f"Encountered exception when loading video {path=} {e=}")
|
| 131 |
+
loaded_video = False
|
| 132 |
+
index = np.random.randint(self.__len__())
|
| 133 |
+
path = self.samples[index]
|
| 134 |
+
|
| 135 |
+
return buffer, actions, states, extrinsics, indices
|
| 136 |
+
|
| 137 |
+
def poses_to_diffs(self, poses):
|
| 138 |
+
xyz = poses[:, :3] # shape [T, 3]
|
| 139 |
+
thetas = poses[:, 3:6] # euler angles, shape [T, 3]
|
| 140 |
+
matrices = [Rotation.from_euler("xyz", theta, degrees=False).as_matrix() for theta in thetas]
|
| 141 |
+
xyz_diff = xyz[1:] - xyz[:-1]
|
| 142 |
+
angle_diff = [matrices[t + 1] @ matrices[t].T for t in range(len(matrices) - 1)]
|
| 143 |
+
angle_diff = [Rotation.from_matrix(mat).as_euler("xyz", degrees=False) for mat in angle_diff]
|
| 144 |
+
angle_diff = np.stack([d for d in angle_diff], axis=0)
|
| 145 |
+
closedness = poses[:, -1:]
|
| 146 |
+
closedness_delta = closedness[1:] - closedness[:-1]
|
| 147 |
+
return np.concatenate([xyz_diff, angle_diff, closedness_delta], axis=1)
|
| 148 |
+
|
| 149 |
+
def transform_frame(self, poses, extrinsics):
|
| 150 |
+
gripper = poses[:, -1:]
|
| 151 |
+
poses = poses[:, :-1]
|
| 152 |
+
|
| 153 |
+
def pose_to_transform(pose):
|
| 154 |
+
trans = pose[:3] # shape [3]
|
| 155 |
+
theta = pose[3:6] # euler angles, shape [3]
|
| 156 |
+
Rot = Rotation.from_euler("xyz", theta, degrees=False).as_matrix()
|
| 157 |
+
T = np.eye(4)
|
| 158 |
+
T[:3, :3] = Rot
|
| 159 |
+
T[:3, 3] = trans
|
| 160 |
+
return T
|
| 161 |
+
|
| 162 |
+
def transform_to_pose(transform):
|
| 163 |
+
trans = transform[:3, 3]
|
| 164 |
+
Rot = transform[:3, :3]
|
| 165 |
+
angle = Rotation.from_matrix(Rot).as_euler("xyz", degrees=False)
|
| 166 |
+
return np.concatenate([trans, angle], axis=0)
|
| 167 |
+
|
| 168 |
+
new_pose = []
|
| 169 |
+
for p, e in zip(poses, extrinsics):
|
| 170 |
+
p_transform = pose_to_transform(p)
|
| 171 |
+
e_transform = pose_to_transform(e)
|
| 172 |
+
new_pose_transform = np.linalg.inv(e_transform) @ p_transform
|
| 173 |
+
new_pose += [transform_to_pose(new_pose_transform)]
|
| 174 |
+
new_pose = np.stack(new_pose, axis=0)
|
| 175 |
+
|
| 176 |
+
return np.concatenate([new_pose, gripper], axis=1)
|
| 177 |
+
|
| 178 |
+
def loadvideo_decord(self, path):
|
| 179 |
+
# -- load metadata
|
| 180 |
+
metadata = get_json(path)
|
| 181 |
+
if metadata is None:
|
| 182 |
+
raise Exception(f"No metadata for video {path=}")
|
| 183 |
+
|
| 184 |
+
# -- load trajectory info
|
| 185 |
+
tpath = os.path.join(path, self.h5_name)
|
| 186 |
+
trajectory = h5py.File(tpath)
|
| 187 |
+
|
| 188 |
+
# -- randomly sample a camera view
|
| 189 |
+
camera_view = self.camera_views[torch.randint(0, len(self.camera_views), (1,))]
|
| 190 |
+
mp4_name = metadata[camera_view].split("recordings/MP4/")[-1]
|
| 191 |
+
camera_name = mp4_name.split(".")[0]
|
| 192 |
+
extrinsics = trajectory["observation"]["camera_extrinsics"][f"{camera_name}_left"]
|
| 193 |
+
states = np.concatenate(
|
| 194 |
+
[
|
| 195 |
+
np.array(trajectory["observation"]["robot_state"]["cartesian_position"]),
|
| 196 |
+
np.array(trajectory["observation"]["robot_state"]["gripper_position"])[:, None],
|
| 197 |
+
],
|
| 198 |
+
axis=1,
|
| 199 |
+
) # [T, 7]
|
| 200 |
+
vpath = os.path.join(path, "recordings/MP4", mp4_name)
|
| 201 |
+
vr = VideoReader(vpath, num_threads=-1, ctx=cpu(0))
|
| 202 |
+
# --
|
| 203 |
+
vfps = vr.get_avg_fps()
|
| 204 |
+
fpc = self.frames_per_clip
|
| 205 |
+
fps = self.fps if self.fps is not None else vfps
|
| 206 |
+
fstp = ceil(vfps / fps)
|
| 207 |
+
nframes = int(fpc * fstp)
|
| 208 |
+
vlen = len(vr)
|
| 209 |
+
|
| 210 |
+
if vlen < nframes:
|
| 211 |
+
raise Exception(f"Video is too short {vpath=}, {nframes=}, {vlen=}")
|
| 212 |
+
|
| 213 |
+
# sample a random window of nframes
|
| 214 |
+
ef = np.random.randint(nframes, vlen)
|
| 215 |
+
sf = ef - nframes
|
| 216 |
+
indices = np.arange(sf, sf + nframes, fstp).astype(np.int64)
|
| 217 |
+
# --
|
| 218 |
+
states = states[indices, :][:: self.frameskip]
|
| 219 |
+
extrinsics = extrinsics[indices, :][:: self.frameskip]
|
| 220 |
+
if self.camera_frame:
|
| 221 |
+
states = self.transform_frame(states, extrinsics)
|
| 222 |
+
actions = self.poses_to_diffs(states)
|
| 223 |
+
# --
|
| 224 |
+
vr.seek(0) # go to start of video before sampling frames
|
| 225 |
+
buffer = vr.get_batch(indices).asnumpy()
|
| 226 |
+
if self.transform is not None:
|
| 227 |
+
buffer = self.transform(buffer)
|
| 228 |
+
|
| 229 |
+
return buffer, actions, states, extrinsics, indices
|
| 230 |
+
|
| 231 |
+
def __len__(self):
|
| 232 |
+
return len(self.samples)
|
vjepa2/app/vjepa_droid/train.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
|
| 11 |
+
try:
|
| 12 |
+
# -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
|
| 13 |
+
# -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
|
| 14 |
+
# -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
|
| 15 |
+
# -- TO EACH PROCESS
|
| 16 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"]
|
| 17 |
+
except Exception:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
import copy
|
| 21 |
+
import gc
|
| 22 |
+
import random
|
| 23 |
+
import time
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.multiprocessing as mp
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 30 |
+
|
| 31 |
+
from app.vjepa_droid.droid import init_data
|
| 32 |
+
from app.vjepa_droid.transforms import make_transforms
|
| 33 |
+
from app.vjepa_droid.utils import init_opt, init_video_model, load_checkpoint, load_pretrained
|
| 34 |
+
from src.utils.distributed import init_distributed
|
| 35 |
+
from src.utils.logging import AverageMeter, CSVLogger, get_logger, gpu_timer
|
| 36 |
+
|
| 37 |
+
# --
|
| 38 |
+
log_timings = True
|
| 39 |
+
log_freq = 10
|
| 40 |
+
CHECKPOINT_FREQ = 1
|
| 41 |
+
GARBAGE_COLLECT_ITR_FREQ = 50
|
| 42 |
+
# --
|
| 43 |
+
|
| 44 |
+
_GLOBAL_SEED = 0
|
| 45 |
+
random.seed(_GLOBAL_SEED)
|
| 46 |
+
np.random.seed(_GLOBAL_SEED)
|
| 47 |
+
torch.manual_seed(_GLOBAL_SEED)
|
| 48 |
+
torch.backends.cudnn.benchmark = True
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
logger = get_logger(__name__, force=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main(args, resume_preempt=False):
|
| 55 |
+
# ----------------------------------------------------------------------- #
|
| 56 |
+
# PASSED IN PARAMS FROM CONFIG FILE
|
| 57 |
+
# ----------------------------------------------------------------------- #
|
| 58 |
+
|
| 59 |
+
# -- META
|
| 60 |
+
folder = args.get("folder")
|
| 61 |
+
cfgs_meta = args.get("meta")
|
| 62 |
+
r_file = cfgs_meta.get("resume_checkpoint", None)
|
| 63 |
+
p_file = cfgs_meta.get("pretrain_checkpoint", None)
|
| 64 |
+
load_predictor = cfgs_meta.get("load_predictor", False)
|
| 65 |
+
context_encoder_key = cfgs_meta.get("context_encoder_key", "encoder")
|
| 66 |
+
target_encoder_key = cfgs_meta.get("target_encoder_key", "target_encoder")
|
| 67 |
+
load_encoder = cfgs_meta.get("load_encoder", True)
|
| 68 |
+
seed = cfgs_meta.get("seed", _GLOBAL_SEED)
|
| 69 |
+
save_every_freq = cfgs_meta.get("save_every_freq", -1)
|
| 70 |
+
skip_batches = cfgs_meta.get("skip_batches", -1)
|
| 71 |
+
use_sdpa = cfgs_meta.get("use_sdpa", False)
|
| 72 |
+
sync_gc = cfgs_meta.get("sync_gc", False)
|
| 73 |
+
which_dtype = cfgs_meta.get("dtype")
|
| 74 |
+
logger.info(f"{which_dtype=}")
|
| 75 |
+
if which_dtype.lower() == "bfloat16":
|
| 76 |
+
dtype = torch.bfloat16
|
| 77 |
+
mixed_precision = True
|
| 78 |
+
elif which_dtype.lower() == "float16":
|
| 79 |
+
dtype = torch.float16
|
| 80 |
+
mixed_precision = True
|
| 81 |
+
else:
|
| 82 |
+
dtype = torch.float32
|
| 83 |
+
mixed_precision = False
|
| 84 |
+
|
| 85 |
+
# -- MODEL
|
| 86 |
+
cfgs_model = args.get("model")
|
| 87 |
+
compile_model = cfgs_model.get("compile_model", False)
|
| 88 |
+
use_activation_checkpointing = cfgs_model.get("use_activation_checkpointing", False)
|
| 89 |
+
model_name = cfgs_model.get("model_name")
|
| 90 |
+
pred_depth = cfgs_model.get("pred_depth")
|
| 91 |
+
pred_num_heads = cfgs_model.get("pred_num_heads", None)
|
| 92 |
+
pred_embed_dim = cfgs_model.get("pred_embed_dim")
|
| 93 |
+
pred_is_frame_causal = cfgs_model.get("pred_is_frame_causal", True)
|
| 94 |
+
uniform_power = cfgs_model.get("uniform_power", False)
|
| 95 |
+
use_rope = cfgs_model.get("use_rope", False)
|
| 96 |
+
use_silu = cfgs_model.get("use_silu", False)
|
| 97 |
+
use_pred_silu = cfgs_model.get("use_pred_silu", False)
|
| 98 |
+
wide_silu = cfgs_model.get("wide_silu", True)
|
| 99 |
+
use_extrinsics = cfgs_model.get("use_extrinsics", False)
|
| 100 |
+
|
| 101 |
+
# -- DATA
|
| 102 |
+
cfgs_data = args.get("data")
|
| 103 |
+
datasets = cfgs_data.get("datasets", [])
|
| 104 |
+
dataset_path = datasets[0]
|
| 105 |
+
dataset_fpcs = cfgs_data.get("dataset_fpcs")
|
| 106 |
+
max_num_frames = max(dataset_fpcs)
|
| 107 |
+
camera_frame = cfgs_data.get("camera_frame", False)
|
| 108 |
+
camera_views = cfgs_data.get("camera_views", ["left_mp4_path"])
|
| 109 |
+
stereo_view = cfgs_data.get("stereo_view", False)
|
| 110 |
+
batch_size = cfgs_data.get("batch_size")
|
| 111 |
+
tubelet_size = cfgs_data.get("tubelet_size")
|
| 112 |
+
fps = cfgs_data.get("fps")
|
| 113 |
+
crop_size = cfgs_data.get("crop_size", 256)
|
| 114 |
+
patch_size = cfgs_data.get("patch_size")
|
| 115 |
+
pin_mem = cfgs_data.get("pin_mem", False)
|
| 116 |
+
num_workers = cfgs_data.get("num_workers", 1)
|
| 117 |
+
persistent_workers = cfgs_data.get("persistent_workers", True)
|
| 118 |
+
|
| 119 |
+
# -- DATA AUGS
|
| 120 |
+
cfgs_data_aug = args.get("data_aug")
|
| 121 |
+
horizontal_flip = cfgs_data_aug.get("horizontal_flip", False)
|
| 122 |
+
ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3])
|
| 123 |
+
rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0])
|
| 124 |
+
motion_shift = cfgs_data_aug.get("motion_shift", False)
|
| 125 |
+
reprob = cfgs_data_aug.get("reprob", 0.0)
|
| 126 |
+
use_aa = cfgs_data_aug.get("auto_augment", False)
|
| 127 |
+
|
| 128 |
+
# -- LOSS
|
| 129 |
+
cfgs_loss = args.get("loss")
|
| 130 |
+
loss_exp = cfgs_loss.get("loss_exp")
|
| 131 |
+
normalize_reps = cfgs_loss.get("normalize_reps")
|
| 132 |
+
auto_steps = min(cfgs_loss.get("auto_steps", 1), max_num_frames)
|
| 133 |
+
# --
|
| 134 |
+
tokens_per_frame = int((crop_size // patch_size) ** 2)
|
| 135 |
+
|
| 136 |
+
# -- OPTIMIZATION
|
| 137 |
+
cfgs_opt = args.get("optimization")
|
| 138 |
+
ipe = cfgs_opt.get("ipe", None)
|
| 139 |
+
wd = float(cfgs_opt.get("weight_decay"))
|
| 140 |
+
final_wd = float(cfgs_opt.get("final_weight_decay"))
|
| 141 |
+
num_epochs = cfgs_opt.get("epochs")
|
| 142 |
+
anneal = cfgs_opt.get("anneal")
|
| 143 |
+
warmup = cfgs_opt.get("warmup")
|
| 144 |
+
start_lr = cfgs_opt.get("start_lr")
|
| 145 |
+
lr = cfgs_opt.get("lr")
|
| 146 |
+
final_lr = cfgs_opt.get("final_lr")
|
| 147 |
+
enc_lr_scale = cfgs_opt.get("enc_lr_scale", 1.0)
|
| 148 |
+
betas = cfgs_opt.get("betas", (0.9, 0.999))
|
| 149 |
+
eps = cfgs_opt.get("eps", 1.0e-8)
|
| 150 |
+
# ----------------------------------------------------------------------- #
|
| 151 |
+
# ----------------------------------------------------------------------- #
|
| 152 |
+
|
| 153 |
+
np.random.seed(seed)
|
| 154 |
+
torch.manual_seed(seed)
|
| 155 |
+
torch.backends.cudnn.benchmark = True
|
| 156 |
+
try:
|
| 157 |
+
mp.set_start_method("spawn")
|
| 158 |
+
except Exception:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
# -- init torch distributed backend
|
| 162 |
+
world_size, rank = init_distributed()
|
| 163 |
+
logger.info(f"Initialized (rank/world-size) {rank}/{world_size}")
|
| 164 |
+
|
| 165 |
+
# -- set device
|
| 166 |
+
if not torch.cuda.is_available():
|
| 167 |
+
device = torch.device("cpu")
|
| 168 |
+
else:
|
| 169 |
+
device = torch.device("cuda:0")
|
| 170 |
+
torch.cuda.set_device(device)
|
| 171 |
+
|
| 172 |
+
# -- log/checkpointing paths
|
| 173 |
+
log_file = os.path.join(folder, f"log_r{rank}.csv")
|
| 174 |
+
latest_path = os.path.join(folder, "latest.pt")
|
| 175 |
+
resume_path = os.path.join(folder, r_file) if r_file is not None else latest_path
|
| 176 |
+
if not os.path.exists(resume_path):
|
| 177 |
+
resume_path = None
|
| 178 |
+
|
| 179 |
+
# -- make csv_logger
|
| 180 |
+
csv_logger = CSVLogger(
|
| 181 |
+
log_file,
|
| 182 |
+
("%d", "epoch"),
|
| 183 |
+
("%d", "itr"),
|
| 184 |
+
("%.5f", "loss"),
|
| 185 |
+
("%d", "iter-time(ms)"),
|
| 186 |
+
("%d", "gpu-time(ms)"),
|
| 187 |
+
("%d", "dataload-time(ms)"),
|
| 188 |
+
mode="+a",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# -- init model
|
| 192 |
+
encoder, predictor = init_video_model(
|
| 193 |
+
uniform_power=uniform_power,
|
| 194 |
+
device=device,
|
| 195 |
+
patch_size=patch_size,
|
| 196 |
+
max_num_frames=512,
|
| 197 |
+
tubelet_size=tubelet_size,
|
| 198 |
+
model_name=model_name,
|
| 199 |
+
crop_size=crop_size,
|
| 200 |
+
pred_depth=pred_depth,
|
| 201 |
+
pred_num_heads=pred_num_heads,
|
| 202 |
+
pred_embed_dim=pred_embed_dim,
|
| 203 |
+
action_embed_dim=7,
|
| 204 |
+
pred_is_frame_causal=pred_is_frame_causal,
|
| 205 |
+
use_extrinsics=use_extrinsics,
|
| 206 |
+
use_sdpa=use_sdpa,
|
| 207 |
+
use_silu=use_silu,
|
| 208 |
+
use_pred_silu=use_pred_silu,
|
| 209 |
+
wide_silu=wide_silu,
|
| 210 |
+
use_rope=use_rope,
|
| 211 |
+
use_activation_checkpointing=use_activation_checkpointing,
|
| 212 |
+
)
|
| 213 |
+
target_encoder = copy.deepcopy(encoder)
|
| 214 |
+
|
| 215 |
+
if compile_model:
|
| 216 |
+
logger.info("Compiling encoder, target_encoder, and predictor.")
|
| 217 |
+
torch._dynamo.config.optimize_ddp = False
|
| 218 |
+
encoder.compile()
|
| 219 |
+
target_encoder.compile()
|
| 220 |
+
predictor.compile()
|
| 221 |
+
|
| 222 |
+
video_collator = torch.utils.data.default_collate
|
| 223 |
+
transform = make_transforms(
|
| 224 |
+
random_horizontal_flip=horizontal_flip,
|
| 225 |
+
random_resize_aspect_ratio=ar_range,
|
| 226 |
+
random_resize_scale=rr_scale,
|
| 227 |
+
reprob=reprob,
|
| 228 |
+
auto_augment=use_aa,
|
| 229 |
+
motion_shift=motion_shift,
|
| 230 |
+
crop_size=crop_size,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# -- init data-loaders/samplers
|
| 234 |
+
(unsupervised_loader, unsupervised_sampler) = init_data(
|
| 235 |
+
data_path=dataset_path,
|
| 236 |
+
batch_size=batch_size,
|
| 237 |
+
frames_per_clip=max_num_frames,
|
| 238 |
+
tubelet_size=1,
|
| 239 |
+
fps=fps,
|
| 240 |
+
camera_views=camera_views,
|
| 241 |
+
camera_frame=camera_frame,
|
| 242 |
+
stereo_view=stereo_view,
|
| 243 |
+
transform=transform,
|
| 244 |
+
collator=video_collator,
|
| 245 |
+
num_workers=num_workers,
|
| 246 |
+
world_size=world_size,
|
| 247 |
+
pin_mem=pin_mem,
|
| 248 |
+
persistent_workers=persistent_workers,
|
| 249 |
+
rank=rank,
|
| 250 |
+
)
|
| 251 |
+
_dlen = len(unsupervised_loader)
|
| 252 |
+
if ipe is None:
|
| 253 |
+
ipe = _dlen
|
| 254 |
+
logger.info(f"iterations per epoch/dataset length: {ipe}/{_dlen}")
|
| 255 |
+
|
| 256 |
+
# -- init optimizer and scheduler
|
| 257 |
+
optimizer, scaler, scheduler, wd_scheduler = init_opt(
|
| 258 |
+
encoder=encoder,
|
| 259 |
+
predictor=predictor,
|
| 260 |
+
wd=wd,
|
| 261 |
+
final_wd=final_wd,
|
| 262 |
+
start_lr=start_lr,
|
| 263 |
+
ref_lr=lr,
|
| 264 |
+
final_lr=final_lr,
|
| 265 |
+
enc_lr_scale=enc_lr_scale,
|
| 266 |
+
iterations_per_epoch=ipe,
|
| 267 |
+
anneal=anneal,
|
| 268 |
+
warmup=warmup,
|
| 269 |
+
num_epochs=num_epochs,
|
| 270 |
+
mixed_precision=mixed_precision,
|
| 271 |
+
betas=betas,
|
| 272 |
+
eps=eps,
|
| 273 |
+
)
|
| 274 |
+
encoder = DistributedDataParallel(encoder, static_graph=True)
|
| 275 |
+
predictor = DistributedDataParallel(predictor, static_graph=False, find_unused_parameters=True)
|
| 276 |
+
target_encoder = DistributedDataParallel(target_encoder)
|
| 277 |
+
for p in target_encoder.parameters():
|
| 278 |
+
p.requires_grad = False
|
| 279 |
+
|
| 280 |
+
# -- looad pretrained weights
|
| 281 |
+
encoder, predictor, target_encoder = load_pretrained(
|
| 282 |
+
r_path=p_file,
|
| 283 |
+
encoder=encoder,
|
| 284 |
+
predictor=predictor,
|
| 285 |
+
context_encoder_key=context_encoder_key,
|
| 286 |
+
target_encoder_key=target_encoder_key,
|
| 287 |
+
target_encoder=target_encoder,
|
| 288 |
+
load_predictor=load_predictor,
|
| 289 |
+
load_encoder=load_encoder,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
start_epoch = 0
|
| 293 |
+
# -- load training checkpoint
|
| 294 |
+
if os.path.exists(latest_path):
|
| 295 |
+
(
|
| 296 |
+
encoder,
|
| 297 |
+
predictor,
|
| 298 |
+
target_encoder,
|
| 299 |
+
optimizer,
|
| 300 |
+
scaler,
|
| 301 |
+
start_epoch,
|
| 302 |
+
) = load_checkpoint(
|
| 303 |
+
r_path=resume_path,
|
| 304 |
+
encoder=encoder,
|
| 305 |
+
predictor=predictor,
|
| 306 |
+
target_encoder=target_encoder,
|
| 307 |
+
opt=optimizer,
|
| 308 |
+
scaler=scaler,
|
| 309 |
+
)
|
| 310 |
+
for _ in range(start_epoch * ipe):
|
| 311 |
+
scheduler.step()
|
| 312 |
+
wd_scheduler.step()
|
| 313 |
+
|
| 314 |
+
def save_checkpoint(epoch, path):
|
| 315 |
+
if rank != 0:
|
| 316 |
+
return
|
| 317 |
+
save_dict = {
|
| 318 |
+
"encoder": encoder.state_dict(),
|
| 319 |
+
"predictor": predictor.state_dict(),
|
| 320 |
+
"opt": optimizer.state_dict(),
|
| 321 |
+
"scaler": None if scaler is None else scaler.state_dict(),
|
| 322 |
+
"target_encoder": target_encoder.state_dict(),
|
| 323 |
+
"epoch": epoch,
|
| 324 |
+
"loss": loss_meter.avg,
|
| 325 |
+
"batch_size": batch_size,
|
| 326 |
+
"world_size": world_size,
|
| 327 |
+
"lr": lr,
|
| 328 |
+
}
|
| 329 |
+
try:
|
| 330 |
+
torch.save(save_dict, path)
|
| 331 |
+
except Exception as e:
|
| 332 |
+
logger.info(f"Encountered exception when saving checkpoint: {e}")
|
| 333 |
+
|
| 334 |
+
logger.info("Initializing loader...")
|
| 335 |
+
unsupervised_sampler.set_epoch(start_epoch)
|
| 336 |
+
loader = iter(unsupervised_loader)
|
| 337 |
+
|
| 338 |
+
if skip_batches > 0:
|
| 339 |
+
logger.info(f"Skip {skip_batches} batches")
|
| 340 |
+
# -- update distributed-data-loader epoch
|
| 341 |
+
|
| 342 |
+
for itr in range(skip_batches):
|
| 343 |
+
if itr % 10 == 0:
|
| 344 |
+
logger.info(f"Skip {itr}/{skip_batches} batches")
|
| 345 |
+
try:
|
| 346 |
+
_ = next(loader)
|
| 347 |
+
except Exception:
|
| 348 |
+
loader = iter(unsupervised_loader)
|
| 349 |
+
_ = next(loader)
|
| 350 |
+
|
| 351 |
+
if sync_gc:
|
| 352 |
+
gc.disable()
|
| 353 |
+
gc.collect()
|
| 354 |
+
|
| 355 |
+
# -- TRAINING LOOP
|
| 356 |
+
for epoch in range(start_epoch, num_epochs):
|
| 357 |
+
logger.info("Epoch %d" % (epoch + 1))
|
| 358 |
+
|
| 359 |
+
loss_meter = AverageMeter()
|
| 360 |
+
jloss_meter = AverageMeter()
|
| 361 |
+
sloss_meter = AverageMeter()
|
| 362 |
+
iter_time_meter = AverageMeter()
|
| 363 |
+
gpu_time_meter = AverageMeter()
|
| 364 |
+
data_elapsed_time_meter = AverageMeter()
|
| 365 |
+
|
| 366 |
+
for itr in range(ipe):
|
| 367 |
+
itr_start_time = time.time()
|
| 368 |
+
|
| 369 |
+
iter_retries = 0
|
| 370 |
+
iter_successful = False
|
| 371 |
+
while not iter_successful:
|
| 372 |
+
try:
|
| 373 |
+
sample = next(loader)
|
| 374 |
+
iter_successful = True
|
| 375 |
+
except StopIteration:
|
| 376 |
+
logger.info("Exhausted data loaders. Refreshing...")
|
| 377 |
+
unsupervised_sampler.set_epoch(epoch)
|
| 378 |
+
loader = iter(unsupervised_loader)
|
| 379 |
+
except Exception as e:
|
| 380 |
+
NUM_RETRIES = 5
|
| 381 |
+
if iter_retries < NUM_RETRIES:
|
| 382 |
+
logger.warning(f"Encountered exception when loading data (num retries {iter_retries}):\n{e}")
|
| 383 |
+
iter_retries += 1
|
| 384 |
+
time.sleep(5)
|
| 385 |
+
else:
|
| 386 |
+
logger.warning(f"Exceeded max retries ({NUM_RETRIES}) when loading data. Skipping batch.")
|
| 387 |
+
raise e
|
| 388 |
+
|
| 389 |
+
def load_clips():
|
| 390 |
+
clips = sample[0].to(device, non_blocking=True) # [B C T H W]
|
| 391 |
+
actions = sample[1].to(device, dtype=torch.float, non_blocking=True) # [B T-1 7]
|
| 392 |
+
states = sample[2].to(device, dtype=torch.float, non_blocking=True) # [B T 7]
|
| 393 |
+
extrinsics = sample[3].to(device, dtype=torch.float, non_blocking=True) # [B T 7]
|
| 394 |
+
return (clips, actions, states, extrinsics)
|
| 395 |
+
|
| 396 |
+
clips, actions, states, extrinsics = load_clips()
|
| 397 |
+
data_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
|
| 398 |
+
|
| 399 |
+
if sync_gc and (itr + 1) % GARBAGE_COLLECT_ITR_FREQ == 0:
|
| 400 |
+
logger.info("Running garbage collection...")
|
| 401 |
+
gc.collect()
|
| 402 |
+
|
| 403 |
+
def train_step():
|
| 404 |
+
_new_lr = scheduler.step()
|
| 405 |
+
_new_wd = wd_scheduler.step()
|
| 406 |
+
# --
|
| 407 |
+
|
| 408 |
+
def forward_target(c):
|
| 409 |
+
with torch.no_grad():
|
| 410 |
+
c = c.permute(0, 2, 1, 3, 4).flatten(0, 1).unsqueeze(2).repeat(1, 1, 2, 1, 1)
|
| 411 |
+
h = target_encoder(c)
|
| 412 |
+
h = h.view(batch_size, max_num_frames, -1, h.size(-1)).flatten(1, 2)
|
| 413 |
+
if normalize_reps:
|
| 414 |
+
h = F.layer_norm(h, (h.size(-1),))
|
| 415 |
+
return h
|
| 416 |
+
|
| 417 |
+
def forward_predictions(z):
|
| 418 |
+
|
| 419 |
+
def _step_predictor(_z, _a, _s, _e):
|
| 420 |
+
_z = predictor(_z, _a, _s, _e)
|
| 421 |
+
if normalize_reps:
|
| 422 |
+
_z = F.layer_norm(_z, (_z.size(-1),))
|
| 423 |
+
return _z
|
| 424 |
+
|
| 425 |
+
# -- one step of predictor with teacher forcing
|
| 426 |
+
_z, _a, _s, _e = z[:, :-tokens_per_frame], actions, states[:, :-1], extrinsics[:, :-1]
|
| 427 |
+
z_tf = _step_predictor(_z, _a, _s, _e)
|
| 428 |
+
|
| 429 |
+
# -- full auto-regressive rollouts of predictor
|
| 430 |
+
_z = torch.cat([z[:, : tokens_per_frame], z_tf[:, : tokens_per_frame]], dim=1)
|
| 431 |
+
for n in range(1, auto_steps):
|
| 432 |
+
_a, _s, _e = actions[:, : n + 1], states[:, : n + 1], extrinsics[:, : n + 1]
|
| 433 |
+
_z_nxt = _step_predictor(_z, _a, _s, _e)[:, -tokens_per_frame:]
|
| 434 |
+
_z = torch.cat([_z, _z_nxt], dim=1)
|
| 435 |
+
z_ar = _z[:, tokens_per_frame:]
|
| 436 |
+
|
| 437 |
+
return z_tf, z_ar
|
| 438 |
+
|
| 439 |
+
def loss_fn(z, h):
|
| 440 |
+
_h = h[:, tokens_per_frame : z.size(1) + tokens_per_frame]
|
| 441 |
+
return torch.mean(torch.abs(z - _h) ** loss_exp) / loss_exp
|
| 442 |
+
|
| 443 |
+
# Step 1. Forward
|
| 444 |
+
with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision):
|
| 445 |
+
h = forward_target(clips)
|
| 446 |
+
z_tf, z_ar = forward_predictions(h)
|
| 447 |
+
jloss = loss_fn(z_tf, h)
|
| 448 |
+
sloss = loss_fn(z_ar, h)
|
| 449 |
+
loss = jloss + sloss
|
| 450 |
+
|
| 451 |
+
# Step 2. Backward & step
|
| 452 |
+
if mixed_precision:
|
| 453 |
+
scaler.scale(loss).backward()
|
| 454 |
+
scaler.unscale_(optimizer)
|
| 455 |
+
else:
|
| 456 |
+
loss.backward()
|
| 457 |
+
if mixed_precision:
|
| 458 |
+
scaler.step(optimizer)
|
| 459 |
+
scaler.update()
|
| 460 |
+
else:
|
| 461 |
+
optimizer.step()
|
| 462 |
+
optimizer.zero_grad()
|
| 463 |
+
|
| 464 |
+
return (
|
| 465 |
+
float(loss),
|
| 466 |
+
float(jloss),
|
| 467 |
+
float(sloss),
|
| 468 |
+
_new_lr,
|
| 469 |
+
_new_wd,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
(
|
| 473 |
+
loss,
|
| 474 |
+
jloss,
|
| 475 |
+
sloss,
|
| 476 |
+
_new_lr,
|
| 477 |
+
_new_wd,
|
| 478 |
+
), gpu_etime_ms = gpu_timer(train_step)
|
| 479 |
+
iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
|
| 480 |
+
loss_meter.update(loss)
|
| 481 |
+
jloss_meter.update(jloss)
|
| 482 |
+
sloss_meter.update(sloss)
|
| 483 |
+
iter_time_meter.update(iter_elapsed_time_ms)
|
| 484 |
+
gpu_time_meter.update(gpu_etime_ms)
|
| 485 |
+
data_elapsed_time_meter.update(data_elapsed_time_ms)
|
| 486 |
+
|
| 487 |
+
# -- Logging
|
| 488 |
+
def log_stats():
|
| 489 |
+
csv_logger.log(epoch + 1, itr, loss, iter_elapsed_time_ms, gpu_etime_ms, data_elapsed_time_ms)
|
| 490 |
+
if (itr % log_freq == 0) or (itr == ipe - 1) or np.isnan(loss) or np.isinf(loss):
|
| 491 |
+
logger.info(
|
| 492 |
+
"[%d, %5d] loss: %.3f [%.2f, %.2f] "
|
| 493 |
+
"[wd: %.2e] [lr: %.2e] "
|
| 494 |
+
"[mem: %.2e] "
|
| 495 |
+
"[iter: %.1f ms] "
|
| 496 |
+
"[gpu: %.1f ms] "
|
| 497 |
+
"[data: %.1f ms]"
|
| 498 |
+
% (
|
| 499 |
+
epoch + 1,
|
| 500 |
+
itr,
|
| 501 |
+
loss_meter.avg,
|
| 502 |
+
jloss_meter.avg,
|
| 503 |
+
sloss_meter.avg,
|
| 504 |
+
_new_wd,
|
| 505 |
+
_new_lr,
|
| 506 |
+
torch.cuda.max_memory_allocated() / 1024.0**2,
|
| 507 |
+
iter_time_meter.avg,
|
| 508 |
+
gpu_time_meter.avg,
|
| 509 |
+
data_elapsed_time_meter.avg,
|
| 510 |
+
)
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
log_stats()
|
| 514 |
+
assert not np.isnan(loss), "loss is nan"
|
| 515 |
+
|
| 516 |
+
# -- Save Checkpoint
|
| 517 |
+
logger.info("avg. loss %.3f" % loss_meter.avg)
|
| 518 |
+
# -- Save Last
|
| 519 |
+
if epoch % CHECKPOINT_FREQ == 0 or epoch == (num_epochs - 1):
|
| 520 |
+
save_checkpoint(epoch + 1, latest_path)
|
| 521 |
+
if save_every_freq > 0 and epoch % save_every_freq == 0:
|
| 522 |
+
save_every_file = f"e{epoch}.pt"
|
| 523 |
+
save_every_path = os.path.join(folder, save_every_file)
|
| 524 |
+
save_checkpoint(epoch + 1, save_every_path)
|
vjepa2/app/vjepa_droid/transforms.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
|
| 11 |
+
import src.datasets.utils.video.transforms as video_transforms
|
| 12 |
+
from src.datasets.utils.video.randerase import RandomErasing
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def make_transforms(
|
| 16 |
+
random_horizontal_flip=True,
|
| 17 |
+
random_resize_aspect_ratio=(3 / 4, 4 / 3),
|
| 18 |
+
random_resize_scale=(0.3, 1.0),
|
| 19 |
+
reprob=0.0,
|
| 20 |
+
auto_augment=False,
|
| 21 |
+
motion_shift=False,
|
| 22 |
+
crop_size=224,
|
| 23 |
+
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 24 |
+
):
|
| 25 |
+
|
| 26 |
+
_frames_augmentation = VideoTransform(
|
| 27 |
+
random_horizontal_flip=random_horizontal_flip,
|
| 28 |
+
random_resize_aspect_ratio=random_resize_aspect_ratio,
|
| 29 |
+
random_resize_scale=random_resize_scale,
|
| 30 |
+
reprob=reprob,
|
| 31 |
+
auto_augment=auto_augment,
|
| 32 |
+
motion_shift=motion_shift,
|
| 33 |
+
crop_size=crop_size,
|
| 34 |
+
normalize=normalize,
|
| 35 |
+
)
|
| 36 |
+
return _frames_augmentation
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class VideoTransform(object):
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
random_horizontal_flip=True,
|
| 44 |
+
random_resize_aspect_ratio=(3 / 4, 4 / 3),
|
| 45 |
+
random_resize_scale=(0.3, 1.0),
|
| 46 |
+
reprob=0.0,
|
| 47 |
+
auto_augment=False,
|
| 48 |
+
motion_shift=False,
|
| 49 |
+
crop_size=224,
|
| 50 |
+
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 51 |
+
):
|
| 52 |
+
|
| 53 |
+
self.random_horizontal_flip = random_horizontal_flip
|
| 54 |
+
self.random_resize_aspect_ratio = random_resize_aspect_ratio
|
| 55 |
+
self.random_resize_scale = random_resize_scale
|
| 56 |
+
self.auto_augment = auto_augment
|
| 57 |
+
self.motion_shift = motion_shift
|
| 58 |
+
self.crop_size = crop_size
|
| 59 |
+
self.mean = torch.tensor(normalize[0], dtype=torch.float32)
|
| 60 |
+
self.std = torch.tensor(normalize[1], dtype=torch.float32)
|
| 61 |
+
if not self.auto_augment:
|
| 62 |
+
# Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
|
| 63 |
+
self.mean *= 255.0
|
| 64 |
+
self.std *= 255.0
|
| 65 |
+
|
| 66 |
+
self.autoaug_transform = video_transforms.create_random_augment(
|
| 67 |
+
input_size=(crop_size, crop_size),
|
| 68 |
+
# auto_augment="rand-m4-n4-w1-mstd0.5-inc1",
|
| 69 |
+
auto_augment="rand-m7-n4-mstd0.5-inc1",
|
| 70 |
+
interpolation="bicubic",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.spatial_transform = (
|
| 74 |
+
video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.reprob = reprob
|
| 78 |
+
self.erase_transform = RandomErasing(
|
| 79 |
+
reprob,
|
| 80 |
+
mode="pixel",
|
| 81 |
+
max_count=1,
|
| 82 |
+
num_splits=1,
|
| 83 |
+
device="cpu",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def __call__(self, buffer):
|
| 87 |
+
|
| 88 |
+
if self.auto_augment:
|
| 89 |
+
buffer = [transforms.ToPILImage()(frame) for frame in buffer]
|
| 90 |
+
buffer = self.autoaug_transform(buffer)
|
| 91 |
+
buffer = [transforms.ToTensor()(img) for img in buffer]
|
| 92 |
+
buffer = torch.stack(buffer) # T C H W
|
| 93 |
+
buffer = buffer.permute(0, 2, 3, 1) # T H W C
|
| 94 |
+
elif torch.is_tensor(buffer):
|
| 95 |
+
# TODO: ensure input is always a tensor?
|
| 96 |
+
buffer = buffer.to(torch.float32)
|
| 97 |
+
else:
|
| 98 |
+
buffer = torch.tensor(buffer, dtype=torch.float32)
|
| 99 |
+
|
| 100 |
+
buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
|
| 101 |
+
|
| 102 |
+
buffer = self.spatial_transform(
|
| 103 |
+
images=buffer,
|
| 104 |
+
target_height=self.crop_size,
|
| 105 |
+
target_width=self.crop_size,
|
| 106 |
+
scale=self.random_resize_scale,
|
| 107 |
+
ratio=self.random_resize_aspect_ratio,
|
| 108 |
+
)
|
| 109 |
+
if self.random_horizontal_flip:
|
| 110 |
+
buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
|
| 111 |
+
|
| 112 |
+
buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
|
| 113 |
+
if self.reprob > 0:
|
| 114 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
| 115 |
+
buffer = self.erase_transform(buffer)
|
| 116 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
| 117 |
+
|
| 118 |
+
return buffer
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def tensor_normalize(tensor, mean, std):
|
| 122 |
+
"""
|
| 123 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
| 124 |
+
Args:
|
| 125 |
+
tensor (tensor): tensor to normalize.
|
| 126 |
+
mean (tensor or list): mean value to subtract.
|
| 127 |
+
std (tensor or list): std to divide.
|
| 128 |
+
"""
|
| 129 |
+
if tensor.dtype == torch.uint8:
|
| 130 |
+
tensor = tensor.float()
|
| 131 |
+
tensor = tensor / 255.0
|
| 132 |
+
if type(mean) == list:
|
| 133 |
+
mean = torch.tensor(mean)
|
| 134 |
+
if type(std) == list:
|
| 135 |
+
std = torch.tensor(std)
|
| 136 |
+
tensor = tensor - mean
|
| 137 |
+
tensor = tensor / std
|
| 138 |
+
return tensor
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _tensor_normalize_inplace(tensor, mean, std):
|
| 142 |
+
"""
|
| 143 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
| 144 |
+
Args:
|
| 145 |
+
tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
|
| 146 |
+
mean (tensor): mean value to subtract (in 0 to 255 floats).
|
| 147 |
+
std (tensor): std to divide (in 0 to 255 floats).
|
| 148 |
+
"""
|
| 149 |
+
if tensor.dtype == torch.uint8:
|
| 150 |
+
tensor = tensor.float()
|
| 151 |
+
|
| 152 |
+
C, T, H, W = tensor.shape
|
| 153 |
+
tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
|
| 154 |
+
tensor.sub_(mean).div_(std)
|
| 155 |
+
tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
|
| 156 |
+
return tensor
|
vjepa2/app/vjepa_droid/utils.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
import src.models.ac_predictor as vit_ac_pred
|
| 14 |
+
import src.models.vision_transformer as video_vit
|
| 15 |
+
from src.utils.checkpoint_loader import robust_checkpoint_loader
|
| 16 |
+
from src.utils.schedulers import CosineWDSchedule, WSDSchedule
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_pretrained(
|
| 23 |
+
r_path,
|
| 24 |
+
encoder=None,
|
| 25 |
+
predictor=None,
|
| 26 |
+
target_encoder=None,
|
| 27 |
+
context_encoder_key="encoder",
|
| 28 |
+
target_encoder_key="target_encoder",
|
| 29 |
+
load_predictor=False,
|
| 30 |
+
load_encoder=True,
|
| 31 |
+
):
|
| 32 |
+
logger.info(f"Loading pretrained model from {r_path}")
|
| 33 |
+
checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
|
| 34 |
+
|
| 35 |
+
epoch = checkpoint["epoch"]
|
| 36 |
+
|
| 37 |
+
if load_encoder:
|
| 38 |
+
# -- loading encoder
|
| 39 |
+
pretrained_dict = checkpoint[context_encoder_key]
|
| 40 |
+
pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
|
| 41 |
+
msg = encoder.load_state_dict(pretrained_dict, strict=False)
|
| 42 |
+
logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")
|
| 43 |
+
|
| 44 |
+
if load_predictor:
|
| 45 |
+
# -- loading predictor
|
| 46 |
+
pretrained_dict = checkpoint["predictor"]
|
| 47 |
+
pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
|
| 48 |
+
msg = predictor.load_state_dict(pretrained_dict, strict=False)
|
| 49 |
+
logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")
|
| 50 |
+
|
| 51 |
+
# -- loading target_encoder
|
| 52 |
+
if load_encoder:
|
| 53 |
+
if target_encoder is not None:
|
| 54 |
+
print(list(checkpoint.keys()))
|
| 55 |
+
pretrained_dict = checkpoint[target_encoder_key]
|
| 56 |
+
pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
|
| 57 |
+
msg = target_encoder.load_state_dict(pretrained_dict, strict=False)
|
| 58 |
+
logger.info(f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}")
|
| 59 |
+
|
| 60 |
+
del checkpoint
|
| 61 |
+
|
| 62 |
+
return (
|
| 63 |
+
encoder,
|
| 64 |
+
predictor,
|
| 65 |
+
target_encoder,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_checkpoint(
|
| 70 |
+
r_path,
|
| 71 |
+
encoder,
|
| 72 |
+
predictor,
|
| 73 |
+
target_encoder,
|
| 74 |
+
opt=None,
|
| 75 |
+
scaler=None,
|
| 76 |
+
replace_kw=["backbone."],
|
| 77 |
+
):
|
| 78 |
+
logger.info(f"Loading checkpoint from {r_path}")
|
| 79 |
+
checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
|
| 80 |
+
|
| 81 |
+
epoch = checkpoint["epoch"]
|
| 82 |
+
|
| 83 |
+
# -- loading encoder
|
| 84 |
+
pretrained_dict = checkpoint["encoder"]
|
| 85 |
+
for kw in replace_kw:
|
| 86 |
+
pretrained_dict = {k.replace(kw, ""): v for k, v in pretrained_dict.items()}
|
| 87 |
+
msg = encoder.load_state_dict(pretrained_dict, strict=False)
|
| 88 |
+
logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")
|
| 89 |
+
|
| 90 |
+
# -- loading predictor
|
| 91 |
+
pretrained_dict = checkpoint["predictor"]
|
| 92 |
+
for kw in replace_kw:
|
| 93 |
+
pretrained_dict = {k.replace(kw, ""): v for k, v in pretrained_dict.items()}
|
| 94 |
+
msg = predictor.load_state_dict(pretrained_dict, strict=False)
|
| 95 |
+
logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")
|
| 96 |
+
|
| 97 |
+
# -- loading target_encoder
|
| 98 |
+
if target_encoder is not None:
|
| 99 |
+
print(list(checkpoint.keys()))
|
| 100 |
+
pretrained_dict = checkpoint["target_encoder"]
|
| 101 |
+
for kw in replace_kw:
|
| 102 |
+
pretrained_dict = {k.replace(kw, ""): v for k, v in pretrained_dict.items()}
|
| 103 |
+
msg = target_encoder.load_state_dict(pretrained_dict, strict=False)
|
| 104 |
+
logger.info(f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}")
|
| 105 |
+
|
| 106 |
+
# -- loading optimizer
|
| 107 |
+
if opt is not None:
|
| 108 |
+
opt.load_state_dict(checkpoint["opt"])
|
| 109 |
+
|
| 110 |
+
if scaler is not None:
|
| 111 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 112 |
+
|
| 113 |
+
logger.info(f"loaded optimizers from epoch {epoch}")
|
| 114 |
+
logger.info(f"read-path: {r_path}")
|
| 115 |
+
del checkpoint
|
| 116 |
+
|
| 117 |
+
return (
|
| 118 |
+
encoder,
|
| 119 |
+
predictor,
|
| 120 |
+
target_encoder,
|
| 121 |
+
opt,
|
| 122 |
+
scaler,
|
| 123 |
+
epoch,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def init_video_model(
|
| 128 |
+
device,
|
| 129 |
+
patch_size=16,
|
| 130 |
+
max_num_frames=16,
|
| 131 |
+
tubelet_size=2,
|
| 132 |
+
model_name="vit_base",
|
| 133 |
+
crop_size=224,
|
| 134 |
+
pred_depth=6,
|
| 135 |
+
pred_num_heads=None,
|
| 136 |
+
pred_embed_dim=384,
|
| 137 |
+
uniform_power=False,
|
| 138 |
+
use_sdpa=False,
|
| 139 |
+
use_rope=False,
|
| 140 |
+
use_silu=False,
|
| 141 |
+
use_pred_silu=False,
|
| 142 |
+
wide_silu=False,
|
| 143 |
+
pred_is_frame_causal=True,
|
| 144 |
+
use_activation_checkpointing=False,
|
| 145 |
+
return_all_tokens=False,
|
| 146 |
+
action_embed_dim=7,
|
| 147 |
+
use_extrinsics=False,
|
| 148 |
+
old_pred=False,
|
| 149 |
+
):
|
| 150 |
+
encoder = video_vit.__dict__[model_name](
|
| 151 |
+
img_size=crop_size,
|
| 152 |
+
patch_size=patch_size,
|
| 153 |
+
num_frames=max_num_frames,
|
| 154 |
+
tubelet_size=tubelet_size,
|
| 155 |
+
uniform_power=uniform_power,
|
| 156 |
+
use_sdpa=use_sdpa,
|
| 157 |
+
use_silu=use_silu,
|
| 158 |
+
wide_silu=wide_silu,
|
| 159 |
+
use_activation_checkpointing=use_activation_checkpointing,
|
| 160 |
+
use_rope=use_rope,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
predictor = vit_ac_pred.__dict__["vit_ac_predictor"](
|
| 164 |
+
img_size=crop_size,
|
| 165 |
+
patch_size=patch_size,
|
| 166 |
+
num_frames=max_num_frames,
|
| 167 |
+
tubelet_size=tubelet_size,
|
| 168 |
+
embed_dim=encoder.embed_dim,
|
| 169 |
+
predictor_embed_dim=pred_embed_dim,
|
| 170 |
+
action_embed_dim=action_embed_dim,
|
| 171 |
+
depth=pred_depth,
|
| 172 |
+
is_frame_causal=pred_is_frame_causal,
|
| 173 |
+
num_heads=encoder.num_heads if pred_num_heads is None else pred_num_heads,
|
| 174 |
+
uniform_power=uniform_power,
|
| 175 |
+
use_rope=use_rope,
|
| 176 |
+
use_sdpa=use_sdpa,
|
| 177 |
+
use_silu=use_pred_silu,
|
| 178 |
+
wide_silu=wide_silu,
|
| 179 |
+
use_extrinsics=use_extrinsics,
|
| 180 |
+
use_activation_checkpointing=use_activation_checkpointing,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
encoder.to(device)
|
| 184 |
+
predictor.to(device)
|
| 185 |
+
logger.info(encoder)
|
| 186 |
+
logger.info(predictor)
|
| 187 |
+
|
| 188 |
+
def count_parameters(model):
|
| 189 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 190 |
+
|
| 191 |
+
logger.info(f"Encoder number of parameters: {count_parameters(encoder)}")
|
| 192 |
+
logger.info(f"Predictor number of parameters: {count_parameters(predictor)}")
|
| 193 |
+
|
| 194 |
+
return encoder, predictor
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def init_opt(
|
| 198 |
+
encoder,
|
| 199 |
+
predictor,
|
| 200 |
+
iterations_per_epoch,
|
| 201 |
+
start_lr,
|
| 202 |
+
ref_lr,
|
| 203 |
+
warmup,
|
| 204 |
+
anneal,
|
| 205 |
+
num_epochs,
|
| 206 |
+
wd=1e-6,
|
| 207 |
+
final_wd=1e-6,
|
| 208 |
+
final_lr=0.0,
|
| 209 |
+
mixed_precision=False,
|
| 210 |
+
betas=(0.9, 0.999),
|
| 211 |
+
eps=1e-8,
|
| 212 |
+
zero_init_bias_wd=True,
|
| 213 |
+
enc_lr_scale=1.0,
|
| 214 |
+
):
|
| 215 |
+
param_groups = [
|
| 216 |
+
{
|
| 217 |
+
"params": (p for n, p in encoder.named_parameters() if ("bias" not in n) and (len(p.shape) != 1)),
|
| 218 |
+
"lr_scale": enc_lr_scale,
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"params": (p for n, p in predictor.named_parameters() if ("bias" not in n) and (len(p.shape) != 1)),
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"params": (p for n, p in encoder.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
|
| 225 |
+
"WD_exclude": zero_init_bias_wd,
|
| 226 |
+
"weight_decay": 0,
|
| 227 |
+
"lr_scale": enc_lr_scale,
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"params": (p for n, p in predictor.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
|
| 231 |
+
"WD_exclude": zero_init_bias_wd,
|
| 232 |
+
"weight_decay": 0,
|
| 233 |
+
},
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
|
| 237 |
+
scheduler = WSDSchedule(
|
| 238 |
+
optimizer,
|
| 239 |
+
warmup_steps=int(warmup * iterations_per_epoch),
|
| 240 |
+
anneal_steps=int(anneal * iterations_per_epoch),
|
| 241 |
+
start_lr=start_lr,
|
| 242 |
+
ref_lr=ref_lr,
|
| 243 |
+
final_lr=final_lr,
|
| 244 |
+
T_max=int(num_epochs * iterations_per_epoch),
|
| 245 |
+
)
|
| 246 |
+
wd_scheduler = CosineWDSchedule(
|
| 247 |
+
optimizer,
|
| 248 |
+
ref_wd=wd,
|
| 249 |
+
final_wd=final_wd,
|
| 250 |
+
T_max=int(num_epochs * iterations_per_epoch),
|
| 251 |
+
)
|
| 252 |
+
scaler = torch.cuda.amp.GradScaler() if mixed_precision else None
|
| 253 |
+
return optimizer, scaler, scheduler, wd_scheduler
|
vjepa2/configs/eval/vitg-384/coin.yaml
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cpus_per_task: 16
|
| 2 |
+
eval_name: video_classification_frozen
|
| 3 |
+
folder: /your_folder/evals/vitg-384/coin
|
| 4 |
+
mem_per_gpu: 220G
|
| 5 |
+
nodes: 16
|
| 6 |
+
resume_checkpoint: true
|
| 7 |
+
tag: coin-vitg16-384-16x8x3
|
| 8 |
+
tasks_per_node: 8
|
| 9 |
+
experiment:
|
| 10 |
+
classifier:
|
| 11 |
+
num_heads: 16
|
| 12 |
+
num_probe_blocks: 4
|
| 13 |
+
data:
|
| 14 |
+
dataset_type: VideoDataset
|
| 15 |
+
dataset_train: /your_data_folder/COIN/train_paths.csv
|
| 16 |
+
dataset_val: /your_data_folder/COIN/val_paths.csv
|
| 17 |
+
frame_step: 4
|
| 18 |
+
frames_per_clip: 16
|
| 19 |
+
num_classes: 180
|
| 20 |
+
num_segments: 8
|
| 21 |
+
num_views_per_segment: 3
|
| 22 |
+
resolution: 384
|
| 23 |
+
optimization:
|
| 24 |
+
batch_size: 1
|
| 25 |
+
multihead_kwargs:
|
| 26 |
+
- final_lr: 0.0
|
| 27 |
+
final_weight_decay: 0.01
|
| 28 |
+
lr: 0.005
|
| 29 |
+
start_lr: 0.005
|
| 30 |
+
warmup: 0.0
|
| 31 |
+
weight_decay: 0.01
|
| 32 |
+
- final_lr: 0.0
|
| 33 |
+
final_weight_decay: 0.01
|
| 34 |
+
lr: 0.003
|
| 35 |
+
start_lr: 0.003
|
| 36 |
+
warmup: 0.0
|
| 37 |
+
weight_decay: 0.01
|
| 38 |
+
- final_lr: 0.0
|
| 39 |
+
final_weight_decay: 0.01
|
| 40 |
+
lr: 0.001
|
| 41 |
+
start_lr: 0.001
|
| 42 |
+
warmup: 0.0
|
| 43 |
+
weight_decay: 0.01
|
| 44 |
+
- final_lr: 0.0
|
| 45 |
+
final_weight_decay: 0.01
|
| 46 |
+
lr: 0.0003
|
| 47 |
+
start_lr: 0.0003
|
| 48 |
+
warmup: 0.0
|
| 49 |
+
weight_decay: 0.01
|
| 50 |
+
- final_lr: 0.0
|
| 51 |
+
final_weight_decay: 0.01
|
| 52 |
+
lr: 0.0001
|
| 53 |
+
start_lr: 0.0001
|
| 54 |
+
warmup: 0.0
|
| 55 |
+
weight_decay: 0.01
|
| 56 |
+
- final_lr: 0.0
|
| 57 |
+
final_weight_decay: 0.1
|
| 58 |
+
lr: 0.005
|
| 59 |
+
start_lr: 0.005
|
| 60 |
+
warmup: 0.0
|
| 61 |
+
weight_decay: 0.1
|
| 62 |
+
- final_lr: 0.0
|
| 63 |
+
final_weight_decay: 0.1
|
| 64 |
+
lr: 0.003
|
| 65 |
+
start_lr: 0.003
|
| 66 |
+
warmup: 0.0
|
| 67 |
+
weight_decay: 0.1
|
| 68 |
+
- final_lr: 0.0
|
| 69 |
+
final_weight_decay: 0.1
|
| 70 |
+
lr: 0.001
|
| 71 |
+
start_lr: 0.001
|
| 72 |
+
warmup: 0.0
|
| 73 |
+
weight_decay: 0.1
|
| 74 |
+
- final_lr: 0.0
|
| 75 |
+
final_weight_decay: 0.1
|
| 76 |
+
lr: 0.0003
|
| 77 |
+
start_lr: 0.0003
|
| 78 |
+
warmup: 0.0
|
| 79 |
+
weight_decay: 0.1
|
| 80 |
+
- final_lr: 0.0
|
| 81 |
+
final_weight_decay: 0.1
|
| 82 |
+
lr: 0.0001
|
| 83 |
+
start_lr: 0.0001
|
| 84 |
+
warmup: 0.0
|
| 85 |
+
weight_decay: 0.1
|
| 86 |
+
- final_lr: 0.0
|
| 87 |
+
final_weight_decay: 0.4
|
| 88 |
+
lr: 0.005
|
| 89 |
+
start_lr: 0.005
|
| 90 |
+
warmup: 0.0
|
| 91 |
+
weight_decay: 0.4
|
| 92 |
+
- final_lr: 0.0
|
| 93 |
+
final_weight_decay: 0.4
|
| 94 |
+
lr: 0.003
|
| 95 |
+
start_lr: 0.003
|
| 96 |
+
warmup: 0.0
|
| 97 |
+
weight_decay: 0.4
|
| 98 |
+
- final_lr: 0.0
|
| 99 |
+
final_weight_decay: 0.4
|
| 100 |
+
lr: 0.001
|
| 101 |
+
start_lr: 0.001
|
| 102 |
+
warmup: 0.0
|
| 103 |
+
weight_decay: 0.4
|
| 104 |
+
- final_lr: 0.0
|
| 105 |
+
final_weight_decay: 0.4
|
| 106 |
+
lr: 0.0003
|
| 107 |
+
start_lr: 0.0003
|
| 108 |
+
warmup: 0.0
|
| 109 |
+
weight_decay: 0.4
|
| 110 |
+
- final_lr: 0.0
|
| 111 |
+
final_weight_decay: 0.4
|
| 112 |
+
lr: 0.0001
|
| 113 |
+
start_lr: 0.0001
|
| 114 |
+
warmup: 0.0
|
| 115 |
+
weight_decay: 0.4
|
| 116 |
+
- final_lr: 0.0
|
| 117 |
+
final_weight_decay: 0.8
|
| 118 |
+
lr: 0.005
|
| 119 |
+
start_lr: 0.005
|
| 120 |
+
warmup: 0.0
|
| 121 |
+
weight_decay: 0.8
|
| 122 |
+
- final_lr: 0.0
|
| 123 |
+
final_weight_decay: 0.8
|
| 124 |
+
lr: 0.003
|
| 125 |
+
start_lr: 0.003
|
| 126 |
+
warmup: 0.0
|
| 127 |
+
weight_decay: 0.8
|
| 128 |
+
- final_lr: 0.0
|
| 129 |
+
final_weight_decay: 0.8
|
| 130 |
+
lr: 0.001
|
| 131 |
+
start_lr: 0.001
|
| 132 |
+
warmup: 0.0
|
| 133 |
+
weight_decay: 0.8
|
| 134 |
+
- final_lr: 0.0
|
| 135 |
+
final_weight_decay: 0.8
|
| 136 |
+
lr: 0.0003
|
| 137 |
+
start_lr: 0.0003
|
| 138 |
+
warmup: 0.0
|
| 139 |
+
weight_decay: 0.8
|
| 140 |
+
- final_lr: 0.0
|
| 141 |
+
final_weight_decay: 0.8
|
| 142 |
+
lr: 0.0001
|
| 143 |
+
start_lr: 0.0001
|
| 144 |
+
warmup: 0.0
|
| 145 |
+
weight_decay: 0.8
|
| 146 |
+
num_epochs: 20
|
| 147 |
+
use_bfloat16: true
|
| 148 |
+
use_pos_embed: false
|
| 149 |
+
model_kwargs:
|
| 150 |
+
checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
|
| 151 |
+
module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
|
| 152 |
+
pretrain_kwargs:
|
| 153 |
+
encoder:
|
| 154 |
+
checkpoint_key: target_encoder
|
| 155 |
+
img_temporal_dim_size: null
|
| 156 |
+
model_name: vit_giant_xformers
|
| 157 |
+
patch_size: 16
|
| 158 |
+
tubelet_size: 2
|
| 159 |
+
uniform_power: true
|
| 160 |
+
use_rope: true
|
| 161 |
+
wrapper_kwargs:
|
| 162 |
+
max_frames: 128
|
| 163 |
+
use_pos_embed: false
|