Copilot Copilot commited on
Commit
d28c63e
·
1 Parent(s): d9e4621

Add AI-Endo project hub UI

Browse files

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +1 -1
  2. README.md +33 -11
  3. app.py +544 -114
  4. dinov2/.github/workflows/lint.yaml +38 -0
  5. dinov2/.gitignore +11 -0
  6. dinov2/CODE_OF_CONDUCT.md +80 -0
  7. dinov2/CONTRIBUTING.md +31 -0
  8. dinov2/LICENSE +203 -0
  9. dinov2/MODEL_CARD.md +272 -0
  10. dinov2/README.md +620 -0
  11. dinov2/conda-extras.yaml +24 -0
  12. dinov2/conda.yaml +21 -0
  13. dinov2/pyproject.toml +29 -0
  14. dinov2/requirements-dev.txt +3 -0
  15. dinov2/requirements-extras.txt +2 -0
  16. dinov2/requirements.txt +11 -0
  17. dinov2/scripts/lint.sh +28 -0
  18. dinov2/setup.cfg +8 -0
  19. dinov2/setup.py +88 -0
  20. explainability.py +112 -0
  21. model/transformer.py +28 -8
  22. model_manager.py +10 -0
  23. predictor.py +372 -15
  24. runtime-requirements.txt +1 -1
  25. scripts/publish_model_repo.py +156 -0
  26. scripts/publish_space_repo.py +98 -0
  27. scripts/stage_space_bundle.py +104 -0
  28. scripts/stage_vendor_sources.py +38 -0
  29. vjepa2/.flake8 +5 -0
  30. vjepa2/.github/workflows/base_tests.yaml +29 -0
  31. vjepa2/.github/workflows/linters.yaml +48 -0
  32. vjepa2/.gitignore +32 -0
  33. vjepa2/APACHE-LICENSE +201 -0
  34. vjepa2/CHANGELOG.md +5 -0
  35. vjepa2/CODE_OF_CONDUCT.md +80 -0
  36. vjepa2/CONTRIBUTING.md +39 -0
  37. vjepa2/LICENSE +21 -0
  38. vjepa2/README.md +450 -0
  39. vjepa2/app/__init__.py +0 -0
  40. vjepa2/app/main.py +84 -0
  41. vjepa2/app/main_distributed.py +269 -0
  42. vjepa2/app/scaffold.py +17 -0
  43. vjepa2/app/vjepa/train.py +536 -0
  44. vjepa2/app/vjepa/transforms.py +154 -0
  45. vjepa2/app/vjepa/utils.py +267 -0
  46. vjepa2/app/vjepa_droid/droid.py +232 -0
  47. vjepa2/app/vjepa_droid/train.py +524 -0
  48. vjepa2/app/vjepa_droid/transforms.py +156 -0
  49. vjepa2/app/vjepa_droid/utils.py +253 -0
  50. vjepa2/configs/eval/vitg-384/coin.yaml +163 -0
Dockerfile CHANGED
@@ -4,7 +4,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
4
  PYTHONUNBUFFERED=1 \
5
  PIP_NO_CACHE_DIR=1 \
6
  SPACE_MODEL_DIR=/app/model \
7
- SPACE_ENABLED_MODELS=dinov2 \
8
  SPACE_DEFAULT_MODEL=dinov2
9
 
10
  RUN apt-get update && apt-get install -y --no-install-recommends \
 
4
  PYTHONUNBUFFERED=1 \
5
  PIP_NO_CACHE_DIR=1 \
6
  SPACE_MODEL_DIR=/app/model \
7
+ SPACE_ENABLED_MODELS=dinov2,aiendo,vjepa2 \
8
  SPACE_DEFAULT_MODEL=dinov2
9
 
10
  RUN apt-get update && apt-get install -y --no-install-recommends \
README.md CHANGED
@@ -1,5 +1,4 @@
1
- ---
2
- title: DINO-ENDO Phase Recognition
3
  emoji: 🩺
4
  colorFrom: blue
5
  colorTo: green
@@ -7,11 +6,12 @@ sdk: docker
7
  app_port: 7860
8
  ---
9
 
10
- # DINO-ENDO Streamlit Space
11
 
12
  This folder is an isolated Hugging Face Space scaffold for the phase-recognition models in this repository.
13
- It is intentionally separate from the existing FastAPI webapp and defaults to a **DINO-Endo demo** on paid GPU hardware such as **1x A10G (24 GB VRAM)**.
14
- The same code can still expose AI-Endo and V-JEPA2 when you opt into them through environment variables.
 
15
 
16
  ## Supported model families
17
 
@@ -42,15 +42,15 @@ A fully local `model/` folder is still supported as a fallback.
42
 
43
  ## Default Space behavior
44
 
45
- The Docker Space is configured to boot as a **DINO-Endo-first demo**:
46
 
47
- - `SPACE_ENABLED_MODELS=dinov2`
48
  - `SPACE_DEFAULT_MODEL=dinov2`
49
 
50
- If you want the same Space build to expose multiple model families again, override those environment variables in Space Settings, for example:
51
 
52
  ```text
53
- SPACE_ENABLED_MODELS=dinov2,aiendo,vjepa2
54
  SPACE_DEFAULT_MODEL=dinov2
55
  ```
56
 
@@ -67,11 +67,22 @@ If a required checkpoint is missing locally, it will try to download it from the
67
 
68
  ### Upload and dashboard behavior
69
 
70
- - The Space now keeps a single active predictor loaded at a time and unloads the previous model when the picker changes.
 
 
71
  - MP4 is the primary video upload format, while `mov`, `avi`, `mkv`, `webm`, and `m4v` remain enabled as fallback containers.
72
  - `.streamlit/config.toml` raises the default Streamlit single-file upload ceiling to **4096 MB** for this Space.
73
  - Uploaded videos are immediately spooled to local disk for metadata probing and analysis, instead of repeatedly reading the in-memory upload object on every rerun.
74
  - The UI shows file size, duration, fps, frame count, resolution, working-storage headroom, and suppresses inline preview for very large uploads to keep the browser path lighter.
 
 
 
 
 
 
 
 
 
75
 
76
  ### Common environment variables
77
 
@@ -116,6 +127,15 @@ That script refreshes the vendored source copies inside this folder before publi
116
  4. Upload your checkpoints to HF **model repo(s)**.
117
  5. Configure the relevant repo IDs (and `HF_TOKEN` only if the repos are private).
118
 
 
 
 
 
 
 
 
 
 
119
  ## Local smoke test
120
 
121
  Once the Space dependencies are installed, you can smoke test a predictor directly:
@@ -129,12 +149,14 @@ python scripts/smoke_test.py --model vjepa2 --model-dir /path/to/model
129
  ## Scope of v1
130
 
131
  - Streamlit UI
132
- - DINO-Endo demo by default, with optional multi-model selector when enabled
 
133
  - image upload and video upload
134
  - dashboard-style model/runtime status
135
  - robust video metadata probing with OpenCV + ffprobe fallback
136
  - large single-file uploads up to the configured Streamlit cap
137
  - per-frame phase timeline output for video
 
138
  - JSON / CSV export
139
 
140
  Not included in v1:
 
1
+ title: AI-Endo Project Hub
 
2
  emoji: 🩺
3
  colorFrom: blue
4
  colorTo: green
 
6
  app_port: 7860
7
  ---
8
 
9
+ # AI-Endo Project Hub
10
 
11
  This folder is an isolated Hugging Face Space scaffold for the phase-recognition models in this repository.
12
+ It is intentionally separate from the existing FastAPI webapp and is designed to expose **DINO-Endo, AI-Endo, and V-JEPA2** on paid GPU hardware such as **1x A10G (24 GB VRAM)**.
13
+ The public UI now behaves like a small **project hub**: DINO-Endo Surgery is the first featured workspace, and the same landing page can later host additional projects without rebuilding the overall shell.
14
+ The default featured model remains **DINO-Endo**, but the same Space can load and unload all three model families one at a time.
15
 
16
  ## Supported model families
17
 
 
42
 
43
  ## Default Space behavior
44
 
45
+ The Docker Space is configured to boot as a **three-model public demo** with **DINO-Endo** selected by default:
46
 
47
+ - `SPACE_ENABLED_MODELS=dinov2,aiendo,vjepa2`
48
  - `SPACE_DEFAULT_MODEL=dinov2`
49
 
50
+ If you want to narrow the public picker to a subset of models, override those environment variables in Space Settings, for example:
51
 
52
  ```text
53
+ SPACE_ENABLED_MODELS=dinov2
54
  SPACE_DEFAULT_MODEL=dinov2
55
  ```
56
 
 
67
 
68
  ### Upload and dashboard behavior
69
 
70
+ - The top of the app is a reusable project-hub landing section, with DINO-Endo Surgery as the current live workspace.
71
+ - The active model family is selected through a visible **model slider** in the workspace rather than a hidden picker.
72
+ - The Space now keeps a single active predictor loaded at a time and unloads the previous model when the model slider changes.
73
  - MP4 is the primary video upload format, while `mov`, `avi`, `mkv`, `webm`, and `m4v` remain enabled as fallback containers.
74
  - `.streamlit/config.toml` raises the default Streamlit single-file upload ceiling to **4096 MB** for this Space.
75
  - Uploaded videos are immediately spooled to local disk for metadata probing and analysis, instead of repeatedly reading the in-memory upload object on every rerun.
76
  - The UI shows file size, duration, fps, frame count, resolution, working-storage headroom, and suppresses inline preview for very large uploads to keep the browser path lighter.
77
+ - V-JEPA2 is labeled as a slower first load so users understand the cold-cache cost of its very large encoder checkpoint.
78
+
79
+ ### Explainability behavior
80
+
81
+ - The sidebar includes an opt-in live explainability toggle for encoder/decoder visualizations.
82
+ - DINO-Endo and V-JEPA2 use true encoder self-attention maps, while AI-Endo uses a labeled proxy encoder overlay from ResNet activations.
83
+ - AI-Endo and DINO-Endo render decoder-side temporal attention strips from the custom Transformer path.
84
+ - V-JEPA2 renders a labeled proxy temporal strip from decoder feature energy because its classifier head is an MLP, not an attention block.
85
+ - Encoder controls expose **layer/head sliders** when the loaded model supports true encoder attention.
86
 
87
  ### Common environment variables
88
 
 
127
  4. Upload your checkpoints to HF **model repo(s)**.
128
  5. Configure the relevant repo IDs (and `HF_TOKEN` only if the repos are private).
129
 
130
+ ### Deployment helper scripts
131
+
132
+ - `python scripts/stage_space_bundle.py --overwrite --output-dir /tmp/dino_space_minimal_upload`
133
+ - stages a code-only upload bundle for the current multi-model Space without local caches or checkpoints.
134
+ - `python scripts/publish_model_repo.py --family aiendo --repo-id <owner/repo> --model-dir /path/to/model`
135
+ - publishes one model family to a Hugging Face **model repo** and automatically switches to `upload_large_folder()` for very large bundles.
136
+ - `python scripts/publish_space_repo.py --repo-id <owner/space> --dino-model-repo-id <owner/dino-repo> --aiendo-model-repo-id <owner/aiendo-repo> --vjepa2-model-repo-id <owner/vjepa2-repo>`
137
+ - stages/uploads the Docker Space bundle and updates the key Space environment variables for the three-model demo.
138
+
139
  ## Local smoke test
140
 
141
  Once the Space dependencies are installed, you can smoke test a predictor directly:
 
149
  ## Scope of v1
150
 
151
  - Streamlit UI
152
+ - project-hub landing page with DINO-Endo Surgery as the first hosted workspace
153
+ - three-model slider for DINO-Endo, AI-Endo, and V-JEPA2, with DINO-Endo selected by default
154
  - image upload and video upload
155
  - dashboard-style model/runtime status
156
  - robust video metadata probing with OpenCV + ffprobe fallback
157
  - large single-file uploads up to the configured Streamlit cap
158
  - per-frame phase timeline output for video
159
+ - optional live encoder/decoder explainability sidebar with true attention where available and labeled proxies elsewhere
160
  - JSON / CSV export
161
 
162
  Not included in v1:
app.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  import os
5
  import time
6
  from collections import Counter
 
7
  from pathlib import Path
8
 
9
  import cv2
@@ -13,6 +14,7 @@ import streamlit as st
13
  import torch
14
  from PIL import Image
15
 
 
16
  from model_manager import SpaceModelManager
17
  from model_registry import MODEL_SPECS, get_model_source_summary
18
  from predictor import MODEL_LABELS, PHASE_LABELS, normalize_model_key
@@ -28,7 +30,80 @@ from video_utils import (
28
  spool_uploaded_video,
29
  )
30
 
31
- st.set_page_config(page_title="DINO-Endo Phase Recognition", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def _phase_index(phase: str) -> int:
@@ -43,6 +118,10 @@ def _image_to_rgb(uploaded_file) -> np.ndarray:
43
  return np.array(image)
44
 
45
 
 
 
 
 
46
  def _enabled_model_keys() -> list[str]:
47
  configured = os.getenv("SPACE_ENABLED_MODELS", "").strip()
48
  if not configured:
@@ -82,7 +161,223 @@ def _default_model_key(enabled_model_keys: list[str]) -> str:
82
  def _space_caption(enabled_model_keys: list[str]) -> str:
83
  if enabled_model_keys == ["dinov2"]:
84
  return "Streamlit Hugging Face Space demo for the DINO-Endo phase-recognition stack."
85
- return "DINO-first Streamlit Hugging Face Space demo for DINO-Endo, AI-Endo, and V-JEPA2."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def _get_model_manager() -> SpaceModelManager:
@@ -126,7 +421,100 @@ def _prepare_staged_video(uploaded_file):
126
  return temp_path, meta
127
 
128
 
129
- def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_frames: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  temp_path = Path(video_path)
131
  capture = cv2.VideoCapture(str(temp_path))
132
  if not capture.isOpened():
@@ -141,6 +529,7 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
141
  records = []
142
  processed = 0
143
  frame_index = 0
 
144
 
145
  try:
146
  while True:
@@ -154,7 +543,7 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
154
 
155
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
156
  started = time.perf_counter()
157
- result = predictor.predict(rgb)
158
  elapsed_ms = (time.perf_counter() - started) * 1000.0
159
 
160
  probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
@@ -174,6 +563,9 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
174
  records.append(record)
175
  processed += 1
176
 
 
 
 
177
  if total_frames > 0:
178
  progress.progress(min(frame_index + 1, total_frames) / total_frames)
179
  else:
@@ -192,18 +584,6 @@ def _analyse_video(video_path: str | Path, predictor, frame_stride: int, max_fra
192
  return records, {"fps": fps, "total_frames": total_frames, "sampled_frames": processed}
193
 
194
 
195
- def _records_to_frame(records):
196
- if not records:
197
- return pd.DataFrame(columns=["frame_index", "timestamp_sec", "phase", "confidence"])
198
- return pd.DataFrame.from_records(records)
199
-
200
-
201
- def _download_payloads(df: pd.DataFrame):
202
- json_payload = df.to_json(orient="records", indent=2).encode("utf-8")
203
- csv_payload = df.to_csv(index=False).encode("utf-8")
204
- return json_payload, csv_payload
205
-
206
-
207
  def _render_single_result(result: dict):
208
  probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
209
  metrics = st.columns(3)
@@ -215,7 +595,7 @@ def _render_single_result(result: dict):
215
  st.bar_chart(prob_df.set_index("phase"))
216
  st.download_button(
217
  label="Download JSON",
218
- data=json.dumps(result, indent=2).encode("utf-8"),
219
  file_name="phase_prediction.json",
220
  mime="application/json",
221
  key="download-single-json",
@@ -269,31 +649,23 @@ def main():
269
  enabled_model_keys = _enabled_model_keys()
270
  default_model_key = _default_model_key(enabled_model_keys)
271
  manager = _get_model_manager()
 
 
 
272
 
273
- st.title("DINO-Endo Surgical Phase Recognition")
274
  st.caption(_space_caption(enabled_model_keys))
275
 
276
- st.sidebar.markdown("### Model")
277
- if len(enabled_model_keys) == 1:
278
- model_key = enabled_model_keys[0]
279
- st.sidebar.write(MODEL_LABELS[model_key])
280
- else:
281
- model_key = st.sidebar.selectbox(
282
- "Model",
283
- options=enabled_model_keys,
284
- index=enabled_model_keys.index(default_model_key),
285
- format_func=lambda key: MODEL_LABELS[key],
286
- )
287
-
288
- previous_selected_model_key = st.session_state.get("selected_model_key")
289
- st.session_state["selected_model_key"] = model_key
290
  if previous_selected_model_key is not None and previous_selected_model_key != model_key:
291
  manager.unload_model()
292
 
 
 
293
  source_summary = get_model_source_summary(model_key)
294
- manager_status = manager.status()
295
  st.sidebar.markdown("### Runtime")
296
  st.sidebar.write(f"Selected model: `{MODEL_LABELS[model_key]}`")
 
297
  st.sidebar.write(f"CUDA available: `{torch.cuda.is_available()}`")
298
  if torch.cuda.is_available():
299
  st.sidebar.write(f"Device: `{torch.cuda.get_device_name(torch.cuda.current_device())}`")
@@ -308,16 +680,13 @@ def main():
308
  st.sidebar.write(f"HF repo: `{source_summary['repo_id'] or 'local-only'}`")
309
  if source_summary["subfolder"]:
310
  st.sidebar.write(f"Repo subfolder: `{source_summary['subfolder']}`")
 
 
 
 
311
  st.sidebar.write(f"Video upload cap: `{STREAMLIT_SERVER_MAX_UPLOAD_MB} MB`")
312
  st.sidebar.write(f"Working storage free: `{format_bytes(get_workspace_free_bytes())}`")
313
 
314
- if manager_status.is_loaded and manager_status.active_model_label:
315
- st.sidebar.success(f"Loaded model: {manager_status.active_model_label}")
316
- else:
317
- st.sidebar.info("No model is currently loaded.")
318
- if manager_status.last_error:
319
- st.sidebar.error(manager_status.last_error)
320
-
321
  prepare_col, unload_col = st.sidebar.columns(2)
322
  if prepare_col.button("Load model", use_container_width=True):
323
  try:
@@ -331,90 +700,151 @@ def main():
331
  manager.unload_model()
332
  st.sidebar.success("Model unloaded")
333
 
 
 
 
 
 
 
 
 
334
  image_tab, video_tab = st.tabs(["Image", "Video"])
335
 
336
  with image_tab:
337
- uploaded_image = st.file_uploader("Upload an RGB frame", type=["png", "jpg", "jpeg"], key="image-uploader")
338
- if uploaded_image is not None:
339
- rgb = _image_to_rgb(uploaded_image)
340
- st.image(rgb, caption=uploaded_image.name, use_container_width=True)
341
- if st.button("Run image inference", key="run-image"):
342
- try:
343
- with st.spinner(f"Running {MODEL_LABELS[model_key]} on {uploaded_image.name}..."):
344
- predictor = manager.get_predictor(model_key)
345
- predictor.reset_state()
346
- started = time.perf_counter()
347
- result = predictor.predict(rgb)
348
- result["inference_ms"] = round((time.perf_counter() - started) * 1000.0, 3)
349
- predictor.reset_state()
350
- except Exception as exc:
351
- st.error(str(exc))
352
- else:
353
- _render_single_result(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  with video_tab:
356
- frame_stride = st.slider("Analyze every Nth frame", min_value=1, max_value=30, value=5, step=1)
357
- max_frames = st.slider("Maximum sampled frames", min_value=10, max_value=600, value=180, step=10)
358
- uploaded_video = st.file_uploader(
359
- "Upload a video (MP4 preferred)",
360
- type=SUPPORTED_VIDEO_TYPES,
361
- key="video-uploader",
362
- help=(
363
- f"Single-file uploads are enabled up to {STREAMLIT_SERVER_MAX_UPLOAD_MB} MB. "
364
- "MP4 is preferred; MOV/AVI/MKV/WEBM/M4V stay enabled as fallback containers."
365
- ),
366
- max_upload_size=STREAMLIT_SERVER_MAX_UPLOAD_MB,
367
  )
368
- if uploaded_video is not None:
369
- try:
370
- temp_path, video_meta = _prepare_staged_video(uploaded_video)
371
- except Exception as exc:
372
- st.error(str(exc))
373
- else:
374
- info_cols = st.columns(5)
375
- info_cols[0].metric("File size", video_meta["file_size_label"])
376
- info_cols[1].metric("Duration", video_meta["duration_label"])
377
- info_cols[2].metric("FPS", f"{video_meta.get('fps', 0.0):.2f}" if video_meta.get("fps") else "Unknown")
378
- info_cols[3].metric("Frames", int(video_meta.get("frame_count", 0)))
379
- info_cols[4].metric("Resolution", video_meta["resolution_label"])
380
- if video_meta.get("format_name"):
381
- st.caption(f"Container detected by ffprobe: {video_meta['format_name']}")
382
-
383
- recommended_stride = recommended_frame_stride(video_meta.get("duration_seconds"))
384
- st.caption(
385
- f"Recommended frame stride for this video: every {recommended_stride} frame(s). "
386
- "Use higher values for very long videos to keep analysis times reasonable."
387
- )
388
-
389
- if should_show_inline_preview(video_meta["file_size_bytes"]):
390
- st.video(uploaded_video)
391
  else:
392
- st.info(
393
- "Inline preview is disabled for uploads larger than "
394
- "256 MB to avoid pushing very large media back through the browser. "
395
- "The staged video on disk is still used for analysis."
 
 
 
 
 
 
 
 
 
396
  )
397
 
398
- if st.button("Analyze video", key="run-video"):
399
- try:
400
- with st.spinner(f"Running {MODEL_LABELS[model_key]} on {uploaded_video.name}..."):
401
- predictor = manager.get_predictor(model_key)
402
- records, analysis_meta = _analyse_video(
403
- temp_path,
404
- predictor,
405
- frame_stride=frame_stride,
406
- max_frames=max_frames,
407
- )
408
- meta = {
409
- **video_meta,
410
- **analysis_meta,
411
- }
412
- except Exception as exc:
413
- st.error(str(exc))
414
  else:
415
- _render_video_results(records, meta)
416
- else:
417
- _clear_video_stage()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
 
420
  if __name__ == "__main__":
 
4
  import os
5
  import time
6
  from collections import Counter
7
+ from dataclasses import dataclass
8
  from pathlib import Path
9
 
10
  import cv2
 
14
  import torch
15
  from PIL import Image
16
 
17
+ from explainability import ExplainabilitySpec
18
  from model_manager import SpaceModelManager
19
  from model_registry import MODEL_SPECS, get_model_source_summary
20
  from predictor import MODEL_LABELS, PHASE_LABELS, normalize_model_key
 
30
  spool_uploaded_video,
31
  )
32
 
33
+ st.set_page_config(page_title="AI-Endo Project Hub", layout="wide")
34
+
35
+ MODEL_OPTION_LABELS = {
36
+ "aiendo": "AI-Endo",
37
+ "dinov2": "DINO-Endo",
38
+ "vjepa2": "V-JEPA2 (slower first load)",
39
+ }
40
+
41
+ MODEL_LOAD_NOTES = {
42
+ "aiendo": "AI-Endo uses the ResNet + MS-TCN + Transformer stack.",
43
+ "dinov2": "DINO-Endo remains the default public model in this demo.",
44
+ "vjepa2": "V-JEPA2 can take longer on the first load because the encoder checkpoint is several gigabytes.",
45
+ }
46
+
47
+ FALLBACK_EXPLAINABILITY_SPECS = {
48
+ "aiendo": ExplainabilitySpec(
49
+ encoder_mode="proxy",
50
+ encoder_label="ResNet layer4 activation energy (proxy)",
51
+ decoder_mode="attention",
52
+ decoder_label="Temporal Transformer attention",
53
+ ),
54
+ "dinov2": ExplainabilitySpec(
55
+ encoder_mode="attention",
56
+ encoder_label="DINOv2 encoder self-attention",
57
+ decoder_mode="attention",
58
+ decoder_label="Fusion Transformer temporal attention",
59
+ encoder_layer_count=12,
60
+ encoder_head_count=6,
61
+ ),
62
+ "vjepa2": ExplainabilitySpec(
63
+ encoder_mode="attention",
64
+ encoder_label="V-JEPA2 encoder self-attention",
65
+ decoder_mode="proxy",
66
+ decoder_label="MLP decoder feature energy (proxy)",
67
+ encoder_layer_count=24,
68
+ encoder_head_count=16,
69
+ ),
70
+ }
71
+
72
+
73
+ SPACE_TITLE = "AI-Endo Project Hub"
74
+ FEATURED_PROJECT_TITLE = "DINO-Endo Surgery Workspace"
75
+ MODEL_SLIDER_KEY = "workspace-model-slider"
76
+ SELECTED_MODEL_STATE_KEY = "selected_model_key"
77
+
78
+
79
+ @dataclass(frozen=True)
80
+ class HostedProject:
81
+ key: str
82
+ title: str
83
+ status: str
84
+ summary: str
85
+ highlights: tuple[str, ...]
86
+ tags: tuple[str, ...]
87
+
88
+
89
+ HOSTED_PROJECTS = (
90
+ HostedProject(
91
+ key="dino-endo-surgery",
92
+ title=FEATURED_PROJECT_TITLE,
93
+ status="Live now",
94
+ summary=(
95
+ "Upload single frames or full videos, swap between DINO-Endo, AI-Endo, and V-JEPA2, "
96
+ "and inspect optional explainability overlays inside one surgical phase-recognition workspace."
97
+ ),
98
+ highlights=(
99
+ "Large video uploads with on-disk staging",
100
+ "One-click JSON and CSV export",
101
+ "Live encoder and decoder explainability",
102
+ "Manual load and unload for GPU-safe model switching",
103
+ ),
104
+ tags=("Computer vision", "Medical video", "Multi-model inference"),
105
+ ),
106
+ )
107
 
108
 
109
  def _phase_index(phase: str) -> int:
 
118
  return np.array(image)
119
 
120
 
121
+ def _model_option_label(model_key: str) -> str:
122
+ return MODEL_OPTION_LABELS.get(model_key, MODEL_LABELS.get(model_key, model_key))
123
+
124
+
125
  def _enabled_model_keys() -> list[str]:
126
  configured = os.getenv("SPACE_ENABLED_MODELS", "").strip()
127
  if not configured:
 
161
  def _space_caption(enabled_model_keys: list[str]) -> str:
162
  if enabled_model_keys == ["dinov2"]:
163
  return "Streamlit Hugging Face Space demo for the DINO-Endo phase-recognition stack."
164
+ return "Streamlit Hugging Face Space demo for DINO-Endo, AI-Endo, and V-JEPA2 with one active model loaded at a time."
165
+
166
+
167
+ def _inject_app_styles() -> None:
168
+ st.markdown(
169
+ """
170
+ <style>
171
+ .block-container {
172
+ padding-top: 2.4rem;
173
+ padding-bottom: 2rem;
174
+ }
175
+
176
+ .hub-hero,
177
+ .hub-card,
178
+ .workspace-card {
179
+ border-radius: 22px;
180
+ border: 1px solid rgba(148, 163, 184, 0.22);
181
+ background: linear-gradient(180deg, rgba(15, 23, 42, 0.86), rgba(15, 23, 42, 0.66));
182
+ box-shadow: 0 20px 45px rgba(15, 23, 42, 0.18);
183
+ }
184
+
185
+ .hub-hero {
186
+ padding: 2rem 2.25rem;
187
+ margin-bottom: 1rem;
188
+ background: linear-gradient(135deg, rgba(14, 165, 233, 0.18), rgba(16, 185, 129, 0.18), rgba(15, 23, 42, 0.9));
189
+ }
190
+
191
+ .hub-eyebrow {
192
+ margin: 0;
193
+ color: #67e8f9;
194
+ font-size: 0.78rem;
195
+ font-weight: 700;
196
+ letter-spacing: 0.18em;
197
+ text-transform: uppercase;
198
+ }
199
+
200
+ .hub-hero h1,
201
+ .workspace-card h2,
202
+ .hub-card h3 {
203
+ margin: 0.4rem 0 0 0;
204
+ color: #f8fafc;
205
+ }
206
+
207
+ .hub-subtitle,
208
+ .workspace-copy,
209
+ .hub-card p,
210
+ .hub-card li {
211
+ color: rgba(226, 232, 240, 0.92);
212
+ line-height: 1.55;
213
+ }
214
+
215
+ .hub-subtitle {
216
+ margin-top: 0.8rem;
217
+ max-width: 62rem;
218
+ font-size: 1.03rem;
219
+ }
220
+
221
+ .hub-chip-row {
222
+ display: flex;
223
+ flex-wrap: wrap;
224
+ gap: 0.55rem;
225
+ margin-top: 1rem;
226
+ }
227
+
228
+ .hub-chip,
229
+ .hub-status {
230
+ display: inline-flex;
231
+ align-items: center;
232
+ border-radius: 999px;
233
+ padding: 0.32rem 0.78rem;
234
+ font-size: 0.82rem;
235
+ font-weight: 600;
236
+ }
237
+
238
+ .hub-chip {
239
+ background: rgba(15, 23, 42, 0.56);
240
+ border: 1px solid rgba(103, 232, 249, 0.24);
241
+ color: #e2e8f0;
242
+ }
243
+
244
+ .hub-status {
245
+ background: rgba(34, 197, 94, 0.18);
246
+ border: 1px solid rgba(34, 197, 94, 0.28);
247
+ color: #bbf7d0;
248
+ margin-bottom: 0.7rem;
249
+ }
250
+
251
+ .hub-card,
252
+ .workspace-card {
253
+ padding: 1.25rem 1.4rem;
254
+ height: 100%;
255
+ }
256
+
257
+ .hub-card ul {
258
+ margin: 0.8rem 0 0 1rem;
259
+ padding: 0;
260
+ }
261
+
262
+ .workspace-card {
263
+ margin: 0.3rem 0 1rem 0;
264
+ }
265
+ </style>
266
+ """,
267
+ unsafe_allow_html=True,
268
+ )
269
+
270
+
271
+ def _render_hub_chips(labels: list[str] | tuple[str, ...]) -> str:
272
+ return "".join(f'<span class="hub-chip">{label}</span>' for label in labels)
273
+
274
+
275
+ def _render_project_hub(enabled_model_keys: list[str]) -> None:
276
+ featured = HOSTED_PROJECTS[0]
277
+ enabled_labels = [_model_option_label(key) for key in enabled_model_keys]
278
+ st.markdown(
279
+ f"""
280
+ <section class="hub-hero">
281
+ <p class="hub-eyebrow">Multi-project landing page</p>
282
+ <h1>{SPACE_TITLE}</h1>
283
+ <p class="hub-subtitle">
284
+ A polished landing page for applied vision demos. {FEATURED_PROJECT_TITLE} is the first live workspace,
285
+ and the layout is ready to host more projects later without rebuilding the app shell.
286
+ </p>
287
+ <div class="hub-chip-row">
288
+ {_render_hub_chips(tuple(enabled_labels) + ("Future-project ready", "Streamlit + Docker Space"))}
289
+ </div>
290
+ </section>
291
+ """,
292
+ unsafe_allow_html=True,
293
+ )
294
+
295
+ metrics = st.columns(4)
296
+ metrics[0].metric("Hosted projects", len(HOSTED_PROJECTS))
297
+ metrics[1].metric("Model families", len(enabled_model_keys))
298
+ metrics[2].metric("Explainability", "Opt-in")
299
+ metrics[3].metric("Exports", "JSON + CSV")
300
+
301
+ left_col, right_col = st.columns([1.8, 1.2], gap="large")
302
+ with left_col:
303
+ highlights_html = "".join(f"<li>{item}</li>" for item in featured.highlights)
304
+ st.markdown(
305
+ f"""
306
+ <section class="hub-card">
307
+ <span class="hub-status">{featured.status}</span>
308
+ <h3>{featured.title}</h3>
309
+ <p>{featured.summary}</p>
310
+ <div class="hub-chip-row">{_render_hub_chips(featured.tags)}</div>
311
+ <ul>{highlights_html}</ul>
312
+ </section>
313
+ """,
314
+ unsafe_allow_html=True,
315
+ )
316
+
317
+ with right_col:
318
+ st.markdown(
319
+ """
320
+ <section class="hub-card">
321
+ <span class="hub-status">Platform shell</span>
322
+ <h3>Ready for more demos</h3>
323
+ <p>
324
+ The top section now works as a reusable project hub instead of a one-off page. Add more project cards
325
+ and workspace blocks here later, while keeping one shared brand, layout, and deployment target.
326
+ </p>
327
+ <ul>
328
+ <li>Keep each project's controls inside its own workspace section.</li>
329
+ <li>Reuse the same landing-page hero, metrics, and project-card layout.</li>
330
+ <li>Preserve one-model-at-a-time loading so future demos stay GPU-friendly.</li>
331
+ </ul>
332
+ </section>
333
+ """,
334
+ unsafe_allow_html=True,
335
+ )
336
+
337
+
338
+ def _render_workspace_header(enabled_model_keys: list[str], model_key: str) -> None:
339
+ selected_label = _model_option_label(model_key)
340
+ selection_note = (
341
+ "Use the model slider to move between DINO-Endo, AI-Endo, and V-JEPA2. "
342
+ "Only one model stays loaded at a time so the Space remains responsive on shared GPU hardware."
343
+ )
344
+ st.markdown(
345
+ f"""
346
+ <section class="workspace-card">
347
+ <p class="hub-eyebrow">Featured project</p>
348
+ <h2>{FEATURED_PROJECT_TITLE}</h2>
349
+ <p class="workspace-copy">
350
+ {selection_note}
351
+ </p>
352
+ <div class="hub-chip-row">
353
+ {_render_hub_chips(tuple(_model_option_label(key) for key in enabled_model_keys))}
354
+ <span class="hub-chip">Selected: {selected_label}</span>
355
+ </div>
356
+ </section>
357
+ """,
358
+ unsafe_allow_html=True,
359
+ )
360
+
361
+
362
+ def _resolve_model_selection(enabled_model_keys: list[str], default_model_key: str) -> tuple[str | None, str]:
363
+ previous_selected_model_key = st.session_state.get(SELECTED_MODEL_STATE_KEY)
364
+ current_slider_value = st.session_state.get(MODEL_SLIDER_KEY)
365
+ if current_slider_value not in enabled_model_keys:
366
+ st.session_state[MODEL_SLIDER_KEY] = default_model_key
367
+
368
+ if len(enabled_model_keys) == 1:
369
+ model_key = enabled_model_keys[0]
370
+ st.session_state[MODEL_SLIDER_KEY] = model_key
371
+ return previous_selected_model_key, model_key
372
+
373
+ model_key = st.select_slider(
374
+ "Project model slider",
375
+ options=enabled_model_keys,
376
+ key=MODEL_SLIDER_KEY,
377
+ format_func=_model_option_label,
378
+ help="Prominent model-family slider for the DINO-Endo project workspace.",
379
+ )
380
+ return previous_selected_model_key, model_key
381
 
382
 
383
  def _get_model_manager() -> SpaceModelManager:
 
421
  return temp_path, meta
422
 
423
 
424
+ def _records_to_frame(records):
425
+ if not records:
426
+ return pd.DataFrame(columns=["frame_index", "timestamp_sec", "phase", "confidence"])
427
+ return pd.DataFrame.from_records(records)
428
+
429
+
430
+ def _download_payloads(df: pd.DataFrame):
431
+ json_payload = df.to_json(orient="records", indent=2).encode("utf-8")
432
+ csv_payload = df.to_csv(index=False).encode("utf-8")
433
+ return json_payload, csv_payload
434
+
435
+
436
+ def _get_explainability_spec(manager: SpaceModelManager, model_key: str) -> ExplainabilitySpec:
437
+ predictor = manager.get_loaded_predictor(model_key)
438
+ if predictor is not None and hasattr(predictor, "get_explainability_spec"):
439
+ return predictor.get_explainability_spec()
440
+ return FALLBACK_EXPLAINABILITY_SPECS[model_key]
441
+
442
+
443
+ def _build_explainability_config(manager: SpaceModelManager, model_key: str):
444
+ spec = _get_explainability_spec(manager, model_key)
445
+ st.sidebar.markdown("### Explainability")
446
+ enabled = st.sidebar.toggle(
447
+ "Enable live encoder/decoder maps",
448
+ value=False,
449
+ help="Shows encoder heatmaps and decoder temporal strips on every processed frame. Leave this off if you want the fastest video analysis path.",
450
+ )
451
+ config = {"enabled": enabled}
452
+ if not enabled:
453
+ return config, spec
454
+
455
+ st.sidebar.caption(f"Encoder view: {spec.encoder_label}")
456
+ st.sidebar.caption(f"Decoder view: {spec.decoder_label}")
457
+ if spec.encoder_mode == "attention" and spec.encoder_layer_count > 0 and spec.encoder_head_count > 0:
458
+ default_layer = spec.encoder_layer_count - 1
459
+ config["encoder_layer"] = st.sidebar.slider(
460
+ "Encoder layer",
461
+ min_value=1,
462
+ max_value=spec.encoder_layer_count,
463
+ value=default_layer + 1,
464
+ key=f"explainability-layer-{model_key}",
465
+ ) - 1
466
+ config["encoder_head"] = st.sidebar.slider(
467
+ "Encoder head",
468
+ min_value=1,
469
+ max_value=spec.encoder_head_count,
470
+ value=1,
471
+ key=f"explainability-head-{model_key}",
472
+ ) - 1
473
+ else:
474
+ st.sidebar.info("This model uses a proxy encoder overlay instead of true encoder attention.")
475
+
476
+ st.sidebar.caption("Decoder strips are rendered as temporal heat strips rather than projected back onto the frame.")
477
+ return config, spec
478
+
479
+
480
+ def _render_explainability_panel(target, payload: dict | None, *, enabled: bool, spec: ExplainabilitySpec, title: str) -> None:
481
+ with target.container():
482
+ st.markdown(f"### {title}")
483
+ if not enabled:
484
+ st.caption("Turn on the explainability toggle in the sidebar to inspect encoder heatmaps and decoder temporal strips.")
485
+ return
486
+
487
+ st.caption(f"Encoder default: {spec.encoder_label}")
488
+ st.caption(f"Decoder default: {spec.decoder_label}")
489
+ if payload is None:
490
+ st.info("Run image or video inference to populate this live explainability panel.")
491
+ return
492
+
493
+ layer_index = payload.get("encoder_layer")
494
+ head_index = payload.get("encoder_head")
495
+ encoder_caption = f"{payload['encoder_label']} ({payload['encoder_kind']})"
496
+ if layer_index is not None and head_index is not None:
497
+ encoder_caption += f" · layer {int(layer_index) + 1}, head {int(head_index) + 1}"
498
+ st.caption(encoder_caption)
499
+ st.image(payload["encoder_visualization"], use_container_width=True)
500
+
501
+ st.caption(f"{payload['decoder_label']} ({payload['decoder_kind']})")
502
+ st.image(payload["decoder_visualization"], use_container_width=True)
503
+
504
+ notes = payload.get("notes")
505
+ if notes:
506
+ st.caption(notes)
507
+
508
+
509
+ def _analyse_video(
510
+ video_path: str | Path,
511
+ predictor,
512
+ frame_stride: int,
513
+ max_frames: int,
514
+ *,
515
+ explainability_config: dict | None = None,
516
+ explainability_callback=None,
517
+ ):
518
  temp_path = Path(video_path)
519
  capture = cv2.VideoCapture(str(temp_path))
520
  if not capture.isOpened():
 
529
  records = []
530
  processed = 0
531
  frame_index = 0
532
+ explain_enabled = bool(explainability_config and explainability_config.get("enabled"))
533
 
534
  try:
535
  while True:
 
543
 
544
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
545
  started = time.perf_counter()
546
+ result = predictor.predict(rgb, explainability=explainability_config if explain_enabled else None)
547
  elapsed_ms = (time.perf_counter() - started) * 1000.0
548
 
549
  probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
 
563
  records.append(record)
564
  processed += 1
565
 
566
+ if explain_enabled and explainability_callback is not None:
567
+ explainability_callback(result.get("explainability"), processed, frame_index)
568
+
569
  if total_frames > 0:
570
  progress.progress(min(frame_index + 1, total_frames) / total_frames)
571
  else:
 
584
  return records, {"fps": fps, "total_frames": total_frames, "sampled_frames": processed}
585
 
586
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  def _render_single_result(result: dict):
588
  probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
589
  metrics = st.columns(3)
 
595
  st.bar_chart(prob_df.set_index("phase"))
596
  st.download_button(
597
  label="Download JSON",
598
+ data=json.dumps(result, indent=2, default=str).encode("utf-8"),
599
  file_name="phase_prediction.json",
600
  mime="application/json",
601
  key="download-single-json",
 
649
  enabled_model_keys = _enabled_model_keys()
650
  default_model_key = _default_model_key(enabled_model_keys)
651
  manager = _get_model_manager()
652
+ _inject_app_styles()
653
+ _render_project_hub(enabled_model_keys)
654
+ previous_selected_model_key, model_key = _resolve_model_selection(enabled_model_keys, default_model_key)
655
 
656
+ _render_workspace_header(enabled_model_keys, model_key)
657
  st.caption(_space_caption(enabled_model_keys))
658
 
659
+ st.session_state[SELECTED_MODEL_STATE_KEY] = model_key
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  if previous_selected_model_key is not None and previous_selected_model_key != model_key:
661
  manager.unload_model()
662
 
663
+ explainability_config, explainability_spec = _build_explainability_config(manager, model_key)
664
+
665
  source_summary = get_model_source_summary(model_key)
 
666
  st.sidebar.markdown("### Runtime")
667
  st.sidebar.write(f"Selected model: `{MODEL_LABELS[model_key]}`")
668
+ st.sidebar.caption(MODEL_LOAD_NOTES[model_key])
669
  st.sidebar.write(f"CUDA available: `{torch.cuda.is_available()}`")
670
  if torch.cuda.is_available():
671
  st.sidebar.write(f"Device: `{torch.cuda.get_device_name(torch.cuda.current_device())}`")
 
680
  st.sidebar.write(f"HF repo: `{source_summary['repo_id'] or 'local-only'}`")
681
  if source_summary["subfolder"]:
682
  st.sidebar.write(f"Repo subfolder: `{source_summary['subfolder']}`")
683
+ with st.sidebar.expander("Checkpoint requirements", expanded=False):
684
+ st.write(", ".join(source_summary["required_files"]))
685
+ if source_summary["optional_files"]:
686
+ st.caption("Optional: " + ", ".join(source_summary["optional_files"]))
687
  st.sidebar.write(f"Video upload cap: `{STREAMLIT_SERVER_MAX_UPLOAD_MB} MB`")
688
  st.sidebar.write(f"Working storage free: `{format_bytes(get_workspace_free_bytes())}`")
689
 
 
 
 
 
 
 
 
690
  prepare_col, unload_col = st.sidebar.columns(2)
691
  if prepare_col.button("Load model", use_container_width=True):
692
  try:
 
700
  manager.unload_model()
701
  st.sidebar.success("Model unloaded")
702
 
703
+ manager_status = manager.status()
704
+ if manager_status.is_loaded and manager_status.active_model_label:
705
+ st.sidebar.success(f"Loaded model: {manager_status.active_model_label}")
706
+ else:
707
+ st.sidebar.info("No model is currently loaded.")
708
+ if manager_status.last_error:
709
+ st.sidebar.error(manager_status.last_error)
710
+
711
  image_tab, video_tab = st.tabs(["Image", "Video"])
712
 
713
  with image_tab:
714
+ image_main_col, image_explain_col = st.columns([3, 2], gap="large")
715
+ image_explain_placeholder = image_explain_col.empty()
716
+ image_result = None
717
+
718
+ with image_main_col:
719
+ uploaded_image = st.file_uploader("Upload an RGB frame", type=["png", "jpg", "jpeg"], key="image-uploader")
720
+ if uploaded_image is not None:
721
+ rgb = _image_to_rgb(uploaded_image)
722
+ st.image(rgb, caption=uploaded_image.name, use_container_width=True)
723
+ if st.button("Run image inference", key="run-image"):
724
+ try:
725
+ with st.spinner(f"Running {MODEL_LABELS[model_key]} on {uploaded_image.name}..."):
726
+ predictor = manager.get_predictor(model_key)
727
+ predictor.reset_state()
728
+ started = time.perf_counter()
729
+ image_result = predictor.predict(
730
+ rgb,
731
+ explainability=explainability_config if explainability_config.get("enabled") else None,
732
+ )
733
+ image_result["inference_ms"] = round((time.perf_counter() - started) * 1000.0, 3)
734
+ predictor.reset_state()
735
+ except Exception as exc:
736
+ st.error(str(exc))
737
+ else:
738
+ _render_single_result(image_result)
739
+
740
+ _render_explainability_panel(
741
+ image_explain_placeholder,
742
+ image_result.get("explainability") if image_result else None,
743
+ enabled=bool(explainability_config.get("enabled")),
744
+ spec=explainability_spec,
745
+ title="Explainability",
746
+ )
747
 
748
  with video_tab:
749
+ video_main_col, video_explain_col = st.columns([3, 2], gap="large")
750
+ video_explain_placeholder = video_explain_col.empty()
751
+ _render_explainability_panel(
752
+ video_explain_placeholder,
753
+ None,
754
+ enabled=bool(explainability_config.get("enabled")),
755
+ spec=explainability_spec,
756
+ title="Explainability",
 
 
 
757
  )
758
+
759
+ with video_main_col:
760
+ frame_stride = st.slider("Analyze every Nth frame", min_value=1, max_value=30, value=5, step=1)
761
+ max_frames = st.slider("Maximum sampled frames", min_value=10, max_value=600, value=180, step=10)
762
+ uploaded_video = st.file_uploader(
763
+ "Upload a video (MP4 preferred)",
764
+ type=SUPPORTED_VIDEO_TYPES,
765
+ key="video-uploader",
766
+ help=(
767
+ f"Single-file uploads are enabled up to {STREAMLIT_SERVER_MAX_UPLOAD_MB} MB. "
768
+ "MP4 is preferred; MOV/AVI/MKV/WEBM/M4V stay enabled as fallback containers."
769
+ ),
770
+ max_upload_size=STREAMLIT_SERVER_MAX_UPLOAD_MB,
771
+ )
772
+ if uploaded_video is not None:
773
+ try:
774
+ temp_path, video_meta = _prepare_staged_video(uploaded_video)
775
+ except Exception as exc:
776
+ st.error(str(exc))
 
 
 
 
777
  else:
778
+ info_cols = st.columns(5)
779
+ info_cols[0].metric("File size", video_meta["file_size_label"])
780
+ info_cols[1].metric("Duration", video_meta["duration_label"])
781
+ info_cols[2].metric("FPS", f"{video_meta.get('fps', 0.0):.2f}" if video_meta.get("fps") else "Unknown")
782
+ info_cols[3].metric("Frames", int(video_meta.get("frame_count", 0)))
783
+ info_cols[4].metric("Resolution", video_meta["resolution_label"])
784
+ if video_meta.get("format_name"):
785
+ st.caption(f"Container detected by ffprobe: {video_meta['format_name']}")
786
+
787
+ recommended_stride = recommended_frame_stride(video_meta.get("duration_seconds"))
788
+ st.caption(
789
+ f"Recommended frame stride for this video: every {recommended_stride} frame(s). "
790
+ "Use higher values for very long videos to keep analysis times reasonable."
791
  )
792
 
793
+ if should_show_inline_preview(video_meta["file_size_bytes"]):
794
+ st.video(uploaded_video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  else:
796
+ st.info(
797
+ "Inline preview is disabled for uploads larger than "
798
+ "256 MB to avoid pushing very large media back through the browser. "
799
+ "The staged video on disk is still used for analysis."
800
+ )
801
+
802
+ if st.button("Analyze video", key="run-video"):
803
+ latest_payload = {"value": None}
804
+
805
+ def _video_explainability_callback(payload, processed_count: int, current_frame_index: int):
806
+ latest_payload["value"] = payload
807
+ _render_explainability_panel(
808
+ video_explain_placeholder,
809
+ payload,
810
+ enabled=True,
811
+ spec=explainability_spec,
812
+ title=f"Live explainability · sampled frame {processed_count}",
813
+ )
814
+
815
+ try:
816
+ with st.spinner(f"Running {MODEL_LABELS[model_key]} on {uploaded_video.name}..."):
817
+ predictor = manager.get_predictor(model_key)
818
+ records, analysis_meta = _analyse_video(
819
+ temp_path,
820
+ predictor,
821
+ frame_stride=frame_stride,
822
+ max_frames=max_frames,
823
+ explainability_config=explainability_config if explainability_config.get("enabled") else None,
824
+ explainability_callback=(
825
+ _video_explainability_callback
826
+ if explainability_config.get("enabled")
827
+ else None
828
+ ),
829
+ )
830
+ meta = {
831
+ **video_meta,
832
+ **analysis_meta,
833
+ }
834
+ except Exception as exc:
835
+ st.error(str(exc))
836
+ else:
837
+ _render_video_results(records, meta)
838
+ if explainability_config.get("enabled"):
839
+ _render_explainability_panel(
840
+ video_explain_placeholder,
841
+ latest_payload["value"],
842
+ enabled=True,
843
+ spec=explainability_spec,
844
+ title="Explainability",
845
+ )
846
+ else:
847
+ _clear_video_stage()
848
 
849
 
850
  if __name__ == "__main__":
dinov2/.github/workflows/lint.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+
11
+ jobs:
12
+ run-linters:
13
+ name: Run linters
14
+ runs-on: ubuntu-20.04
15
+
16
+ steps:
17
+ - name: Checkout repository
18
+ uses: actions/checkout@v3
19
+ - name: Set up Python
20
+ uses: actions/setup-python@v4
21
+ with:
22
+ python-version: 3.9
23
+ cache: 'pip'
24
+ cache-dependency-path: '**/requirements*.txt'
25
+ - name: Install Python (development) dependencies
26
+ run: |
27
+ pip install -r requirements-dev.txt
28
+ - name: Run flake8
29
+ run: |
30
+ flake8
31
+ - name: Run black
32
+ if: always()
33
+ run: |
34
+ black --check dinov2
35
+ - name: Run pylint
36
+ if: always()
37
+ run: |
38
+ pylint --exit-zero dinov2
dinov2/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build/
2
+ dist/
3
+ *.egg-info/
4
+ **/__pycache__/
5
+
6
+ **/.ipynb_checkpoints
7
+ **/.ipynb_checkpoints/**
8
+
9
+ *.swp
10
+
11
+ .vscode/
dinov2/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
dinov2/CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to DINOv2
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Meta's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to DINOv2, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
dinov2/LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
dinov2/MODEL_CARD.md ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card for DINOv2-S/B/L/g
2
+
3
+ These are Vision Transformer models trained following the method described in the papers:
4
+ "DINOv2: Learning Robust Visual Features without Supervision"
5
+ and
6
+ "Vision Transformers Need Registers".
7
+
8
+ We provide 8 models:
9
+ - 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, without registers.
10
+ - 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, with registers.
11
+
12
+ ## Model Details
13
+ The model takes an image as input and returns a class token and patch tokens, and optionally 4 register tokens.
14
+
15
+ The embedding dimension is:
16
+ - 384 for ViT-S.
17
+ - 768 for ViT-B.
18
+ - 1024 for ViT-L.
19
+ - 1536 for ViT-g.
20
+
21
+ The models follow a Transformer architecture, with a patch size of 14. In the case of registers, we add 4 register tokens, learned during training, to the input sequence after the patch embedding.
22
+
23
+ For a 224x224 image, this results in 1 class token + 256 patch tokens, and optionally 4 register tokens.
24
+
25
+ The models can accept larger images provided the image shapes are multiples of the patch size (14).
26
+ If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
27
+
28
+ ### Model Description
29
+
30
+ - **Developed by:** Meta AI
31
+ - **Model type:** Vision Transformer
32
+ - **License:** Apache License 2.0
33
+
34
+ - **Repository:** https://github.com/facebookresearch/dinov2
35
+ - **Paper:** https://arxiv.org/abs/2304.07193
36
+ - **Demo:** https://dinov2.metademolab.com/
37
+
38
+ ## Uses
39
+
40
+ The models are vision backbones providing multi-purpose features for downstream tasks.
41
+
42
+ ### Direct Use
43
+
44
+ The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
45
+ - on depth estimation, semantic segmentation, using linear layers.
46
+ - on image classification, using k-NN classifiers on the class token.
47
+ - on image classification, with logistic regression classifiers applied on the class token.
48
+ - on image classification, with a linear layer applied on the class token and the average of the patch tokens.
49
+ - on image retrieval using nearest neighbors.
50
+
51
+ ### Downstream Use
52
+
53
+ It is technically possible to perform fine-tuning on the models, for small gains (we measured +2% on ImageNet-1k classification).
54
+ We recommend keeping this as a very last step and only when necessary, as the features already provide good performance out-of-the-box.
55
+
56
+ ## Bias, Risks, and Limitations
57
+
58
+ Despite improvements thanks to the training method not using annotations, we still observe significant biases in our models toward rich households from Western countries.
59
+
60
+ ### Recommendations
61
+
62
+ We expect fine-tuning will increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
63
+
64
+ ## How to Get Started with the Model
65
+
66
+ Use the code below to get started with the model.
67
+
68
+ ```python
69
+ import torch
70
+
71
+ # DINOv2
72
+ dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
73
+ dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
74
+ dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
75
+ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
76
+
77
+ # DINOv2 with registers
78
+ dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
79
+ dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
80
+ dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
81
+ dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
82
+ ```
83
+
84
+ ## Training Details
85
+
86
+ ### Training Data
87
+
88
+ - **Training data:** LVD-142M (see paper)
89
+ - **Training regime:** fp16 using PyTorch-FSDP mixed-precision.
90
+
91
+ ### Training Procedure
92
+
93
+ - **Training objective:**
94
+ - DINO self-distillation loss with multi-crop
95
+ - iBOT masked-image modeling loss
96
+ - KoLeo regularization on [CLS] tokens
97
+ - **Architectures:**
98
+ - ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN
99
+ - ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN
100
+ - ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN
101
+ - ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN
102
+ - **Distillation:**
103
+ - Distillation follows the standard DINOv2 pretraining procedure, except the teacher is a pretrained ViT-g, frozen.
104
+
105
+ ## Evaluation
106
+
107
+ We refer users to the associated papers for the evaluation protocols.
108
+
109
+ <table>
110
+ <tr>
111
+ <th colspan="2"></th>
112
+ <th colspan="3">ImageNet-1k</th>
113
+ <th>NYU-Depth v2</th>
114
+ <th>SUN-RGBD</th>
115
+ <th>ADE20k</th>
116
+ <th>iNaturalist 2018</th>
117
+ <th>Oxford-H</th>
118
+ </tr>
119
+ <tr>
120
+ <th rowspan="2">model</th>
121
+ <th rowspan="2">with <br /> registers</th>
122
+ <th>classif. (acc)</th>
123
+ <th>classif. (acc)</th>
124
+ <th>classif. V2 (acc)</th>
125
+ <th>depth (RMSE)</th>
126
+ <th>depth (RMSE)</th>
127
+ <th>segm. (mAP)</th>
128
+ <th>classif. (acc)</th>
129
+ <th>retrieval (mAP)</th>
130
+ </tr>
131
+ <tr>
132
+ <!-- <th>^</th> -->
133
+ <th>k-NN</th>
134
+ <th>linear</th>
135
+ <th>linear</th>
136
+ <th>linear<br />4 layers</th>
137
+ <th>NYU-D transfer</th>
138
+ <th>multiscale</th>
139
+ <th>linear</th>
140
+ <th>nearest neighbor</th>
141
+ </tr>
142
+ <tr>
143
+ <td>ViT-S/14</td>
144
+ <td align="center">:x:</td>
145
+ <td align="right">79.0%</td>
146
+ <td align="right">81.1%</td>
147
+ <td align="right">70.8%</td>
148
+ <td align="right">0.417</td>
149
+ <td align="right">0.431</td>
150
+ <td align="right">47.2</td>
151
+ <td align="right">69.5%</td>
152
+ <td align="right">43.2</td>
153
+ </tr>
154
+ <tr>
155
+ <td>ViT-S/14</td>
156
+ <td align="center">:white_check_mark:</td>
157
+ <td align="right">79.1%</td>
158
+ <td align="right">80.9%</td>
159
+ <td align="right">71.0%</td>
160
+ <td align="right">N/A</td>
161
+ <td align="right">N/A</td>
162
+ <td align="right">N/A</td>
163
+ <td align="right">67.6%</td>
164
+ <td align="right">39.5</td>
165
+ </tr>
166
+ <tr>
167
+ <td>ViT-B/14</td>
168
+ <td align="center">:x:</td>
169
+ <td align="right">82.1%</td>
170
+ <td align="right">84.5%</td>
171
+ <td align="right">74.9%</td>
172
+ <td align="right">0.362</td>
173
+ <td align="right">0.400</td>
174
+ <td align="right">51.3</td>
175
+ <td align="right">76.3%</td>
176
+ <td align="right">49.5</td>
177
+ </tr>
178
+ <td>ViT-B/14</td>
179
+ <td align="center">:white_check_mark:</td>
180
+ <td align="right">82.0%</td>
181
+ <td align="right">84.6%</td>
182
+ <td align="right">75.6%</td>
183
+ <td align="right">N/A</td>
184
+ <td align="right">N/A</td>
185
+ <td align="right">N/A</td>
186
+ <td align="right">73.8%</td>
187
+ <td align="right">51.0</td>
188
+ </tr>
189
+ <tr>
190
+ <td>ViT-L/14</td>
191
+ <td align="center">:x:</td>
192
+ <td align="right">83.5%</td>
193
+ <td align="right">86.3%</td>
194
+ <td align="right">77.6%</td>
195
+ <td align="right">0.333</td>
196
+ <td align="right">0.396</td>
197
+ <td align="right">53.1</td>
198
+ <td align="right">79.8%</td>
199
+ <td align="right">54.0</td>
200
+ </tr>
201
+ <tr>
202
+ <td>ViT-L/14</td>
203
+ <td align="center">:white_check_mark:</td>
204
+ <td align="right">83.8%</td>
205
+ <td align="right">86.7%</td>
206
+ <td align="right">78.5%</td>
207
+ <td align="right">N/A</td>
208
+ <td align="right">N/A</td>
209
+ <td align="right">N/A</td>
210
+ <td align="right">80.9%</td>
211
+ <td align="right">55.7</td>
212
+ </tr>
213
+ <tr>
214
+ <td>ViT-g/14</td>
215
+ <td align="center">:x:</td>
216
+ <td align="right">83.5%</td>
217
+ <td align="right">86.5%</td>
218
+ <td align="right">78.4%</td>
219
+ <td align="right">0.298</td>
220
+ <td align="right">0.362</td>
221
+ <td align="right">53.0</td>
222
+ <td align="right">81.6%</td>
223
+ <td align="right">52.3</td>
224
+ </tr>
225
+ <tr>
226
+ <tr>
227
+ <td>ViT-g/14</td>
228
+ <td align="center">:white_check_mark:</td>
229
+ <td align="right">83.7%</td>
230
+ <td align="right">87.1%</td>
231
+ <td align="right">78.8%</td>
232
+ <td align="right">N/A</td>
233
+ <td align="right">N/A</td>
234
+ <td align="right">N/A</td>
235
+ <td align="right">81.5%</td>
236
+ <td align="right">58.2</td>
237
+ </tr>
238
+ </table>
239
+
240
+ ## Environmental Impact
241
+
242
+ - **Hardware Type:** Nvidia A100
243
+ - **Hours used:** 22,000 for ViT-g, 4,500 for ViT-S distillation, 5,300 for ViT-B distillation, 8,000 for ViT-L distillation
244
+ - **Cloud Provider:** Private infra
245
+ - **Compute Region:** USA
246
+ - **Carbon Emitted:** 7t CO2eq
247
+
248
+ #### Hardware
249
+
250
+ Nvidia A100 GPUs
251
+
252
+ #### Software
253
+
254
+ PyTorch 2.0,
255
+ xFormers 0.0.18
256
+
257
+ **BibTeX**
258
+
259
+ ```
260
+ @misc{oquab2023dinov2,
261
+ title={DINOv2: Learning Robust Visual Features without Supervision},
262
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
263
+ journal={arXiv:2304.07193},
264
+ year={2023}
265
+ }
266
+ @misc{darcet2023vitneedreg,
267
+ title={Vision Transformers Need Registers},
268
+ author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
269
+ journal={arXiv:2309.16588},
270
+ year={2023}
271
+ }
272
+ ```
dinov2/README.md ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :new: [2023-10-26] *Added DINOv2 backbones with registers, following [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588).*
2
+
3
+ # DINOv2: Learning Robust Visual Features without Supervision
4
+
5
+ **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
6
+
7
+ Maxime Oquab,
8
+ Timothée Darcet,
9
+ Théo Moutakanni,
10
+ Huy V. Vo,
11
+ Marc Szafraniec,
12
+ Vasil Khalidov,
13
+ Patrick Labatut,
14
+ Armand Joulin,
15
+ Piotr Bojanowski
16
+
17
+ [[`Paper #1`](https://arxiv.org/abs/2304.07193)] [`Paper #2`](https://arxiv.org/abs/2309.16588)] [[`Blog`](https://ai.facebook.com/blog/dino-v2-computer-vision-self-supervised-learning/)] [[`Demo`](https://dinov2.metademolab.com)] [[`BibTeX`](#citing-dinov2)]
18
+
19
+ PyTorch implementation and pretrained models for DINOv2. For details, see the papers: **[DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)** and **[Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)**.
20
+
21
+ DINOv2 models produce high-performance visual features that can be directly employed with classifiers as simple as linear layers on a variety of computer vision tasks; these visual features are robust and perform well across domains without any requirement for fine-tuning. The models were pretrained on a dataset of 142 M images without using any labels or annotations.
22
+
23
+ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b429-578badf5c356
24
+
25
+ <div align="center">
26
+ Visualization of the three first principal components of the patch features of all frames, mapped to RGB values.
27
+ </div>
28
+
29
+ ## Pretrained models
30
+
31
+ <table style="margin: auto">
32
+ <thead>
33
+ <tr>
34
+ <th>model</th>
35
+ <th># of<br />params</th>
36
+ <th>with<br />registers</th>
37
+ <th>ImageNet<br />k-NN</th>
38
+ <th>ImageNet<br />linear</th>
39
+ <th>download</th>
40
+ </tr>
41
+ </thead>
42
+ <tbody>
43
+ <tr>
44
+ <td>ViT-S/14 distilled</td>
45
+ <td align="right">21 M</td>
46
+ <td align="center">:x:</td>
47
+ <td align="right">79.0%</td>
48
+ <td align="right">81.1%</td>
49
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth">backbone only</a></td>
50
+ </tr>
51
+ <tr>
52
+ <td>ViT-S/14 distilled</td>
53
+ <td align="right">21 M</td>
54
+ <td align="center">:white_check_mark:</td>
55
+ <td align="right">79.1%</td>
56
+ <td align="right">80.9%</td>
57
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth">backbone only</a></td>
58
+ </tr>
59
+ <tr>
60
+ <td>ViT-B/14 distilled</td>
61
+ <td align="right">86 M</td>
62
+ <td align="center">:x:</td>
63
+ <td align="right">82.1%</td>
64
+ <td align="right">84.5%</td>
65
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth">backbone only</a></td>
66
+ </tr>
67
+ <tr>
68
+ <td>ViT-B/14 distilled</td>
69
+ <td align="right">86 M</td>
70
+ <td align="center">:white_check_mark:</td>
71
+ <td align="right">82.0%</td>
72
+ <td align="right">84.6%</td>
73
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth">backbone only</a></td>
74
+ </tr>
75
+ <tr>
76
+ <td>ViT-L/14 distilled</td>
77
+ <td align="right">300 M</td>
78
+ <td align="center">:x:</td>
79
+ <td align="right">83.5%</td>
80
+ <td align="right">86.3%</td>
81
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth">backbone only</a></td>
82
+ </tr>
83
+ <tr>
84
+ <td>ViT-L/14 distilled</td>
85
+ <td align="right">300 M</td>
86
+ <td align="center">:white_check_mark:</td>
87
+ <td align="right">83.8%</td>
88
+ <td align="right">86.7%</td>
89
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth">backbone only</a></td>
90
+ </tr>
91
+ <tr>
92
+ <td>ViT-g/14</td>
93
+ <td align="right">1,100 M</td>
94
+ <td align="center">:x:</td>
95
+ <td align="right">83.5%</td>
96
+ <td align="right">86.5%</td>
97
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth">backbone only</a></td>
98
+ </tr>
99
+ <tr>
100
+ <td>ViT-g/14</td>
101
+ <td align="right">1,100 M</td>
102
+ <td align="center">:white_check_mark:</td>
103
+ <td align="right">83.7%</td>
104
+ <td align="right">87.1%</td>
105
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth">backbone only</a></td>
106
+ </tr>
107
+ </tbody>
108
+ </table>
109
+
110
+ ### Pretrained backbones (via PyTorch Hub)
111
+
112
+ Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended.
113
+
114
+ A corresponding [model card](MODEL_CARD.md) is included in the repository.
115
+
116
+ ```python
117
+ import torch
118
+
119
+ # DINOv2
120
+ dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
121
+ dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
122
+ dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
123
+ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
124
+
125
+ # DINOv2 with registers
126
+ dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
127
+ dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
128
+ dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
129
+ dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
130
+ ```
131
+
132
+ ### Pretrained heads - Image classification
133
+
134
+ <table style="margin: auto">
135
+ <thead>
136
+ <tr>
137
+ <th rowspan="2">backbone</th>
138
+ <th rowspan="2">with<br />registers</th>
139
+ <th>download</th>
140
+ </tr>
141
+ <tr>
142
+ <th>ImageNet</th>
143
+ </tr>
144
+ </thead>
145
+ <tbody>
146
+ <tr>
147
+ <td>ViT-S/14 distilled</td>
148
+ <td align="center">:x:</td>
149
+ <td>
150
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">1 layer</a>,
151
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear4_head.pth">4 layers</a>)
152
+ </td>
153
+ </tr>
154
+ <tr>
155
+ <td>ViT-S/14 distilled</td>
156
+ <td align="center">:white_check_mark:</td>
157
+ <td>
158
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">1 layer</a>,
159
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear4_head.pth">4 layers</a>)
160
+ </td>
161
+ </tr>
162
+ <tr>
163
+ <td>ViT-B/14 distilled</td>
164
+ <td align="center">:x:</td>
165
+ <td>
166
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>,
167
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear4_head.pth">4 layers</a>)
168
+ </tr>
169
+ <tr>
170
+ <td>ViT-B/14 distilled</td>
171
+ <td align="center">:white_check_mark:</td>
172
+ <td>
173
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">1 layer</a>,
174
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear4_head.pth">4 layers</a>)
175
+ </tr>
176
+ <tr>
177
+ <td>ViT-L/14 distilled</td>
178
+ <td align="center">:x:</td>
179
+ <td>
180
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>,
181
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear4_head.pth">4 layers</a>)
182
+ </tr>
183
+ <tr>
184
+ <td>ViT-L/14 distilled</td>
185
+ <td align="center">:white_check_mark:</td>
186
+ <td>
187
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">1 layer</a>,
188
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear4_head.pth">4 layers</a>)
189
+ </tr>
190
+ <tr>
191
+ <td>ViT-g/14</td>
192
+ <td align="center">:x:</td>
193
+ <td>
194
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>,
195
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear4_head.pth">4 layers</a>)
196
+ </tr>
197
+ <tr>
198
+ <td>ViT-g/14</td>
199
+ <td align="center">:white_check_mark:</td>
200
+ <td>
201
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_lreg4_inear_head.pth">1 layer</a>,
202
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear4_head.pth">4 layers</a>)
203
+ </tr>
204
+ </tbody>
205
+ </table>
206
+
207
+ The (full) classifier models can be loaded via PyTorch Hub:
208
+
209
+ ```python
210
+ import torch
211
+
212
+ # DINOv2
213
+ dinov2_vits14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
214
+ dinov2_vitb14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
215
+ dinov2_vitl14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
216
+ dinov2_vitg14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc')
217
+
218
+ # DINOv2 with registers
219
+ dinov2_vits14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc')
220
+ dinov2_vitb14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc')
221
+ dinov2_vitl14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc')
222
+ dinov2_vitg14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')
223
+ ```
224
+
225
+ ### Pretrained heads - Depth estimation
226
+
227
+ <table style="margin: auto">
228
+ <thead>
229
+ <tr>
230
+ <th rowspan="2">backbone</th>
231
+ <th colspan="2">download head</th>
232
+ </tr>
233
+ <tr>
234
+ <th>NYUd</th>
235
+ <th>KITTI</th>
236
+ </tr>
237
+ </thead>
238
+ <tbody>
239
+ <tr>
240
+ <td>ViT-S/14 distilled</td>
241
+ <td>
242
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_linear_head.pth">1 layer</a>,
243
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_linear4_head.pth">4 layers</a>),
244
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth">DPT</a>
245
+ </td>
246
+ <td>
247
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_linear_head.pth">1 layer</a>,
248
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_linear4_head.pth">4 layers</a>),
249
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_dpt_head.pth">DPT</a>
250
+ </td>
251
+ </tr>
252
+ <tr>
253
+ <td>ViT-B/14 distilled</td>
254
+ <td>
255
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>,
256
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_linear4_head.pth">4 layers</a>),
257
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_dpt_head.pth">DPT</a>
258
+ </td>
259
+ <td>
260
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_linear_head.pth">1 layer</a>,
261
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_linear4_head.pth">4 layers</a>),
262
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_dpt_head.pth">DPT</a>
263
+ </td>
264
+ </tr>
265
+ <tr>
266
+ <td>ViT-L/14 distilled</td>
267
+ <td>
268
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>,
269
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_linear4_head.pth">4 layers</a>),
270
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_dpt_head.pth">DPT</a>
271
+ </td>
272
+ <td>
273
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_linear_head.pth">1 layer</a>,
274
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_linear4_head.pth">4 layers</a>),
275
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_dpt_head.pth">DPT</a>
276
+ </td>
277
+ </tr>
278
+ <tr>
279
+ <td>ViT-g/14</td>
280
+ <td>
281
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>,
282
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_linear4_head.pth">4 layers</a>),
283
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_dpt_head.pth">DPT</a>
284
+ </td>
285
+ <td>
286
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_linear_head.pth">1 layer</a>,
287
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_linear4_head.pth">4 layers</a>),
288
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_dpt_head.pth">DPT</a>
289
+ </td>
290
+ </tr>
291
+ </tbody>
292
+ </table>
293
+
294
+ ### Pretrained heads - Semantic segmentation
295
+
296
+ <table style="margin: auto">
297
+ <thead>
298
+ <tr>
299
+ <th rowspan="2">backbone</th>
300
+ <th>download model</th>
301
+ <th colspan="2">download head</th>
302
+ </tr>
303
+ <tr>
304
+ <th>ADE20K</th>
305
+ <th>ADE20K</th>
306
+ <th>VOC2012</th>
307
+ </tr>
308
+ </thead>
309
+ <tbody>
310
+ <tr>
311
+ <td>ViT-S/14 distilled</td>
312
+ <td></td>
313
+ <td>
314
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_linear_head.pth">linear</a>,
315
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_ms_head.pth">multi-scale</a>
316
+ </td>
317
+ <td>
318
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_voc2012_linear_head.pth">linear</a>,
319
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_voc2012_ms_head.pth">multi-scale</a>
320
+ </td>
321
+ </tr>
322
+ <tr>
323
+ <td>ViT-B/14 distilled</td>
324
+ <td></td>
325
+ <td>
326
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_ade20k_linear_head.pth">linear</a>,
327
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_ade20k_ms_head.pth">multi-scale</a>
328
+ </td>
329
+ <td>
330
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_voc2012_linear_head.pth">linear</a>,
331
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_voc2012_ms_head.pth">multi-scale</a>
332
+ </td>
333
+ </tr>
334
+ <tr>
335
+ <td>ViT-L/14 distilled</td>
336
+ <td></td>
337
+ <td>
338
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_ade20k_linear_head.pth">linear</a>,
339
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_ade20k_ms_head.pth">multi-scale</a>
340
+ </td>
341
+ <td>
342
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_voc2012_linear_head.pth">linear</a>,
343
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_voc2012_ms_head.pth">multi-scale</a>
344
+ </td>
345
+ </tr>
346
+ <tr>
347
+ <td>ViT-g/14</td>
348
+ <td>
349
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth">Mask2Former</a>
350
+ </td>
351
+ <td>
352
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_linear_head.pth">linear</a>,
353
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_ms_head.pth">multi-scale</a>
354
+ </td>
355
+ <td>
356
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_voc2012_linear_head.pth">linear</a>,
357
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_voc2012_ms_head.pth">multi-scale</a>
358
+ </td>
359
+ </tr>
360
+ </tbody>
361
+ </table>
362
+
363
+ ## Installation
364
+
365
+ The training and evaluation code requires PyTorch 2.0 and [xFormers](https://github.com/facebookresearch/xformers) 0.0.18 as well as a number of other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:
366
+
367
+ *[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov2` conda environment using the provided environment definition:
368
+
369
+ ```shell
370
+ conda env create -f conda.yaml
371
+ conda activate dinov2
372
+ ```
373
+
374
+ *[pip](https://pip.pypa.io/en/stable/getting-started/)* - Clone the repository and then use the provided `requirements.txt` to install the dependencies:
375
+
376
+ ```shell
377
+ pip install -r requirements.txt
378
+ ```
379
+
380
+ For dense tasks (depth estimation and semantic segmentation), there are additional dependencies (specific versions of `mmcv` and `mmsegmentation`) which are captured in the `extras` dependency specifications:
381
+
382
+ *[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)**:
383
+
384
+ ```shell
385
+ conda env create -f conda-extras.yaml
386
+ conda activate dinov2-extras
387
+ ```
388
+
389
+ *[pip](https://pip.pypa.io/en/stable/getting-started/)*:
390
+
391
+ ```shell
392
+ pip install -r requirements.txt -r requirements-extras.txt
393
+ ```
394
+
395
+ ## Data preparation
396
+
397
+ ### ImageNet-1k
398
+
399
+ The root directory of the dataset should hold the following contents:
400
+
401
+ - `<ROOT>/test/ILSVRC2012_test_00000001.JPEG`
402
+ - `<ROOT>/test/[..]`
403
+ - `<ROOT>/test/ILSVRC2012_test_00100000.JPEG`
404
+ - `<ROOT>/train/n01440764/n01440764_10026.JPEG`
405
+ - `<ROOT>/train/[...]`
406
+ - `<ROOT>/train/n15075141/n15075141_9993.JPEG`
407
+ - `<ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
408
+ - `<ROOT>/val/[...]`
409
+ - `<ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
410
+ - `<ROOT>/labels.txt`
411
+
412
+ The provided dataset implementation expects a few additional metadata files to be present under the extra directory:
413
+
414
+ - `<EXTRA>/class-ids-TRAIN.npy`
415
+ - `<EXTRA>/class-ids-VAL.npy`
416
+ - `<EXTRA>/class-names-TRAIN.npy`
417
+ - `<EXTRA>/class-names-VAL.npy`
418
+ - `<EXTRA>/entries-TEST.npy`
419
+ - `<EXTRA>/entries-TRAIN.npy`
420
+ - `<EXTRA>/entries-VAL.npy`
421
+
422
+ These metadata files can be generated (once) with the following lines of Python code:
423
+
424
+ ```python
425
+ from dinov2.data.datasets import ImageNet
426
+
427
+ for split in ImageNet.Split:
428
+ dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
429
+ dataset.dump_extra()
430
+ ```
431
+
432
+ Note that the root and extra directories do not have to be distinct directories.
433
+
434
+ ### ImageNet-22k
435
+
436
+ Please adapt the [dataset class](dinov2/data/datasets/image_net_22k.py) to match your local setup.
437
+
438
+ <br />
439
+
440
+ :warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
441
+
442
+ ## Training
443
+
444
+ ### Fast setup: training DINOv2 ViT-L/16 on ImageNet-1k
445
+
446
+ Run DINOv2 training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
447
+
448
+ ```shell
449
+ python dinov2/run/train/train.py \
450
+ --nodes 4 \
451
+ --config-file dinov2/configs/train/vitl16_short.yaml \
452
+ --output-dir <PATH/TO/OUTPUT/DIR> \
453
+ train.dataset_path=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
454
+ ```
455
+
456
+ Training time is approximately 1 day and the resulting checkpoint should reach 81.6% on k-NN eval and 82.9% on linear eval.
457
+
458
+ The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
459
+
460
+ ### Long setup: training DINOv2 ViT-L/14 on ImageNet-22k
461
+
462
+ Run DINOv2 training on 12 A100-80GB nodes (96 GPUs) in a SLURM cluster environment with submitit:
463
+
464
+ ```shell
465
+ python dinov2/run/train/train.py \
466
+ --nodes 12 \
467
+ --config-file dinov2/configs/train/vitl14.yaml \
468
+ --output-dir <PATH/TO/OUTPUT/DIR> \
469
+ train.dataset_path=ImageNet22k:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
470
+ ```
471
+
472
+ Training time is approximately 3.3 days and the resulting checkpoint should reach 82.0% on k-NN eval and 84.5% on linear eval.
473
+
474
+ The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
475
+
476
+
477
+ ## Evaluation
478
+
479
+ The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
480
+
481
+ ### k-NN classification on ImageNet-1k
482
+
483
+ ```shell
484
+ python dinov2/run/eval/knn.py \
485
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
486
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
487
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/knn \
488
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
489
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
490
+ ```
491
+
492
+ ### Logistic regression classification on ImageNet-1k
493
+
494
+ ```shell
495
+ python dinov2/run/eval/log_regression.py \
496
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
497
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
498
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/logreg \
499
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
500
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
501
+ ```
502
+
503
+ ### Linear classification with data augmentation on ImageNet-1k
504
+
505
+ ```shell
506
+ python dinov2/run/eval/linear.py \
507
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
508
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
509
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/linear \
510
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
511
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
512
+ ```
513
+
514
+ We release the weights from evaluating the different models:
515
+
516
+ <table style="margin: auto">
517
+ <tr>
518
+ <th>model</th>
519
+ <th>with<br />registers</th>
520
+ <th>ImageNet<br />top-1</th>
521
+ <th>linear evaluation</th>
522
+ </tr>
523
+ <tr>
524
+ <td>ViT-S/14 distilled</td>
525
+ <td align="center">:x:</td>
526
+ <td align="right">81.1%</td>
527
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">linear head weights</a></td>
528
+ </tr>
529
+ <tr>
530
+ <td>ViT-S/14 distilled</td>
531
+ <td align="center">:white_check_mark:</td>
532
+ <td align="right">80.8%</td>
533
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">linear head weights</a></td>
534
+ </tr>
535
+ <tr>
536
+ <td>ViT-B/14 distilled</td>
537
+ <td align="center">:x:</td>
538
+ <td align="right">84.5%</td>
539
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">linear head weights</a></td>
540
+ </tr>
541
+ <tr>
542
+ <td>ViT-B/14 distilled</td>
543
+ <td align="center">:white_check_mark:</td>
544
+ <td align="right">84.4%</td>
545
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">linear head weights</a></td>
546
+ </tr>
547
+ <tr>
548
+ <td>ViT-L/14 distilled</td>
549
+ <td align="center">:x:</td>
550
+ <td align="right">86.3%</td>
551
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">linear head weights</a></td>
552
+ </tr>
553
+ <tr>
554
+ <td>ViT-L/14 distilled</td>
555
+ <td align="center">:white_check_mark:</td>
556
+ <td align="right">86.5%</td>
557
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">linear head weights</a></td>
558
+ </tr>
559
+ <tr>
560
+ <td>ViT-g/14</td>
561
+ <td align="center">:x:</td>
562
+ <td align="right">86.5%</td>
563
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">linear head weights</a></td>
564
+ </tr>
565
+ <tr>
566
+ <td>ViT-g/14</td>
567
+ <td align="center">:white_check_mark:</td>
568
+ <td align="right">87.0%</td>
569
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth">linear head weights</a></td>
570
+ </tr>
571
+ </table>
572
+
573
+ The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k:
574
+
575
+ ```shell
576
+ python dinov2/run/eval/linear.py \
577
+ --config-file dinov2/configs/eval/vitg14_pretrain.yaml \
578
+ --pretrained-weights https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth \
579
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
580
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
581
+ ```
582
+
583
+ ## Notebooks
584
+
585
+ A few notebooks are provided to help the community leverage the models and code:
586
+
587
+ <ul>
588
+ <li><a href="https://github.com/facebookresearch/dinov2/blob/main/notebooks/depth_estimation.ipynb">Depth estimation</a> - How to load and use the depth heads in combination with a matching backbone via mmcv</li>
589
+ <li><a href="https://github.com/facebookresearch/dinov2/blob/main/notebooks/semantic_segmentation.ipynb">Semantic segmentation</a> - How to load and use the segmentation heads in combination with a matching backbone via mmcv, and also how to load and use the Mask2Former-based segmentation model trained on ADE20K</li>
590
+ </ul>
591
+
592
+ ## License
593
+
594
+ DINOv2 code and model weights are released under the Apache License 2.0. See [LICENSE](LICENSE) for additional details.
595
+
596
+ ## Contributing
597
+
598
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
599
+
600
+ ## Citing DINOv2
601
+
602
+ If you find this repository useful, please consider giving a star :star: and citation :t-rex::
603
+
604
+ ```
605
+ @misc{oquab2023dinov2,
606
+ title={DINOv2: Learning Robust Visual Features without Supervision},
607
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
608
+ journal={arXiv:2304.07193},
609
+ year={2023}
610
+ }
611
+ ```
612
+
613
+ ```
614
+ @misc{darcet2023vitneedreg,
615
+ title={Vision Transformers Need Registers},
616
+ author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
617
+ journal={arXiv:2309.16588},
618
+ year={2023}
619
+ }
620
+ ```
dinov2/conda-extras.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dinov2-extras
2
+ channels:
3
+ - defaults
4
+ - pytorch
5
+ - nvidia
6
+ - xformers
7
+ - conda-forge
8
+ dependencies:
9
+ - python=3.9
10
+ - pytorch::pytorch=2.0.0
11
+ - pytorch::pytorch-cuda=11.7.0
12
+ - pytorch::torchvision=0.15.0
13
+ - omegaconf
14
+ - torchmetrics=0.10.3
15
+ - fvcore
16
+ - iopath
17
+ - xformers::xformers=0.0.18
18
+ - pip
19
+ - pip:
20
+ - git+https://github.com/facebookincubator/submitit
21
+ - --extra-index-url https://pypi.nvidia.com
22
+ - cuml-cu11
23
+ - mmcv-full==1.5.0
24
+ - mmsegmentation==0.27.0
dinov2/conda.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dinov2
2
+ channels:
3
+ - defaults
4
+ - pytorch
5
+ - nvidia
6
+ - conda-forge
7
+ dependencies:
8
+ - python=3.9
9
+ - pytorch=2.0.0
10
+ - pytorch-cuda=11.7
11
+ - torchvision=0.15.0
12
+ - omegaconf
13
+ - torchmetrics=0.10.3
14
+ - fvcore
15
+ - iopath
16
+ - pip
17
+ - pip:
18
+ - git+https://github.com/facebookincubator/submitit
19
+ - --extra-index-url https://pypi.nvidia.com
20
+ - cuml-cu11
21
+ - xformers==0.0.20 # Updated xformers version compatible with PyTorch 2.0
dinov2/pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 120
3
+
4
+ [tool.pylint.master]
5
+ persistent = false
6
+ score = false
7
+
8
+ [tool.pylint.messages_control]
9
+ disable = "all"
10
+ enable = [
11
+ "miscellaneous",
12
+ "similarities",
13
+ ]
14
+
15
+ [tool.pylint.similarities]
16
+ ignore-comments = true
17
+ ignore-docstrings = true
18
+ ignore-imports = true
19
+ min-similarity-lines = 8
20
+
21
+ [tool.pylint.reports]
22
+ reports = false
23
+
24
+ [tool.pylint.miscellaneous]
25
+ notes = [
26
+ "FIXME",
27
+ "XXX",
28
+ "TODO",
29
+ ]
dinov2/requirements-dev.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ black==22.6.0
2
+ flake8==5.0.4
3
+ pylint==2.15.0
dinov2/requirements-extras.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mmcv-full==1.5.0
2
+ mmsegmentation==0.27.0
dinov2/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu117
2
+ torch==2.0.0
3
+ torchvision==0.15.0
4
+ omegaconf
5
+ torchmetrics==0.10.3
6
+ fvcore
7
+ iopath
8
+ xformers==0.0.18
9
+ submitit
10
+ --extra-index-url https://pypi.nvidia.com
11
+ cuml-cu11
dinov2/scripts/lint.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ if [ -n "$1" ]; then
4
+ echo "linting \"$1\""
5
+ fi
6
+
7
+ echo "running black"
8
+ if [ -n "$1" ]; then
9
+ black "$1"
10
+ else
11
+ black dinov2
12
+ fi
13
+
14
+ echo "running flake8"
15
+ if [ -n "$1" ]; then
16
+ flake8 "$1"
17
+ else
18
+ flake8
19
+ fi
20
+
21
+ echo "running pylint"
22
+ if [ -n "$1" ]; then
23
+ pylint "$1"
24
+ else
25
+ pylint dinov2
26
+ fi
27
+
28
+ exit 0
dinov2/setup.cfg ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 120
3
+ ignore = E203,E501,W503
4
+ per-file-ignores =
5
+ __init__.py:F401
6
+ hubconf.py:F401
7
+ exclude =
8
+ venv
dinov2/setup.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from pathlib import Path
7
+ import re
8
+ from typing import List, Tuple
9
+
10
+ from setuptools import setup, find_packages
11
+
12
+
13
+ NAME = "dinov2"
14
+ DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method."
15
+
16
+ URL = "https://github.com/facebookresearch/dinov2"
17
+ AUTHOR = "FAIR"
18
+ REQUIRES_PYTHON = ">=3.9.0"
19
+ HERE = Path(__file__).parent
20
+
21
+
22
+ try:
23
+ with open(HERE / "README.md", encoding="utf-8") as f:
24
+ long_description = "\n" + f.read()
25
+ except FileNotFoundError:
26
+ long_description = DESCRIPTION
27
+
28
+
29
+ def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]:
30
+ requirements = []
31
+ extra_indices = []
32
+ with open(path) as f:
33
+ for line in f.readlines():
34
+ line = line.rstrip("\r\n")
35
+ if line.startswith("--extra-index-url "):
36
+ extra_indices.append(line[18:])
37
+ continue
38
+ requirements.append(line)
39
+ return requirements, extra_indices
40
+
41
+
42
+ def get_package_version() -> str:
43
+ with open(HERE / "dinov2/__init__.py") as f:
44
+ result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
45
+ if result:
46
+ return result.group(1)
47
+ raise RuntimeError("Can't get package version")
48
+
49
+
50
+ requirements, extra_indices = get_requirements()
51
+ version = get_package_version()
52
+ dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt")
53
+ extras_requirements, _ = get_requirements(HERE / "requirements-extras.txt")
54
+
55
+
56
+ setup(
57
+ name=NAME,
58
+ version=version,
59
+ description=DESCRIPTION,
60
+ long_description=long_description,
61
+ long_description_content_type="text/markdown",
62
+ author=AUTHOR,
63
+ python_requires=REQUIRES_PYTHON,
64
+ url=URL,
65
+ packages=find_packages(),
66
+ package_data={
67
+ "": ["*.yaml"],
68
+ },
69
+ install_requires=requirements,
70
+ extras_require={
71
+ "dev": dev_requirements,
72
+ "extras": extras_requirements,
73
+ },
74
+ dependency_links=extra_indices,
75
+ install_package_data=True,
76
+ license="Apache",
77
+ license_files=("LICENSE",),
78
+ classifiers=[
79
+ # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py
80
+ "Development Status :: 3 - Alpha",
81
+ "Intended Audience :: Developers",
82
+ "Intended Audience :: Science/Research",
83
+ "License :: OSI Approved :: Apache Software License",
84
+ "Programming Language :: Python :: 3.9",
85
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
86
+ "Topic :: Software Development :: Libraries :: Python Modules",
87
+ ],
88
+ )
explainability.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class ExplainabilitySpec:
12
+ encoder_mode: str
13
+ encoder_label: str
14
+ decoder_mode: str
15
+ decoder_label: str
16
+ encoder_layer_count: int = 0
17
+ encoder_head_count: int = 0
18
+
19
+
20
+ class ModuleOutputRecorder:
21
+ def __init__(self) -> None:
22
+ self.handle = None
23
+ self.output = None
24
+
25
+ def attach(self, module) -> None:
26
+ self.remove()
27
+ self.handle = module.register_forward_hook(self._hook)
28
+
29
+ def clear(self) -> None:
30
+ self.output = None
31
+
32
+ def remove(self) -> None:
33
+ if self.handle is not None:
34
+ self.handle.remove()
35
+ self.handle = None
36
+ self.output = None
37
+
38
+ def _hook(self, module, inputs, output) -> None: # pragma: no cover - hook signature
39
+ if torch.is_tensor(output):
40
+ self.output = output.detach()
41
+ else:
42
+ self.output = output
43
+
44
+
45
+ def clamp_index(index: int | None, upper_bound: int) -> int:
46
+ if upper_bound <= 0:
47
+ return 0
48
+ if index is None:
49
+ return upper_bound - 1
50
+ return max(0, min(int(index), upper_bound - 1))
51
+
52
+
53
+ def normalize_map(values) -> np.ndarray:
54
+ array = np.asarray(values, dtype=np.float32)
55
+ if array.ndim != 2:
56
+ raise ValueError(f"Expected a 2D array, got shape {array.shape}")
57
+
58
+ array = array.copy()
59
+ min_value = float(array.min(initial=0.0))
60
+ array -= min_value
61
+ max_value = float(array.max(initial=0.0))
62
+ if max_value > 0:
63
+ array /= max_value
64
+ return array
65
+
66
+
67
+ def resize_rgb_image(rgb_image: np.ndarray, size: tuple[int, int]) -> np.ndarray:
68
+ width, height = size
69
+ return cv2.resize(rgb_image, (width, height), interpolation=cv2.INTER_LINEAR)
70
+
71
+
72
+ def feature_energy_map(feature_tensor: torch.Tensor, output_shape: tuple[int, int]) -> np.ndarray:
73
+ tensor = feature_tensor.detach().float()
74
+ while tensor.dim() > 3:
75
+ tensor = tensor[0]
76
+ if tensor.dim() == 3:
77
+ tensor = tensor.abs().mean(dim=0)
78
+ elif tensor.dim() != 2:
79
+ raise ValueError(f"Unexpected feature tensor shape: {tuple(feature_tensor.shape)}")
80
+
81
+ heatmap = normalize_map(tensor.cpu().numpy())
82
+ height, width = output_shape
83
+ return cv2.resize(heatmap, (width, height), interpolation=cv2.INTER_CUBIC)
84
+
85
+
86
+ def render_heatmap_overlay(rgb_image: np.ndarray, heatmap: np.ndarray, alpha: float = 0.45) -> np.ndarray:
87
+ if heatmap.shape != rgb_image.shape[:2]:
88
+ heatmap = cv2.resize(heatmap, (rgb_image.shape[1], rgb_image.shape[0]), interpolation=cv2.INTER_CUBIC)
89
+ colored = cv2.applyColorMap((normalize_map(heatmap) * 255.0).astype(np.uint8), cv2.COLORMAP_TURBO)
90
+ colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
91
+ return cv2.addWeighted(rgb_image, 1.0 - alpha, colored, alpha, 0.0)
92
+
93
+
94
+ def render_temporal_strip(values, *, active_index: int | None = None, cell_width: int = 12, height: int = 72) -> np.ndarray:
95
+ sequence = np.asarray(values, dtype=np.float32).reshape(1, -1)
96
+ if sequence.size == 0:
97
+ sequence = np.zeros((1, 1), dtype=np.float32)
98
+
99
+ normalized = normalize_map(sequence)
100
+ strip = (normalized * 255.0).astype(np.uint8)
101
+ strip = np.repeat(strip, height, axis=0)
102
+ strip = np.repeat(strip, cell_width, axis=1)
103
+ colored = cv2.applyColorMap(strip, cv2.COLORMAP_TURBO)
104
+ colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
105
+
106
+ if active_index is not None and sequence.shape[1] > 0:
107
+ clamped = clamp_index(active_index, sequence.shape[1])
108
+ x0 = clamped * cell_width
109
+ x1 = min(colored.shape[1] - 1, x0 + cell_width - 1)
110
+ cv2.rectangle(colored, (x0, 0), (x1, colored.shape[0] - 1), (255, 255, 255), 2)
111
+
112
+ return colored
model/transformer.py CHANGED
@@ -146,7 +146,7 @@ class Decoder(nn.Module):
146
  super(Decoder, self).__init__()
147
  self.layers = nn.ModuleList([DecoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
148
 
149
- def forward(self, dec_inputs, enc_outputs):
150
  '''
151
  dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5]
152
  enc_intpus: [batch_size, src_len, d_model] [512, 30, 5]
@@ -160,6 +160,8 @@ class Decoder(nn.Module):
160
  # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
161
  dec_outputs, dec_enc_attn = layer(dec_outputs, enc_outputs)
162
  dec_enc_attns.append(dec_enc_attn)
 
 
163
  return dec_outputs
164
 
165
 
@@ -175,7 +177,7 @@ class Transformer2_3_1(nn.Module):
175
  self.encoder = Encoder(d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q)
176
  self.decoder = Decoder(d_model, d_ff, d_k, d_v, 1, n_heads, len_q)
177
 
178
- def forward(self, enc_inputs, dec_inputs):
179
  '''
180
  enc_inputs: [Frames, src_len, d_model] [512, 30, 5]
181
  dec_inputs: [Frames, 1, d_model] [512, 1, 5]
@@ -185,8 +187,11 @@ class Transformer2_3_1(nn.Module):
185
 
186
  # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
187
  enc_outputs, enc_self_attns = self.encoder(enc_inputs) # Self-attention for temporal features
188
- dec_outputs = self.decoder(dec_inputs, enc_outputs)
189
- return dec_outputs
 
 
 
190
 
191
 
192
  class Transformer(nn.Module):
@@ -210,7 +215,7 @@ class Transformer(nn.Module):
210
  nn.Linear(self.d_model, out_features, bias=False)
211
  )
212
 
213
- def forward(self, x, long_feature):
214
  # x: [B, 256, T]; long_feature: [B, T, 256]
215
  B, D, T = x.shape
216
  out_features = x.transpose(1, 2) # [B, T, 256]
@@ -238,9 +243,24 @@ class Transformer(nn.Module):
238
  win = out_features[:, i - spa_len + 1:i + 1, :]
239
  out_feas.append(win)
240
  out_feas = torch.stack(out_feas, dim=0).squeeze(1)
241
- out_feas, _ = self.spatial_encoder(out_feas)
242
 
243
  # Temporal-spatial fusion
244
- output = self.transformer(inputs, out_feas) # [T, B, 1, 256] collapsed → [T, B, 256]
 
 
 
 
245
  output = self.out(output) # [T, B, C]
246
- return output.transpose(0, 1) # [B, T, C]
 
 
 
 
 
 
 
 
 
 
 
 
146
  super(Decoder, self).__init__()
147
  self.layers = nn.ModuleList([DecoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
148
 
149
+ def forward(self, dec_inputs, enc_outputs, return_attentions=False):
150
  '''
151
  dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5]
152
  enc_intpus: [batch_size, src_len, d_model] [512, 30, 5]
 
160
  # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
161
  dec_outputs, dec_enc_attn = layer(dec_outputs, enc_outputs)
162
  dec_enc_attns.append(dec_enc_attn)
163
+ if return_attentions:
164
+ return dec_outputs, dec_enc_attns
165
  return dec_outputs
166
 
167
 
 
177
  self.encoder = Encoder(d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q)
178
  self.decoder = Decoder(d_model, d_ff, d_k, d_v, 1, n_heads, len_q)
179
 
180
+ def forward(self, enc_inputs, dec_inputs, return_attentions=False):
181
  '''
182
  enc_inputs: [Frames, src_len, d_model] [512, 30, 5]
183
  dec_inputs: [Frames, 1, d_model] [512, 1, 5]
 
187
 
188
  # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
189
  enc_outputs, enc_self_attns = self.encoder(enc_inputs) # Self-attention for temporal features
190
+ decoder_outputs = self.decoder(dec_inputs, enc_outputs, return_attentions=return_attentions)
191
+ if return_attentions:
192
+ dec_outputs, dec_enc_attns = decoder_outputs
193
+ return dec_outputs, {"encoder_self_attns": enc_self_attns, "decoder_cross_attns": dec_enc_attns}
194
+ return decoder_outputs
195
 
196
 
197
  class Transformer(nn.Module):
 
215
  nn.Linear(self.d_model, out_features, bias=False)
216
  )
217
 
218
+ def forward(self, x, long_feature, return_attention=False):
219
  # x: [B, 256, T]; long_feature: [B, T, 256]
220
  B, D, T = x.shape
221
  out_features = x.transpose(1, 2) # [B, T, 256]
 
243
  win = out_features[:, i - spa_len + 1:i + 1, :]
244
  out_feas.append(win)
245
  out_feas = torch.stack(out_feas, dim=0).squeeze(1)
246
+ out_feas, spatial_attn = self.spatial_encoder(out_feas)
247
 
248
  # Temporal-spatial fusion
249
+ transformer_outputs = self.transformer(inputs, out_feas, return_attentions=return_attention)
250
+ if return_attention:
251
+ output, attention_meta = transformer_outputs
252
+ else:
253
+ output = transformer_outputs
254
  output = self.out(output) # [T, B, C]
255
+ output = output.transpose(0, 1) # [B, T, C]
256
+ if not return_attention:
257
+ return output
258
+
259
+ decoder_attn = attention_meta["decoder_cross_attns"][-1]
260
+ spatial_attn_last = spatial_attn[-1]
261
+ decoder_strip = decoder_attn[-1].mean(dim=0).mean(dim=0).detach()
262
+ spatial_strip = spatial_attn_last.mean(dim=0).mean(dim=0).detach()
263
+ return output, {
264
+ "decoder_strip": decoder_strip,
265
+ "spatial_strip": spatial_strip,
266
+ }
model_manager.py CHANGED
@@ -56,6 +56,16 @@ class SpaceModelManager:
56
  self.current_predictor = predictor
57
  return predictor
58
 
 
 
 
 
 
 
 
 
 
 
59
  def reset_predictor_state(self) -> None:
60
  if self.current_predictor is not None and hasattr(self.current_predictor, "reset_state"):
61
  self.current_predictor.reset_state()
 
56
  self.current_predictor = predictor
57
  return predictor
58
 
59
+ def get_loaded_predictor(self, model_key: str | None = None):
60
+ if self.current_predictor is None:
61
+ return None
62
+ if model_key is None:
63
+ return self.current_predictor
64
+ normalized_key = normalize_model_key(model_key)
65
+ if self.current_model_key != normalized_key:
66
+ return None
67
+ return self.current_predictor
68
+
69
  def reset_predictor_state(self) -> None:
70
  if self.current_predictor is not None and hasattr(self.current_predictor, "reset_state"):
71
  self.current_predictor.reset_state()
predictor.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  import os
4
  import sys
5
  from contextlib import nullcontext
@@ -21,6 +22,15 @@ except ImportError: # pragma: no cover
21
  from model.resnet import ResNet
22
  from model.mstcn import MultiStageModel
23
  from model.transformer import Transformer
 
 
 
 
 
 
 
 
 
24
 
25
  PHASE_LABELS = ("idle", "marking", "injection", "dissection")
26
  MODEL_LABELS = {
@@ -108,6 +118,37 @@ def _resolve_vendor_repo(repo_name: str, extra_candidates=()):
108
  raise FileNotFoundError(f"Required vendor repo '{repo_name}' not found. Stage it into this folder or keep the repo-root copy available.")
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  class Predictor:
112
  def __init__(self, model_dir: str | None = None, device: str = "cuda"):
113
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
@@ -118,6 +159,14 @@ class Predictor:
118
  self.frame_feature_cache = None
119
  self.label_dict = dict(enumerate(PHASE_LABELS))
120
  self.available = False
 
 
 
 
 
 
 
 
121
 
122
  self._norm_mean = None
123
  self._norm_std = None
@@ -134,6 +183,9 @@ class Predictor:
134
  paras = {k.replace("share.", "resnet."): v for k, v in paras.items()}
135
  self.resnet.load_state_dict(paras, strict=True)
136
  self.resnet.to(self.device).eval()
 
 
 
137
 
138
  self.fusion = MultiStageModel(
139
  mstcn_stages=2,
@@ -174,11 +226,18 @@ class Predictor:
174
  self.predict(dummy)
175
  self.reset_state()
176
 
 
 
 
177
  def reset_state(self):
178
  self.frame_feature_cache = None
 
179
  if torch.cuda.is_available():
180
  torch.cuda.empty_cache()
181
 
 
 
 
182
  def unload(self):
183
  self.available = False
184
  self.resnet.to("cpu")
@@ -188,6 +247,10 @@ class Predictor:
188
  self.fusion = None
189
  self.transformer = None
190
  self.frame_feature_cache = None
 
 
 
 
191
  if torch.cuda.is_available():
192
  torch.cuda.empty_cache()
193
 
@@ -200,7 +263,10 @@ class Predictor:
200
  self.frame_feature_cache = torch.cat([self.frame_feature_cache, feature], dim=0)
201
 
202
  @torch.inference_mode()
203
- def predict(self, rgb_image: np.ndarray):
 
 
 
204
  if self._norm_mean is not None:
205
  tensor = self._preprocess_gpu(rgb_image)
206
  else:
@@ -216,33 +282,91 @@ class Predictor:
216
  single_frame_feature = feature.unsqueeze(1)
217
  temporal_input = single_frame_feature.transpose(1, 2)
218
  temporal_feature = self.fusion(temporal_input)
219
- outputs = self.transformer(temporal_feature.detach(), single_frame_feature)
 
 
 
 
 
 
 
 
220
  final_logits = outputs[-1, -1, :]
221
  probs = F.softmax(final_logits.float(), dim=-1)
222
  pred_np = probs.detach().cpu().numpy()
223
  confidence = float(np.max(pred_np))
224
  phase_idx = max(0, min(3, int(np.argmax(pred_np))))
225
  phase = self.label_dict.get(phase_idx, "idle")
226
- return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": 1}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  if self.frame_feature_cache.shape[0] < 30:
229
  available_frames = self.frame_feature_cache.shape[0] + 1
230
  cat_frame_feature = torch.cat([self.frame_feature_cache, feature], dim=0).unsqueeze(0)
231
  temporal_input = cat_frame_feature.transpose(1, 2)
232
  temporal_feature = self.fusion(temporal_input)
233
- outputs = self.transformer(temporal_feature.detach(), cat_frame_feature)
 
 
 
 
 
 
 
 
234
  final_logits = outputs[-1, -1, :]
235
  probs = F.softmax(final_logits.float(), dim=-1)
236
  pred_np = probs.detach().cpu().numpy()
237
  confidence = float(np.max(pred_np))
238
  phase_idx = max(0, min(3, int(np.argmax(pred_np))))
239
  phase = self.label_dict.get(phase_idx, "idle")
240
- return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  cat_frame_feature = self.frame_feature_cache.unsqueeze(0)
243
  temporal_input = cat_frame_feature.transpose(1, 2)
244
  temporal_feature = self.fusion(temporal_input)
245
- outputs = self.transformer(temporal_feature.detach(), cat_frame_feature)
 
 
 
 
 
 
 
 
246
  final_logits = outputs[-1, -1, :]
247
  probs = F.softmax(final_logits.float(), dim=-1)
248
  pred_np = probs.detach().cpu().numpy()
@@ -250,7 +374,22 @@ class Predictor:
250
  confidence = float(np.max(pred_np))
251
  phase_idx = max(0, min(3, int(np.argmax(pred_np))))
252
  phase = self.label_dict.get(phase_idx, "idle")
253
- return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": min(self.trans_seq, self.frame_feature_cache.shape[0])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
 
256
  class PredictorDinoV2:
@@ -267,7 +406,19 @@ class PredictorDinoV2:
267
  A.CenterCrop(height=224, width=224),
268
  A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
269
  ])
 
 
 
 
270
  self.frame_features = []
 
 
 
 
 
 
 
 
271
  self._load_models(self.model_dir)
272
 
273
  def _amp_context(self):
@@ -297,6 +448,14 @@ class PredictorDinoV2:
297
  encoder_load = self.backbone.load_state_dict(encoder_state, strict=False)
298
  _validate_load_result(encoder_load, "DINOv2 backbone")
299
  self.backbone.to(self.device).eval()
 
 
 
 
 
 
 
 
300
 
301
  decoder_path = os.path.join(model_dir, "fusion_transformer_decoder_best_model.pth")
302
  if not os.path.exists(decoder_path):
@@ -326,7 +485,7 @@ class PredictorDinoV2:
326
  d_model=d_model,
327
  )
328
 
329
- def forward(self, x):
330
  x = x.permute(0, 2, 1)
331
  x_reduced = self.reduce(x)
332
  mstcn_input = x_reduced.permute(0, 2, 1)
@@ -341,8 +500,15 @@ class PredictorDinoV2:
341
  else:
342
  transformer_input = mstcn_input.detach()
343
 
344
- transformer_out = self.transformer(transformer_input, x_reduced)
345
- return transformer_out.permute(0, 2, 1)
 
 
 
 
 
 
 
346
 
347
  self.decoder = FusionTransformerDecoder()
348
  decoder_load = self.decoder.load_state_dict(decoder_state, strict=False)
@@ -359,6 +525,7 @@ class PredictorDinoV2:
359
 
360
  def reset_state(self):
361
  self.frame_features = []
 
362
  if torch.cuda.is_available():
363
  torch.cuda.empty_cache()
364
 
@@ -367,6 +534,40 @@ class PredictorDinoV2:
367
  self.predict(dummy_img)
368
  self.reset_state()
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def unload(self):
371
  if self.backbone is not None:
372
  self.backbone.to("cpu")
@@ -375,15 +576,33 @@ class PredictorDinoV2:
375
  self.backbone = None
376
  self.decoder = None
377
  self.frame_features = []
 
 
378
  self.available = False
379
  if torch.cuda.is_available():
380
  torch.cuda.empty_cache()
381
 
382
  @torch.inference_mode()
383
- def predict(self, rgb_image: np.ndarray):
384
  if not self.available or self.backbone is None or self.decoder is None:
385
  raise RuntimeError("DINO-Endo predictor is not available")
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  processed = self.aug(image=rgb_image)["image"]
388
  chw = np.transpose(processed, (2, 0, 1))
389
  tensor = torch.tensor(chw, dtype=torch.float32).unsqueeze(0).to(self.device)
@@ -408,7 +627,11 @@ class PredictorDinoV2:
408
 
409
  decoder_input = seq.transpose(1, 2)
410
  with self._amp_context():
411
- logits = self.decoder(decoder_input)
 
 
 
 
412
 
413
  if logits.dim() != 3:
414
  raise ValueError(f"Unexpected DINOv2 decoder output shape: {tuple(logits.shape)}")
@@ -424,7 +647,22 @@ class PredictorDinoV2:
424
  confidence = float(np.max(pred_np))
425
  phase_idx = int(np.argmax(pred_np))
426
  phase = self.label_dict.get(phase_idx, "idle")
427
- return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
 
430
  class PredictorVJEPA2:
@@ -443,6 +681,15 @@ class PredictorVJEPA2:
443
  self._feature_buffer = []
444
  self._vjepa_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1, 1)
445
  self._vjepa_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1, 1)
 
 
 
 
 
 
 
 
 
446
  self._load_models(self.model_dir)
447
 
448
  def _amp_context(self):
@@ -509,7 +756,9 @@ class PredictorVJEPA2:
509
  sys.path.insert(0, str(vjepa2_path))
510
 
511
  from src.models import vision_transformer as vjepa_vit
 
512
  from src.utils.checkpoint_loader import robust_checkpoint_loader
 
513
 
514
  encoder_path = os.path.join(model_dir, "vjepa_encoder_human.pt")
515
  if not os.path.exists(encoder_path):
@@ -530,6 +779,14 @@ class PredictorVJEPA2:
530
  encoder_load = self.encoder.load_state_dict(encoder_state, strict=False)
531
  self._validate_load_result(encoder_load, "V-JEPA2 encoder")
532
  self.encoder.to(self.device).eval()
 
 
 
 
 
 
 
 
533
 
534
  decoder_path = os.path.join(model_dir, "mlp_decoder_human.pth")
535
  if not os.path.exists(decoder_path):
@@ -566,6 +823,7 @@ class PredictorVJEPA2:
566
  def reset_state(self):
567
  self._frame_buffer = []
568
  self._feature_buffer = []
 
569
  if torch.cuda.is_available():
570
  torch.cuda.empty_cache()
571
 
@@ -574,6 +832,67 @@ class PredictorVJEPA2:
574
  self.predict(dummy)
575
  self.reset_state()
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  def unload(self):
578
  if self.encoder is not None:
579
  self.encoder.to("cpu")
@@ -583,15 +902,30 @@ class PredictorVJEPA2:
583
  self.decoder = None
584
  self._frame_buffer = []
585
  self._feature_buffer = []
 
 
586
  self.available = False
587
  if torch.cuda.is_available():
588
  torch.cuda.empty_cache()
589
 
590
  @torch.inference_mode()
591
- def predict(self, rgb_image: np.ndarray):
592
  if not self.available:
593
  raise RuntimeError("V-JEPA2 predictor is not available")
594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  frame = np.ascontiguousarray(rgb_image, dtype=np.uint8)
596
  self._frame_buffer.append(frame)
597
  if len(self._frame_buffer) > self._clip_frames:
@@ -625,7 +959,30 @@ class PredictorVJEPA2:
625
  confidence = float(np.max(pred_np))
626
  phase_idx = int(np.argmax(pred_np))
627
  phase = self.label_dict.get(phase_idx, "idle")
628
- return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
 
631
  def create_predictor(model_key: str, model_dir: str | None = None, device: str | None = None):
 
1
  from __future__ import annotations
2
 
3
+ import math
4
  import os
5
  import sys
6
  from contextlib import nullcontext
 
22
  from model.resnet import ResNet
23
  from model.mstcn import MultiStageModel
24
  from model.transformer import Transformer
25
+ from explainability import (
26
+ ExplainabilitySpec,
27
+ ModuleOutputRecorder,
28
+ clamp_index,
29
+ feature_energy_map,
30
+ render_heatmap_overlay,
31
+ render_temporal_strip,
32
+ resize_rgb_image,
33
+ )
34
 
35
  PHASE_LABELS = ("idle", "marking", "injection", "dissection")
36
  MODEL_LABELS = {
 
118
  raise FileNotFoundError(f"Required vendor repo '{repo_name}' not found. Stage it into this folder or keep the repo-root copy available.")
119
 
120
 
121
+ def _build_explainability_payload(
122
+ *,
123
+ display_image: np.ndarray,
124
+ encoder_heatmap: np.ndarray,
125
+ encoder_kind: str,
126
+ encoder_label: str,
127
+ decoder_values,
128
+ decoder_kind: str,
129
+ decoder_label: str,
130
+ active_decoder_index: int | None = None,
131
+ encoder_layer: int | None = None,
132
+ encoder_head: int | None = None,
133
+ notes: str | None = None,
134
+ ) -> dict:
135
+ payload = {
136
+ "encoder_kind": encoder_kind,
137
+ "encoder_label": encoder_label,
138
+ "encoder_visualization": render_heatmap_overlay(display_image, encoder_heatmap),
139
+ "decoder_kind": decoder_kind,
140
+ "decoder_label": decoder_label,
141
+ "decoder_visualization": render_temporal_strip(decoder_values, active_index=active_decoder_index),
142
+ }
143
+ if encoder_layer is not None:
144
+ payload["encoder_layer"] = int(encoder_layer)
145
+ if encoder_head is not None:
146
+ payload["encoder_head"] = int(encoder_head)
147
+ if notes:
148
+ payload["notes"] = notes
149
+ return payload
150
+
151
+
152
  class Predictor:
153
  def __init__(self, model_dir: str | None = None, device: str = "cuda"):
154
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
 
159
  self.frame_feature_cache = None
160
  self.label_dict = dict(enumerate(PHASE_LABELS))
161
  self.available = False
162
+ self._resnet_activation = None
163
+ self._resnet_activation_hook = None
164
+ self._explainability_spec = ExplainabilitySpec(
165
+ encoder_mode="proxy",
166
+ encoder_label="ResNet layer4 activation energy (proxy)",
167
+ decoder_mode="attention",
168
+ decoder_label="Temporal Transformer attention",
169
+ )
170
 
171
  self._norm_mean = None
172
  self._norm_std = None
 
183
  paras = {k.replace("share.", "resnet."): v for k, v in paras.items()}
184
  self.resnet.load_state_dict(paras, strict=True)
185
  self.resnet.to(self.device).eval()
186
+ self._resnet_activation_hook = self.resnet.resnet.layer4[-1].relu.register_forward_hook(
187
+ self._capture_resnet_activation
188
+ )
189
 
190
  self.fusion = MultiStageModel(
191
  mstcn_stages=2,
 
226
  self.predict(dummy)
227
  self.reset_state()
228
 
229
+ def _capture_resnet_activation(self, module, inputs, output): # pragma: no cover - hook signature
230
+ self._resnet_activation = output.detach()
231
+
232
  def reset_state(self):
233
  self.frame_feature_cache = None
234
+ self._resnet_activation = None
235
  if torch.cuda.is_available():
236
  torch.cuda.empty_cache()
237
 
238
+ def get_explainability_spec(self) -> ExplainabilitySpec:
239
+ return self._explainability_spec
240
+
241
  def unload(self):
242
  self.available = False
243
  self.resnet.to("cpu")
 
247
  self.fusion = None
248
  self.transformer = None
249
  self.frame_feature_cache = None
250
+ self._resnet_activation = None
251
+ if self._resnet_activation_hook is not None:
252
+ self._resnet_activation_hook.remove()
253
+ self._resnet_activation_hook = None
254
  if torch.cuda.is_available():
255
  torch.cuda.empty_cache()
256
 
 
263
  self.frame_feature_cache = torch.cat([self.frame_feature_cache, feature], dim=0)
264
 
265
  @torch.inference_mode()
266
+ def predict(self, rgb_image: np.ndarray, explainability: dict | None = None):
267
+ explain_enabled = bool(explainability and explainability.get("enabled"))
268
+ attention_meta = None
269
+ display_image = resize_rgb_image(rgb_image, (224, 224)) if explain_enabled else None
270
  if self._norm_mean is not None:
271
  tensor = self._preprocess_gpu(rgb_image)
272
  else:
 
282
  single_frame_feature = feature.unsqueeze(1)
283
  temporal_input = single_frame_feature.transpose(1, 2)
284
  temporal_feature = self.fusion(temporal_input)
285
+ transformer_outputs = self.transformer(
286
+ temporal_feature.detach(),
287
+ single_frame_feature,
288
+ return_attention=explain_enabled,
289
+ )
290
+ if explain_enabled:
291
+ outputs, attention_meta = transformer_outputs
292
+ else:
293
+ outputs = transformer_outputs
294
  final_logits = outputs[-1, -1, :]
295
  probs = F.softmax(final_logits.float(), dim=-1)
296
  pred_np = probs.detach().cpu().numpy()
297
  confidence = float(np.max(pred_np))
298
  phase_idx = max(0, min(3, int(np.argmax(pred_np))))
299
  phase = self.label_dict.get(phase_idx, "idle")
300
+ frames_used = 1
301
+ result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": frames_used}
302
+ if explain_enabled and attention_meta is not None and display_image is not None and self._resnet_activation is not None:
303
+ encoder_heatmap = feature_energy_map(self._resnet_activation, display_image.shape[:2])
304
+ result["explainability"] = _build_explainability_payload(
305
+ display_image=display_image,
306
+ encoder_heatmap=encoder_heatmap,
307
+ encoder_kind="proxy",
308
+ encoder_label=self._explainability_spec.encoder_label,
309
+ decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
310
+ decoder_kind="attention",
311
+ decoder_label=self._explainability_spec.decoder_label,
312
+ active_decoder_index=frames_used - 1,
313
+ notes="Encoder view is a proxy activation map because the ResNet backbone is not attention-based.",
314
+ )
315
+ return result
316
 
317
  if self.frame_feature_cache.shape[0] < 30:
318
  available_frames = self.frame_feature_cache.shape[0] + 1
319
  cat_frame_feature = torch.cat([self.frame_feature_cache, feature], dim=0).unsqueeze(0)
320
  temporal_input = cat_frame_feature.transpose(1, 2)
321
  temporal_feature = self.fusion(temporal_input)
322
+ transformer_outputs = self.transformer(
323
+ temporal_feature.detach(),
324
+ cat_frame_feature,
325
+ return_attention=explain_enabled,
326
+ )
327
+ if explain_enabled:
328
+ outputs, attention_meta = transformer_outputs
329
+ else:
330
+ outputs = transformer_outputs
331
  final_logits = outputs[-1, -1, :]
332
  probs = F.softmax(final_logits.float(), dim=-1)
333
  pred_np = probs.detach().cpu().numpy()
334
  confidence = float(np.max(pred_np))
335
  phase_idx = max(0, min(3, int(np.argmax(pred_np))))
336
  phase = self.label_dict.get(phase_idx, "idle")
337
+ result = {
338
+ "phase": phase,
339
+ "probs": pred_np.tolist(),
340
+ "confidence": confidence,
341
+ "frames_used": available_frames,
342
+ }
343
+ if explain_enabled and attention_meta is not None and display_image is not None and self._resnet_activation is not None:
344
+ encoder_heatmap = feature_energy_map(self._resnet_activation, display_image.shape[:2])
345
+ result["explainability"] = _build_explainability_payload(
346
+ display_image=display_image,
347
+ encoder_heatmap=encoder_heatmap,
348
+ encoder_kind="proxy",
349
+ encoder_label=self._explainability_spec.encoder_label,
350
+ decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
351
+ decoder_kind="attention",
352
+ decoder_label=self._explainability_spec.decoder_label,
353
+ active_decoder_index=available_frames - 1,
354
+ notes="Encoder view is a proxy activation map because the ResNet backbone is not attention-based.",
355
+ )
356
+ return result
357
 
358
  cat_frame_feature = self.frame_feature_cache.unsqueeze(0)
359
  temporal_input = cat_frame_feature.transpose(1, 2)
360
  temporal_feature = self.fusion(temporal_input)
361
+ transformer_outputs = self.transformer(
362
+ temporal_feature.detach(),
363
+ cat_frame_feature,
364
+ return_attention=explain_enabled,
365
+ )
366
+ if explain_enabled:
367
+ outputs, attention_meta = transformer_outputs
368
+ else:
369
+ outputs = transformer_outputs
370
  final_logits = outputs[-1, -1, :]
371
  probs = F.softmax(final_logits.float(), dim=-1)
372
  pred_np = probs.detach().cpu().numpy()
 
374
  confidence = float(np.max(pred_np))
375
  phase_idx = max(0, min(3, int(np.argmax(pred_np))))
376
  phase = self.label_dict.get(phase_idx, "idle")
377
+ frames_used = min(self.trans_seq, self.frame_feature_cache.shape[0])
378
+ result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": frames_used}
379
+ if explain_enabled and attention_meta is not None and display_image is not None and self._resnet_activation is not None:
380
+ encoder_heatmap = feature_energy_map(self._resnet_activation, display_image.shape[:2])
381
+ result["explainability"] = _build_explainability_payload(
382
+ display_image=display_image,
383
+ encoder_heatmap=encoder_heatmap,
384
+ encoder_kind="proxy",
385
+ encoder_label=self._explainability_spec.encoder_label,
386
+ decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
387
+ decoder_kind="attention",
388
+ decoder_label=self._explainability_spec.decoder_label,
389
+ active_decoder_index=frames_used - 1,
390
+ notes="Encoder view is a proxy activation map because the ResNet backbone is not attention-based.",
391
+ )
392
+ return result
393
 
394
 
395
  class PredictorDinoV2:
 
406
  A.CenterCrop(height=224, width=224),
407
  A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
408
  ])
409
+ self.display_aug = A.Compose([
410
+ A.SmallestMaxSize(max_size=256, interpolation=cv2.INTER_LINEAR),
411
+ A.CenterCrop(height=224, width=224),
412
+ ])
413
  self.frame_features = []
414
+ self._attention_recorder = ModuleOutputRecorder()
415
+ self._attention_layer_index = None
416
+ self._explainability_spec = ExplainabilitySpec(
417
+ encoder_mode="attention",
418
+ encoder_label="DINOv2 encoder self-attention",
419
+ decoder_mode="attention",
420
+ decoder_label="Fusion Transformer temporal attention",
421
+ )
422
  self._load_models(self.model_dir)
423
 
424
  def _amp_context(self):
 
448
  encoder_load = self.backbone.load_state_dict(encoder_state, strict=False)
449
  _validate_load_result(encoder_load, "DINOv2 backbone")
450
  self.backbone.to(self.device).eval()
451
+ self._explainability_spec = ExplainabilitySpec(
452
+ encoder_mode="attention",
453
+ encoder_label="DINOv2 encoder self-attention",
454
+ decoder_mode="attention",
455
+ decoder_label="Fusion Transformer temporal attention",
456
+ encoder_layer_count=len(self.backbone.blocks),
457
+ encoder_head_count=int(self.backbone.num_heads),
458
+ )
459
 
460
  decoder_path = os.path.join(model_dir, "fusion_transformer_decoder_best_model.pth")
461
  if not os.path.exists(decoder_path):
 
485
  d_model=d_model,
486
  )
487
 
488
+ def forward(self, x, return_attention=False):
489
  x = x.permute(0, 2, 1)
490
  x_reduced = self.reduce(x)
491
  mstcn_input = x_reduced.permute(0, 2, 1)
 
500
  else:
501
  transformer_input = mstcn_input.detach()
502
 
503
+ transformer_outputs = self.transformer(
504
+ transformer_input,
505
+ x_reduced,
506
+ return_attention=return_attention,
507
+ )
508
+ if return_attention:
509
+ transformer_out, attention_meta = transformer_outputs
510
+ return transformer_out.permute(0, 2, 1), attention_meta
511
+ return transformer_outputs.permute(0, 2, 1)
512
 
513
  self.decoder = FusionTransformerDecoder()
514
  decoder_load = self.decoder.load_state_dict(decoder_state, strict=False)
 
525
 
526
  def reset_state(self):
527
  self.frame_features = []
528
+ self._attention_recorder.clear()
529
  if torch.cuda.is_available():
530
  torch.cuda.empty_cache()
531
 
 
534
  self.predict(dummy_img)
535
  self.reset_state()
536
 
537
+ def get_explainability_spec(self) -> ExplainabilitySpec:
538
+ return self._explainability_spec
539
+
540
+ def _ensure_attention_hook(self, layer_index: int) -> None:
541
+ clamped_layer = clamp_index(layer_index, self._explainability_spec.encoder_layer_count)
542
+ if self._attention_layer_index == clamped_layer and self._attention_recorder.handle is not None:
543
+ return
544
+ self._attention_recorder.attach(self.backbone.blocks[clamped_layer].norm1)
545
+ self._attention_layer_index = clamped_layer
546
+
547
+ def _compute_encoder_attention_map(self, head_index: int, output_shape: tuple[int, int]) -> np.ndarray:
548
+ if self._attention_recorder.output is None or self._attention_layer_index is None:
549
+ raise RuntimeError("DINO encoder attention recorder did not capture any tokens")
550
+
551
+ tokens = self._attention_recorder.output.to(self.device)
552
+ block = self.backbone.blocks[self._attention_layer_index]
553
+ attn_module = block.attn
554
+ qkv = attn_module.qkv(tokens).reshape(tokens.shape[0], tokens.shape[1], 3, attn_module.num_heads, -1).permute(
555
+ 2, 0, 3, 1, 4
556
+ )
557
+ q = qkv[0] * attn_module.scale
558
+ k = qkv[1]
559
+ attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
560
+
561
+ head = clamp_index(head_index, attn.shape[1])
562
+ patch_start = 1 + int(getattr(self.backbone, "num_register_tokens", 0))
563
+ cls_attention = attn[0, head, 0, patch_start:]
564
+ patch_count = int(cls_attention.numel())
565
+ grid_size = int(math.sqrt(patch_count))
566
+ if grid_size * grid_size != patch_count:
567
+ raise RuntimeError(f"Unexpected DINO patch attention size: {patch_count}")
568
+ heatmap = cls_attention.view(grid_size, grid_size).detach().cpu().numpy()
569
+ return cv2.resize(heatmap, (output_shape[1], output_shape[0]), interpolation=cv2.INTER_CUBIC)
570
+
571
  def unload(self):
572
  if self.backbone is not None:
573
  self.backbone.to("cpu")
 
576
  self.backbone = None
577
  self.decoder = None
578
  self.frame_features = []
579
+ self._attention_recorder.remove()
580
+ self._attention_layer_index = None
581
  self.available = False
582
  if torch.cuda.is_available():
583
  torch.cuda.empty_cache()
584
 
585
  @torch.inference_mode()
586
+ def predict(self, rgb_image: np.ndarray, explainability: dict | None = None):
587
  if not self.available or self.backbone is None or self.decoder is None:
588
  raise RuntimeError("DINO-Endo predictor is not available")
589
 
590
+ explain_enabled = bool(explainability and explainability.get("enabled"))
591
+ encoder_layer = clamp_index(
592
+ explainability.get("encoder_layer") if explainability else None,
593
+ self._explainability_spec.encoder_layer_count,
594
+ )
595
+ encoder_head = clamp_index(
596
+ explainability.get("encoder_head") if explainability else None,
597
+ self._explainability_spec.encoder_head_count,
598
+ )
599
+ if explain_enabled:
600
+ self._ensure_attention_hook(encoder_layer)
601
+ self._attention_recorder.clear()
602
+ display_image = self.display_aug(image=rgb_image)["image"]
603
+ else:
604
+ display_image = None
605
+
606
  processed = self.aug(image=rgb_image)["image"]
607
  chw = np.transpose(processed, (2, 0, 1))
608
  tensor = torch.tensor(chw, dtype=torch.float32).unsqueeze(0).to(self.device)
 
627
 
628
  decoder_input = seq.transpose(1, 2)
629
  with self._amp_context():
630
+ decoder_outputs = self.decoder(decoder_input, return_attention=explain_enabled)
631
+ if explain_enabled:
632
+ logits, attention_meta = decoder_outputs
633
+ else:
634
+ logits = decoder_outputs
635
 
636
  if logits.dim() != 3:
637
  raise ValueError(f"Unexpected DINOv2 decoder output shape: {tuple(logits.shape)}")
 
647
  confidence = float(np.max(pred_np))
648
  phase_idx = int(np.argmax(pred_np))
649
  phase = self.label_dict.get(phase_idx, "idle")
650
+ result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
651
+ if explain_enabled and display_image is not None:
652
+ encoder_heatmap = self._compute_encoder_attention_map(encoder_head, display_image.shape[:2])
653
+ result["explainability"] = _build_explainability_payload(
654
+ display_image=display_image,
655
+ encoder_heatmap=encoder_heatmap,
656
+ encoder_kind="attention",
657
+ encoder_label=self._explainability_spec.encoder_label,
658
+ decoder_values=attention_meta["decoder_strip"].detach().cpu().numpy(),
659
+ decoder_kind="attention",
660
+ decoder_label=self._explainability_spec.decoder_label,
661
+ active_decoder_index=available_frames - 1,
662
+ encoder_layer=encoder_layer,
663
+ encoder_head=encoder_head,
664
+ )
665
+ return result
666
 
667
 
668
  class PredictorVJEPA2:
 
681
  self._feature_buffer = []
682
  self._vjepa_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1, 1)
683
  self._vjepa_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1, 1)
684
+ self._attention_recorder = ModuleOutputRecorder()
685
+ self._attention_layer_index = None
686
+ self._rotate_queries_or_keys = None
687
+ self._explainability_spec = ExplainabilitySpec(
688
+ encoder_mode="attention",
689
+ encoder_label="V-JEPA2 encoder self-attention",
690
+ decoder_mode="proxy",
691
+ decoder_label="MLP decoder feature energy (proxy)",
692
+ )
693
  self._load_models(self.model_dir)
694
 
695
  def _amp_context(self):
 
756
  sys.path.insert(0, str(vjepa2_path))
757
 
758
  from src.models import vision_transformer as vjepa_vit
759
+ from src.models.utils.modules import rotate_queries_or_keys
760
  from src.utils.checkpoint_loader import robust_checkpoint_loader
761
+ self._rotate_queries_or_keys = rotate_queries_or_keys
762
 
763
  encoder_path = os.path.join(model_dir, "vjepa_encoder_human.pt")
764
  if not os.path.exists(encoder_path):
 
779
  encoder_load = self.encoder.load_state_dict(encoder_state, strict=False)
780
  self._validate_load_result(encoder_load, "V-JEPA2 encoder")
781
  self.encoder.to(self.device).eval()
782
+ self._explainability_spec = ExplainabilitySpec(
783
+ encoder_mode="attention",
784
+ encoder_label="V-JEPA2 encoder self-attention",
785
+ decoder_mode="proxy",
786
+ decoder_label="MLP decoder feature energy (proxy)",
787
+ encoder_layer_count=len(self.encoder.blocks),
788
+ encoder_head_count=int(self.encoder.num_heads),
789
+ )
790
 
791
  decoder_path = os.path.join(model_dir, "mlp_decoder_human.pth")
792
  if not os.path.exists(decoder_path):
 
823
  def reset_state(self):
824
  self._frame_buffer = []
825
  self._feature_buffer = []
826
+ self._attention_recorder.clear()
827
  if torch.cuda.is_available():
828
  torch.cuda.empty_cache()
829
 
 
832
  self.predict(dummy)
833
  self.reset_state()
834
 
835
+ def get_explainability_spec(self) -> ExplainabilitySpec:
836
+ return self._explainability_spec
837
+
838
+ def _ensure_attention_hook(self, layer_index: int) -> None:
839
+ clamped_layer = clamp_index(layer_index, self._explainability_spec.encoder_layer_count)
840
+ if self._attention_layer_index == clamped_layer and self._attention_recorder.handle is not None:
841
+ return
842
+ self._attention_recorder.attach(self.encoder.blocks[clamped_layer].norm1)
843
+ self._attention_layer_index = clamped_layer
844
+
845
+ def _compute_encoder_attention_map(
846
+ self,
847
+ *,
848
+ head_index: int,
849
+ temporal_group_index: int,
850
+ output_shape: tuple[int, int],
851
+ ) -> np.ndarray:
852
+ if self._attention_recorder.output is None or self._attention_layer_index is None:
853
+ raise RuntimeError("V-JEPA2 encoder attention recorder did not capture any tokens")
854
+ if self._rotate_queries_or_keys is None:
855
+ raise RuntimeError("V-JEPA2 rotation helper is unavailable")
856
+
857
+ tokens = self._attention_recorder.output.to(self.device)
858
+ block = self.encoder.blocks[self._attention_layer_index]
859
+ attn_module = block.attn
860
+ qkv = attn_module.qkv(tokens).unflatten(-1, (3, attn_module.num_heads, -1)).permute(2, 0, 3, 1, 4)
861
+ q, k = qkv[0], qkv[1]
862
+
863
+ patch_grid = self._crop_size // 16
864
+ temporal_groups = self._clip_frames // self._tubelet_size
865
+ if hasattr(attn_module, "separate_positions"):
866
+ mask = torch.arange(int(temporal_groups * patch_grid * patch_grid), device=tokens.device)
867
+ d_mask, h_mask, w_mask = attn_module.separate_positions(mask, patch_grid, patch_grid)
868
+ offset = 0
869
+ qd = self._rotate_queries_or_keys(q[..., offset : offset + attn_module.d_dim], pos=d_mask)
870
+ kd = self._rotate_queries_or_keys(k[..., offset : offset + attn_module.d_dim], pos=d_mask)
871
+ offset += attn_module.d_dim
872
+ qh = self._rotate_queries_or_keys(q[..., offset : offset + attn_module.h_dim], pos=h_mask)
873
+ kh = self._rotate_queries_or_keys(k[..., offset : offset + attn_module.h_dim], pos=h_mask)
874
+ offset += attn_module.h_dim
875
+ qw = self._rotate_queries_or_keys(q[..., offset : offset + attn_module.w_dim], pos=w_mask)
876
+ kw = self._rotate_queries_or_keys(k[..., offset : offset + attn_module.w_dim], pos=w_mask)
877
+ offset += attn_module.w_dim
878
+ q_parts = [qd, qh, qw]
879
+ k_parts = [kd, kh, kw]
880
+ if offset < attn_module.head_dim:
881
+ q_parts.append(q[..., offset:])
882
+ k_parts.append(k[..., offset:])
883
+ q = torch.cat(q_parts, dim=-1)
884
+ k = torch.cat(k_parts, dim=-1)
885
+
886
+ attn = ((q @ k.transpose(-2, -1)) * attn_module.scale).softmax(dim=-1)
887
+ head = clamp_index(head_index, attn.shape[1])
888
+ group_size = patch_grid * patch_grid
889
+ group_index = clamp_index(temporal_group_index, temporal_groups)
890
+ start = group_index * group_size
891
+ end = start + group_size
892
+ group_attention = attn[0, head, start:end, start:end].mean(dim=0)
893
+ heatmap = group_attention.view(patch_grid, patch_grid).detach().cpu().numpy()
894
+ return cv2.resize(heatmap, (output_shape[1], output_shape[0]), interpolation=cv2.INTER_CUBIC)
895
+
896
  def unload(self):
897
  if self.encoder is not None:
898
  self.encoder.to("cpu")
 
902
  self.decoder = None
903
  self._frame_buffer = []
904
  self._feature_buffer = []
905
+ self._attention_recorder.remove()
906
+ self._attention_layer_index = None
907
  self.available = False
908
  if torch.cuda.is_available():
909
  torch.cuda.empty_cache()
910
 
911
  @torch.inference_mode()
912
+ def predict(self, rgb_image: np.ndarray, explainability: dict | None = None):
913
  if not self.available:
914
  raise RuntimeError("V-JEPA2 predictor is not available")
915
 
916
+ explain_enabled = bool(explainability and explainability.get("enabled"))
917
+ encoder_layer = clamp_index(
918
+ explainability.get("encoder_layer") if explainability else None,
919
+ self._explainability_spec.encoder_layer_count,
920
+ )
921
+ encoder_head = clamp_index(
922
+ explainability.get("encoder_head") if explainability else None,
923
+ self._explainability_spec.encoder_head_count,
924
+ )
925
+ if explain_enabled:
926
+ self._ensure_attention_hook(encoder_layer)
927
+ self._attention_recorder.clear()
928
+
929
  frame = np.ascontiguousarray(rgb_image, dtype=np.uint8)
930
  self._frame_buffer.append(frame)
931
  if len(self._frame_buffer) > self._clip_frames:
 
959
  confidence = float(np.max(pred_np))
960
  phase_idx = int(np.argmax(pred_np))
961
  phase = self.label_dict.get(phase_idx, "idle")
962
+ result = {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
963
+ if explain_enabled:
964
+ latest_group_index = latest_feature_idx // self._tubelet_size
965
+ display_image = resize_rgb_image(frame, (self._crop_size, self._crop_size))
966
+ encoder_heatmap = self._compute_encoder_attention_map(
967
+ head_index=encoder_head,
968
+ temporal_group_index=latest_group_index,
969
+ output_shape=display_image.shape[:2],
970
+ )
971
+ decoder_proxy_values = [feature.abs().mean().item() for feature in self._feature_buffer]
972
+ result["explainability"] = _build_explainability_payload(
973
+ display_image=display_image,
974
+ encoder_heatmap=encoder_heatmap,
975
+ encoder_kind="attention",
976
+ encoder_label=self._explainability_spec.encoder_label,
977
+ decoder_values=decoder_proxy_values,
978
+ decoder_kind="proxy",
979
+ decoder_label=self._explainability_spec.decoder_label,
980
+ active_decoder_index=available_frames - 1,
981
+ encoder_layer=encoder_layer,
982
+ encoder_head=encoder_head,
983
+ notes="Decoder view is a proxy feature-energy strip because the V-JEPA2 classifier head is an MLP.",
984
+ )
985
+ return result
986
 
987
 
988
  def create_predictor(model_key: str, model_dir: str | None = None, device: str | None = None):
runtime-requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu121
2
- streamlit>=1.40,<2
3
  torch==2.5.1
4
  torchvision==0.20.1
5
  numpy>=1.26,<3
 
1
  --extra-index-url https://download.pytorch.org/whl/cu121
2
+ streamlit>=1.55,<2
3
  torch==2.5.1
4
  torchvision==0.20.1
5
  numpy>=1.26,<3
scripts/publish_model_repo.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import shutil
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ from huggingface_hub import HfApi
10
+
11
+ import sys
12
+
13
+ SCRIPT_PATH = Path(__file__).resolve()
14
+ SPACE_ROOT = SCRIPT_PATH.parents[1]
15
+ if str(SPACE_ROOT) not in sys.path:
16
+ sys.path.insert(0, str(SPACE_ROOT))
17
+
18
+ from model_registry import MODEL_SPECS
19
+
20
+
21
+ ENV_VAR_BY_FAMILY = {
22
+ "aiendo": "AIENDO_MODEL_REPO_ID",
23
+ "dinov2": "DINO_MODEL_REPO_ID",
24
+ "vjepa2": "VJEPA2_MODEL_REPO_ID",
25
+ }
26
+
27
+
28
+ def _render_model_card(*, family: str, repo_id: str, copied_files: list[str]) -> str:
29
+ spec = MODEL_SPECS[family]
30
+ file_list = "\n".join(f"- `{name}`" for name in copied_files)
31
+ return f"""---
32
+ tags:
33
+ - medical-imaging
34
+ - endoscopy
35
+ - surgical-phase-recognition
36
+ - {family}
37
+ ---
38
+
39
+ # {spec.label} checkpoints for the AI-Endo Hugging Face Space
40
+
41
+ This repository stores the published checkpoint set for the **{spec.label}** phase-recognition path used by `hf_spaces/DINO-ENDO/`.
42
+
43
+ ## Files
44
+
45
+ {file_list}
46
+
47
+ ## Consumed by the Space
48
+
49
+ Set the following Space environment variable so the Streamlit Space can download these files lazily at runtime:
50
+
51
+ ```text
52
+ {ENV_VAR_BY_FAMILY[family]}={repo_id}
53
+ ```
54
+ """
55
+
56
+
57
+ def _stage_model_family(*, family: str, model_dir: Path, staging_dir: Path, repo_id: str) -> int:
58
+ spec = MODEL_SPECS[family]
59
+ copied_files: list[str] = []
60
+ total_bytes = 0
61
+
62
+ for filename in spec.required_files:
63
+ src = model_dir / filename
64
+ if not src.exists():
65
+ raise FileNotFoundError(f"Missing required checkpoint: {src}")
66
+ dst = staging_dir / filename
67
+ shutil.copy2(src, dst)
68
+ copied_files.append(filename)
69
+ total_bytes += src.stat().st_size
70
+
71
+ for filename in spec.optional_files:
72
+ src = model_dir / filename
73
+ if not src.exists():
74
+ continue
75
+ dst = staging_dir / filename
76
+ shutil.copy2(src, dst)
77
+ copied_files.append(filename)
78
+ total_bytes += src.stat().st_size
79
+
80
+ (staging_dir / "README.md").write_text(
81
+ _render_model_card(family=family, repo_id=repo_id, copied_files=copied_files),
82
+ encoding="utf-8",
83
+ )
84
+ return total_bytes
85
+
86
+
87
+ def _should_use_large_upload(mode: str, total_bytes: int) -> bool:
88
+ if mode == "always":
89
+ return True
90
+ if mode == "never":
91
+ return False
92
+ return total_bytes >= 2 * 1024 * 1024 * 1024
93
+
94
+
95
+ def parse_args() -> argparse.Namespace:
96
+ parser = argparse.ArgumentParser(description="Publish a model-family checkpoint repo for the HF Space.")
97
+ parser.add_argument("--family", choices=sorted(MODEL_SPECS), required=True)
98
+ parser.add_argument("--repo-id", required=True, help="Target Hugging Face model repo ID.")
99
+ parser.add_argument(
100
+ "--model-dir",
101
+ default=str(SPACE_ROOT / "model"),
102
+ help="Directory containing the local checkpoints to publish.",
103
+ )
104
+ parser.add_argument(
105
+ "--upload-mode",
106
+ choices=("auto", "never", "always"),
107
+ default="auto",
108
+ help="Choose whether to force upload_large_folder for this family.",
109
+ )
110
+ parser.add_argument("--revision", default=None, help="Optional target revision or branch.")
111
+ parser.add_argument("--private", action="store_true", help="Create the model repo as private.")
112
+ parser.add_argument(
113
+ "--token-env",
114
+ default="HF_TOKEN",
115
+ help="Environment variable name containing the Hugging Face write token.",
116
+ )
117
+ return parser.parse_args()
118
+
119
+
120
+ def main() -> None:
121
+ args = parse_args()
122
+ model_dir = Path(args.model_dir).expanduser().resolve()
123
+ token = os.getenv(args.token_env) or None
124
+ api = HfApi(token=token)
125
+ api.create_repo(repo_id=args.repo_id, repo_type="model", private=args.private, exist_ok=True)
126
+
127
+ with tempfile.TemporaryDirectory(prefix=f"hf-space-{args.family}-") as temp_dir:
128
+ staging_dir = Path(temp_dir)
129
+ total_bytes = _stage_model_family(
130
+ family=args.family,
131
+ model_dir=model_dir,
132
+ staging_dir=staging_dir,
133
+ repo_id=args.repo_id,
134
+ )
135
+
136
+ upload_kwargs = {
137
+ "repo_id": args.repo_id,
138
+ "repo_type": "model",
139
+ "folder_path": str(staging_dir),
140
+ }
141
+ if args.revision:
142
+ upload_kwargs["revision"] = args.revision
143
+
144
+ if _should_use_large_upload(args.upload_mode, total_bytes):
145
+ api.upload_large_folder(**upload_kwargs)
146
+ mode = "upload_large_folder"
147
+ else:
148
+ api.upload_folder(**upload_kwargs)
149
+ mode = "upload_folder"
150
+
151
+ print(f"Published {args.family} checkpoints to {args.repo_id} via {mode}")
152
+ print(f"Suggested Space variable: {ENV_VAR_BY_FAMILY[args.family]}={args.repo_id}")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()
scripts/publish_space_repo.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ from huggingface_hub import HfApi
9
+
10
+ import sys
11
+
12
+ SCRIPT_PATH = Path(__file__).resolve()
13
+ SPACE_ROOT = SCRIPT_PATH.parents[1]
14
+ if str(SPACE_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(SPACE_ROOT))
16
+
17
+ from stage_space_bundle import stage_bundle
18
+
19
+
20
+ def _space_variables(args: argparse.Namespace) -> dict[str, str]:
21
+ variables = {
22
+ "SPACE_ENABLED_MODELS": args.enabled_models,
23
+ "SPACE_DEFAULT_MODEL": args.default_model,
24
+ }
25
+ if args.aiendo_model_repo_id:
26
+ variables["AIENDO_MODEL_REPO_ID"] = args.aiendo_model_repo_id
27
+ if args.dino_model_repo_id:
28
+ variables["DINO_MODEL_REPO_ID"] = args.dino_model_repo_id
29
+ if args.vjepa2_model_repo_id:
30
+ variables["VJEPA2_MODEL_REPO_ID"] = args.vjepa2_model_repo_id
31
+ return variables
32
+
33
+
34
+ def parse_args() -> argparse.Namespace:
35
+ parser = argparse.ArgumentParser(description="Publish the staged Docker Space bundle and set its variables.")
36
+ parser.add_argument("--repo-id", required=True, help="Target Hugging Face Space repo ID.")
37
+ parser.add_argument(
38
+ "--bundle-dir",
39
+ default=None,
40
+ help="Optional pre-staged bundle directory. If omitted, a temporary bundle is staged automatically.",
41
+ )
42
+ parser.add_argument("--enabled-models", default="dinov2,aiendo,vjepa2")
43
+ parser.add_argument("--default-model", default="dinov2")
44
+ parser.add_argument("--aiendo-model-repo-id", default=None)
45
+ parser.add_argument("--dino-model-repo-id", default=None)
46
+ parser.add_argument("--vjepa2-model-repo-id", default=None)
47
+ parser.add_argument("--revision", default=None, help="Optional target revision or branch.")
48
+ parser.add_argument("--private", action="store_true", help="Create the Space repo as private.")
49
+ parser.add_argument(
50
+ "--token-env",
51
+ default="HF_TOKEN",
52
+ help="Environment variable name containing the Hugging Face write token.",
53
+ )
54
+ return parser.parse_args()
55
+
56
+
57
+ def _publish_bundle(api: HfApi, *, repo_id: str, bundle_dir: Path, revision: str | None) -> None:
58
+ upload_kwargs = {
59
+ "repo_id": repo_id,
60
+ "repo_type": "space",
61
+ "folder_path": str(bundle_dir),
62
+ }
63
+ if revision:
64
+ upload_kwargs["revision"] = revision
65
+ api.upload_folder(**upload_kwargs)
66
+
67
+
68
+ def main() -> None:
69
+ args = parse_args()
70
+ token = os.getenv(args.token_env) or None
71
+ api = HfApi(token=token)
72
+ api.create_repo(repo_id=args.repo_id, repo_type="space", space_sdk="docker", private=args.private, exist_ok=True)
73
+
74
+ if args.bundle_dir:
75
+ bundle_dir = Path(args.bundle_dir).expanduser().resolve()
76
+ if not bundle_dir.exists():
77
+ raise FileNotFoundError(f"Bundle directory not found: {bundle_dir}")
78
+ _publish_bundle(api, repo_id=args.repo_id, bundle_dir=bundle_dir, revision=args.revision)
79
+ else:
80
+ with tempfile.TemporaryDirectory(prefix="hf-space-bundle-") as temp_dir:
81
+ bundle_dir = stage_bundle(SPACE_ROOT, Path(temp_dir), overwrite=True)
82
+ _publish_bundle(api, repo_id=args.repo_id, bundle_dir=bundle_dir, revision=args.revision)
83
+
84
+ for key, value in _space_variables(args).items():
85
+ api.add_space_variable(
86
+ repo_id=args.repo_id,
87
+ key=key,
88
+ value=value,
89
+ description=f"Managed by publish_space_repo.py for {key}",
90
+ )
91
+
92
+ print(f"Published Space bundle to {args.repo_id}")
93
+ for key, value in _space_variables(args).items():
94
+ print(f"{key}={value}")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
scripts/stage_space_bundle.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+
8
+ ROOT_FILES = (
9
+ ".dockerignore",
10
+ ".gitattributes",
11
+ ".gitignore",
12
+ "Dockerfile",
13
+ "README.md",
14
+ "app.py",
15
+ "explainability.py",
16
+ "model_manager.py",
17
+ "model_registry.py",
18
+ "predictor.py",
19
+ "requirements.txt",
20
+ "runtime-requirements.txt",
21
+ "video_utils.py",
22
+ )
23
+
24
+ ROOT_DIRS = (
25
+ ".streamlit",
26
+ "dinov2",
27
+ "model",
28
+ "scripts",
29
+ "vjepa2",
30
+ )
31
+
32
+ IGNORE_PATTERNS = (
33
+ ".git",
34
+ ".cache",
35
+ "__pycache__",
36
+ ".pytest_cache",
37
+ ".mypy_cache",
38
+ "*.egg-info",
39
+ "*.ipynb",
40
+ "*.pt",
41
+ "*.pth",
42
+ "*.pyc",
43
+ "*.pyo",
44
+ "assets",
45
+ "notebooks",
46
+ "tests",
47
+ )
48
+
49
+
50
+ def _copy_item(src: Path, dst: Path) -> None:
51
+ if not src.exists():
52
+ raise FileNotFoundError(f"Missing required Space item: {src}")
53
+
54
+ if src.is_dir():
55
+ shutil.copytree(src, dst, ignore=shutil.ignore_patterns(*IGNORE_PATTERNS))
56
+ else:
57
+ dst.parent.mkdir(parents=True, exist_ok=True)
58
+ shutil.copy2(src, dst)
59
+
60
+
61
+ def stage_bundle(space_root: Path, output_dir: Path, overwrite: bool) -> Path:
62
+ if output_dir.exists():
63
+ if not overwrite:
64
+ raise FileExistsError(f"Destination already exists: {output_dir}")
65
+ shutil.rmtree(output_dir)
66
+
67
+ output_dir.mkdir(parents=True, exist_ok=True)
68
+
69
+ for name in ROOT_FILES:
70
+ _copy_item(space_root / name, output_dir / name)
71
+ for name in ROOT_DIRS:
72
+ _copy_item(space_root / name, output_dir / name)
73
+
74
+ return output_dir
75
+
76
+
77
+ def parse_args() -> argparse.Namespace:
78
+ parser = argparse.ArgumentParser(
79
+ description="Stage a code-only Hugging Face Space bundle from the local DINO-ENDO scaffold."
80
+ )
81
+ parser.add_argument(
82
+ "--output-dir",
83
+ default="/tmp/dino_space_minimal_upload",
84
+ help="Destination directory for the staged bundle.",
85
+ )
86
+ parser.add_argument(
87
+ "--overwrite",
88
+ action="store_true",
89
+ help="Replace the destination directory if it already exists.",
90
+ )
91
+ return parser.parse_args()
92
+
93
+
94
+ def main() -> None:
95
+ args = parse_args()
96
+ script_path = Path(__file__).resolve()
97
+ space_root = script_path.parents[1]
98
+ output_dir = Path(args.output_dir).expanduser().resolve()
99
+ staged_dir = stage_bundle(space_root, output_dir, overwrite=args.overwrite)
100
+ print(f"Staged Space bundle at {staged_dir}")
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
scripts/stage_vendor_sources.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+
8
+ def copy_tree(src: Path, dst: Path, overwrite: bool) -> None:
9
+ if not src.exists():
10
+ raise FileNotFoundError(f"Source directory not found: {src}")
11
+ if dst.exists():
12
+ if not overwrite:
13
+ print(f"Skipping existing {dst}")
14
+ return
15
+ shutil.rmtree(dst)
16
+ shutil.copytree(
17
+ src,
18
+ dst,
19
+ ignore=shutil.ignore_patterns('.git', '__pycache__', '.pytest_cache', '.mypy_cache', '*.pyc', '*.pyo'),
20
+ )
21
+ print(f"Copied {src} -> {dst}")
22
+
23
+
24
+ def main() -> None:
25
+ parser = argparse.ArgumentParser(description='Copy vendored dinov2/ and vjepa2/ source trees into the Space folder.')
26
+ parser.add_argument('--overwrite', action='store_true', help='Replace existing destination directories.')
27
+ args = parser.parse_args()
28
+
29
+ script_path = Path(__file__).resolve()
30
+ space_root = script_path.parents[1]
31
+ repo_root = script_path.parents[3]
32
+
33
+ copy_tree(repo_root / 'dinov2', space_root / 'dinov2', overwrite=args.overwrite)
34
+ copy_tree(repo_root / 'vjepa2', space_root / 'vjepa2', overwrite=args.overwrite)
35
+
36
+
37
+ if __name__ == '__main__':
38
+ main()
vjepa2/.flake8 ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 119
3
+ select = E,F,W
4
+ ignore = E203,E701,W503
5
+ per-file-ignores=__init__.py:F401 version.py:F401
vjepa2/.github/workflows/base_tests.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: UnitTests
2
+
3
+ on: [push]
4
+
5
+ jobs:
6
+ unittests:
7
+ runs-on: ubuntu-latest
8
+ strategy:
9
+ max-parallel: 4
10
+
11
+ steps:
12
+ - uses: actions/checkout@v4
13
+ - name: Set up Python 3.12
14
+ uses: actions/setup-python@v5
15
+ with:
16
+ python-version: '3.12'
17
+ - name: Add conda to system path
18
+ run: |
19
+ # $CONDA is an environment variable pointing to the root of the miniconda directory
20
+ echo $CONDA/bin >> $GITHUB_PATH
21
+ - name: Install dependencies
22
+ run: |
23
+ conda create --name test-env python=3.12
24
+ conda install pytest
25
+ echo "Starting setup from $PWD"
26
+ pip install -e .
27
+ - name: Test with pytest
28
+ run: |
29
+ pytest tests
vjepa2/.github/workflows/linters.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint (Common Code)
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - master
7
+ paths:
8
+ - 'app/'
9
+ - 'evals/*.py'
10
+ - 'src/'
11
+ - 'tests/'
12
+ pull_request:
13
+ branches:
14
+ - master
15
+ - 'gh/**'
16
+ paths:
17
+ - 'app/'
18
+ - 'evals/*.py'
19
+ - 'src/'
20
+ - 'tests/'
21
+
22
+ jobs:
23
+ run-linters:
24
+ name: Run linters
25
+ runs-on: ubuntu-latest
26
+
27
+ steps:
28
+ - uses: actions/checkout@v4
29
+ - name: Set up Python 3.12
30
+ uses: actions/setup-python@v5
31
+ with:
32
+ python-version: '3.12'
33
+ - name: Install Python lint dependencies
34
+ run: |
35
+ pip install -r requirements-test.txt
36
+ - name: Set lint paths
37
+ run: echo "lint_paths=app evals/*.py src tests" >> "$GITHUB_ENV"
38
+ - name: Run isort
39
+ run: |
40
+ python -m isort $lint_paths --check
41
+ - name: Run flake8
42
+ if: always()
43
+ run: |
44
+ python -m flake8 --config .flake8 --show-source --statistics $lint_paths
45
+ - name: Run black
46
+ if: always()
47
+ run: |
48
+ python -m black --check $lint_paths
vjepa2/.gitignore ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ .vscode/
3
+ .*.swp
4
+
5
+ run_vjepa_aws.py
6
+ run.py
7
+ main_distributed_video.py
8
+ main_video.py
9
+
10
+ app/vjepa/configs/temp_aws
11
+ app/main_dev.py
12
+ app/main_distributed_dev.py
13
+ evals/ava/alphaction/data
14
+
15
+ run_evals.py
16
+ run_evals_v2.py
17
+ run_pretrain.py
18
+
19
+ *.egg-info/
20
+ *.ipynb_checkpoints/
21
+
22
+ traces/
23
+ third_party/*
24
+
25
+ evals/simu_env_planning/local/
26
+ evals/simu_env_planning/docker2/
27
+ evals/simu_env_planning/docker/
28
+ app/vjepa_droid/local/
29
+ app/vjepa_droid_v2/local/
30
+ app/vjepa_droid_v3/local/
31
+ app/vjepa_droid_v4/local/
32
+ configs/local
vjepa2/APACHE-LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2018-2021 William Falcon
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
vjepa2/CHANGELOG.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ ## [0.0.1] - 2025-06-05
4
+
5
+ Initial release of V-JEPA 2 codebase
vjepa2/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@fb.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
vjepa2/CONTRIBUTING.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to V-JEPA 2
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code is consistent with style guidance (below) and lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+ 7. Add reviewer(s) for approval.
15
+
16
+ ## Contributor License Agreement ("CLA")
17
+ In order to accept your pull request, we need you to submit a CLA. You only need
18
+ to do this once to work on any of Facebook's open source projects.
19
+
20
+ Complete your CLA here: <https://code.facebook.com/cla>
21
+
22
+ ## Issues
23
+ We use GitHub issues to track public bugs. Please ensure your description is
24
+ clear and has sufficient instructions to be able to reproduce the issue.
25
+
26
+ Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
27
+ disclosure of security bugs. In those cases, please go through the process
28
+ outlined on that page and do not file a public issue.
29
+
30
+ ## Coding Style
31
+ * 4 spaces for indentation rather than tabs
32
+ * 119 character line length
33
+ * PEP8 formatting
34
+
35
+ We recommend using `black`, `isort`, and `flake8` to format your code changes.
36
+
37
+ ## License
38
+ By contributing to this repository, you agree that your contributions will be licensed
39
+ under the LICENSE file in the root directory of this source tree.
vjepa2/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
vjepa2/README.md ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning
2
+
3
+ ### [Meta FAIR](https://ai.meta.com/research/)
4
+
5
+ Mahmoud Assran∗, Adrien Bardes∗, David Fan∗, Quentin Garrido∗, Russell Howes∗, Mojtaba
6
+ Komeili∗, Matthew Muckley∗, Ammar Rizvi∗, Claire Roberts∗, Koustuv Sinha∗, Artem Zholus*,
7
+ Sergio Arnaud*, Abha Gejji*, Ada Martin*, Francois Robert Hogan*, Daniel Dugas*, Piotr
8
+ Bojanowski, Vasil Khalidov, Patrick Labatut, Francisco Massa, Marc Szafraniec, Kapil
9
+ Krishnakumar, Yong Li, Xiaodong Ma, Sarath Chandar, Franziska Meier*, Yann LeCun*, Michael
10
+ Rabbat*, Nicolas Ballas*
11
+
12
+ *Core Team
13
+
14
+ [[`Paper`](https://arxiv.org/abs/2506.09985)] [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)] [[`BibTex`](#Citation)]
15
+
16
+ Official Pytorch codebase for V-JEPA 2 and V-JEPA 2-AC.
17
+
18
+ V-JEPA 2 is a self-supervised approach to training video encoders, using internet-scale video data, that attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.
19
+
20
+ <p align="center">
21
+ <img src="assets/flowchart.png" width=100%>
22
+ </p>
23
+
24
+ <!---
25
+ ## Updates
26
+
27
+ * **[Jun-6-25]:** V-JEPA 2 is released. [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)]
28
+ --->
29
+
30
+ ## V-JEPA 2 Pre-training
31
+
32
+ **(Top)** The encoder and predictor are pre-trained through self-supervised learning from video using a masked latent feature prediction objective, leveraging abundant natural videos to bootstrap physical world understanding and prediction. **(Bottom)** Performance of V-JEPA 2 on downstream understanding and prediction tasks.
33
+
34
+ <img align="left" src="https://github.com/user-attachments/assets/914942d8-6a1e-409d-86ff-ff856b7346ab" width=65%>&nbsp;
35
+ <table>
36
+ <tr>
37
+ <th colspan="1">Benchmark</th>
38
+ <th colspan="1">VJEPA 2</th>
39
+ <th colspan="1">Previous Best</th>
40
+ </tr>
41
+ <tr>
42
+ <td>EK100</td>
43
+ <td>39.7%</td>
44
+ <td>27.6% (PlausiVL)</td>
45
+ </tr>
46
+ <tr>
47
+ <td>SSv2 (Probe)</td>
48
+ <td>77.3%</td>
49
+ <td>69.7% (InternVideo2-1B)</td>
50
+ </tr>
51
+ <tr>
52
+ <td>Diving48 (Probe)</td>
53
+ <td>90.2%</td>
54
+ <td>86.4% (InternVideo2-1B)</td>
55
+ </tr>
56
+ <tr>
57
+ <td>MVP (Video QA)</td>
58
+ <td>44.5%</td>
59
+ <td>39.9% (InternVL-2.5)</td>
60
+ </tr>
61
+ <tr>
62
+ <td>TempCompass (Video QA)</td>
63
+ <td>76.9%</td>
64
+ <td>75.3% (Tarsier 2)</td>
65
+ </tr>
66
+ </table>
67
+
68
+ ## V-JEPA 2-AC Post-training
69
+
70
+ **(Top)** After post-training with a small amount of robot data, we can deploy the model on a robot arm in new environments, and tackle foundational tasks like reaching, grasping, and pick-and-place by planning from image goals. **(Bottom)** Performance on robot manipulation tasks using a Franka arm, with input provided through a monocular RGB camera.
71
+
72
+ <img align="left" src="https://github.com/user-attachments/assets/c5d42221-0102-4216-911d-061a4369a805" width=65%>&nbsp;
73
+ <table>
74
+ <tr>
75
+ <th colspan="1"></th>
76
+ <th colspan="1"></th>
77
+ <th colspan="2">Grasp</th>
78
+ <th colspan="2">Pick-and-Place</th>
79
+ </tr>
80
+ <tr>
81
+ <th colspan="1">Method</th>
82
+ <th colspan="1">Reach</th>
83
+ <th colspan="1">Cup</th>
84
+ <th colspan="1">Box</th>
85
+ <th colspan="1">Cup</th>
86
+ <th colspan="1">Box</th>
87
+ </tr>
88
+ <tr>
89
+ <td>Octo</td>
90
+ <td>100%</td>
91
+ <td>10%</td>
92
+ <td>0%</td>
93
+ <td>10%</td>
94
+ <td>10%</td>
95
+ </tr>
96
+ <tr>
97
+ <td>Cosmos</td>
98
+ <td>80%</td>
99
+ <td>0%</td>
100
+ <td>20%</td>
101
+ <td>0%</td>
102
+ <td>0%</td>
103
+ </tr>
104
+ <tr>
105
+ <td>VJEPA 2-AC</td>
106
+ <td>100%</td>
107
+ <td>60%</td>
108
+ <td>20%</td>
109
+ <td>80%</td>
110
+ <td>50%</td>
111
+ </tr>
112
+ </table>
113
+
114
+ ## Models
115
+
116
+ ### V-JEPA 2
117
+
118
+ #### HuggingFace
119
+
120
+ See our [HuggingFace collection](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) for V-JEPA 2.
121
+
122
+ #### Pretrained Checkpoints
123
+
124
+ <table>
125
+ <tr>
126
+ <th colspan="1">Model</th>
127
+ <th colspan="1">#Parameters</th>
128
+ <th colspan="1">Resolution</th>
129
+ <th colspan="1">Download Link</th>
130
+ <th colspan="1">Pretraining Config</th>
131
+ </tr>
132
+ <tr>
133
+ <td>ViT-L/16</td>
134
+ <td>300M</td>
135
+ <td>256</td>
136
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitl.pt">checkpoint</a></td>
137
+ <td><a href="configs/train/vitl16">configs</a></td>
138
+ </tr>
139
+ <tr>
140
+ <td>ViT-H/16</td>
141
+ <td>600M</td>
142
+ <td>256</td>
143
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vith.pt">checkpoint</a></td>
144
+ <td><a href="configs/train/vith16/">configs</a></td>
145
+ </tr>
146
+ <tr>
147
+ <td>ViT-g/16</td>
148
+ <td>1B</td>
149
+ <td>256</td>
150
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitg.pt">checkpoint</a></td>
151
+ <td><a href="configs/train/vitg16">configs</a></td>
152
+ </tr>
153
+ <tr>
154
+ <td>ViT-g/16<sub>384</sub></td>
155
+ <td>1B</td>
156
+ <td>384</td>
157
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt">checkpoint</a></td>
158
+ <td><a href="configs/train/vitg16">configs</a></td>
159
+ </tr>
160
+ </table>
161
+
162
+ #### Pretrained backbones (via PyTorch Hub)
163
+
164
+ Please install [Pytorch](https://pytorch.org/get-started/locally/), [timm](https://pypi.org/project/timm/) and [einops](https://pypi.org/project/einops/) locally, then run the following to load each model. Installing Pytorch with CUDA support is strongly recommended.
165
+
166
+ ```python
167
+ import torch
168
+
169
+ # preprocessor
170
+ processor = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_preprocessor')
171
+ # models
172
+ vjepa2_vit_large = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_large')
173
+ vjepa2_vit_huge = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_huge')
174
+ vjepa2_vit_giant = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant')
175
+ vjepa2_vit_giant_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')
176
+
177
+ ```
178
+
179
+ #### Pretrained checkpoints on Huggingface
180
+
181
+ You can also use our pretrained checkpoints on [Huggingface](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6).
182
+
183
+ ```python
184
+ from transformers import AutoVideoProcessor, AutoModel
185
+
186
+ hf_repo = "facebook/vjepa2-vitg-fpc64-256"
187
+ # facebook/vjepa2-vitl-fpc64-256
188
+ # facebook/vjepa2-vith-fpc64-256
189
+ # facebook/vjepa2-vitg-fpc64-256
190
+ # facebook/vjepa2-vitg-fpc64-384
191
+
192
+
193
+ model = AutoModel.from_pretrained(hf_repo)
194
+ processor = AutoVideoProcessor.from_pretrained(hf_repo)
195
+ ```
196
+
197
+ #### Evaluation Attentive Probes
198
+
199
+ We share the trained attentive probes for two of our visual understanding evals (Something-Something v2 and Diving48) and the action anticipation eval EPIC-KITCHENS-100.
200
+
201
+ <table>
202
+ <tr>
203
+ <th colspan="1">Model</th>
204
+ <th colspan="4">SSv2</th>
205
+ <th colspan="4">Diving48</th>
206
+ <th colspan="4">EK100</th>
207
+ </tr>
208
+ <tr>
209
+ <th colspan="1"></th>
210
+ <th colspan="1">Checkpoint</th>
211
+ <th colspan="1">Training Config</th>
212
+ <th colspan="1">Inference Config</th>
213
+ <th colspan="1">Result</th>
214
+ <th colspan="1">Checkpoint</th>
215
+ <th colspan="1">Training Config</th>
216
+ <th colspan="1">Inference Config</th>
217
+ <th colspan="1">Result</th>
218
+ <th colspan="1">Checkpoint</th>
219
+ <th colspan="1">Training Config</th>
220
+ <th colspan="1">Inference Config</th>
221
+ <th colspan="1">Result</th>
222
+ </tr>
223
+ <tr>
224
+ <td>ViT-L/16</td>
225
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitl-16x2x3.pt">checkpoint</a></td>
226
+ <td><a href="configs/eval/vitl/ssv2.yaml">config</a></td>
227
+ <td><a href="configs/inference/vitl/ssv2.yaml">config</a></td>
228
+ <td>73.7%</td>
229
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitl-256.pt">checkpoint</a></td>
230
+ <td><a href="configs/eval/vitl/diving48.yaml">config</a></td>
231
+ <td><a href="configs/inference/vitl/diving48.yaml">config</a></td>
232
+ <td>89.0%</td>
233
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ek100-vitl-256.pt">checkpoint</a></td>
234
+ <td><a href="configs/eval/vitl/ek100.yaml">config</a></td>
235
+ <td><a href="configs/inference/vitl/ek100.yaml">config</a></td>
236
+ <td>32.7 R@5</td>
237
+ </tr>
238
+ <tr>
239
+ <td>ViT-g/16<sub>384</td>
240
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt">checkpoint</a></td>
241
+ <td><a href="configs/eval/vitg-384/ssv2.yaml">config</a></td>
242
+ <td><a href="configs/inference/vitg-384/ssv2.yaml">config</a></td>
243
+ <td>77.3%</td>
244
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitg-384-32x4x3.pt">checkpoint</a></td>
245
+ <td><a href="configs/eval/vitg-384/diving48.yaml">config</a></td>
246
+ <td><a href="configs/inference/vitg-384/diving48.yaml">config</a></td>
247
+ <td>90.2%</td>
248
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/evals/ek100-vitg-384.pt">checkpoint</a></td>
249
+ <td><a href="configs/eval/vitg-384/ek100.yaml">config</a></td>
250
+ <td><a href="configs/inference/vitg-384/ek100.yaml">config</a></td>
251
+ <td>39.7 R@5</td>
252
+ </tr>
253
+ </table>
254
+
255
+ ### V-JEPA 2-AC
256
+
257
+ Our action-conditioned checkpoint was trained from the ViT-g encoder.
258
+ <table>
259
+ <tr>
260
+ <th colspan="1">Model</th>
261
+ <th colspan="1">Download Link</th>
262
+ <th colspan="1">Training Config</th>
263
+ </tr>
264
+ <tr>
265
+ <td>ViT-g/16</td>
266
+ <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vjepa2-ac-vitg.pt">checkpoint</a></td>
267
+ <td><a href="configs/train/vitg16/droid-256px-8f.yaml">config</a></td>
268
+ </tr>
269
+ </table>
270
+
271
+ #### Pretrained action-conditioned backbone (via PyTorch Hub)
272
+
273
+ Please install [Pytorch](https://pytorch.org/get-started/locally/), [timm](https://pypi.org/project/timm/) and [einops](https://pypi.org/project/einops/) locally, then run the following to load each model. Installing Pytorch with CUDA support is strongly recommended.
274
+
275
+ ```python
276
+ import torch
277
+
278
+ vjepa2_encoder, vjepa2_ac_predictor = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_ac_vit_giant')
279
+ ```
280
+
281
+
282
+ See [energy_landscape_example.ipynb](notebooks/energy_landscape_example.ipynb) for an example notebook computing the energy landscape of the pretrained action-conditioned backbone using a robot trajectory collected from our lab.
283
+ To run this notebook, you'll need to additionally install [Jupyter](https://jupyter.org/install) and [Scipy](https://scipy.org/install/) in your conda environment.
284
+
285
+
286
+ ## Getting Started
287
+
288
+ ### Setup
289
+
290
+ ```
291
+ conda create -n vjepa2-312 python=3.12
292
+ conda activate vjepa2-312
293
+ pip install . # or `pip install -e .` for development mode
294
+ ```
295
+
296
+ **Note to macOS users:** V-JEPA 2 relies on [`decord`](https://github.com/dmlc/decord), which does not support macOS (and, unfortunately, is also no longer under development). In order to run the V-JEPA 2 code on macOS, you will need a different `decord` implementation. We do not make specific recommendations, although some users have reported the use of [`eva-decord`](https://github.com/georgia-tech-db/eva-decord) (see [PR 1](https://github.com/facebookresearch/vjepa2/pull/1)) or [`decord2`](https://github.com/johnnynunez/decord2) (see [PR 31](https://github.com/facebookresearch/vjepa2/pull/31)). We leave the selection of the `decord` package up to the user's discretion.
297
+
298
+ ### Usage Demo
299
+
300
+ See [vjepa2_demo.ipynb](notebooks/vjepa2_demo.ipynb) [(Colab Link)](https://colab.research.google.com/github/facebookresearch/vjepa2/blob/main/notebooks/vjepa2_demo.ipynb) or [vjepa2_demo.py](notebooks/vjepa2_demo.py) for an example of how to load both the HuggingFace and PyTorch V-JEPA 2 models and run inference on a sample video to get a sample classification result.
301
+
302
+ The script assumes the presence of downloaded model checkpoints so you will need to download the model weights and update the corresponding paths in the script. E.g.:
303
+ ```
304
+ wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P YOUR_DIR
305
+ wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P YOUR_DIR
306
+
307
+ # Then update your model paths in vjepa2_demo.py.
308
+ pt_model_path = YOUR_DIR/vitg-384.pt
309
+ classifier_model_path = YOUR_DIR/ssv2-vitg-384-64x2x3.pt
310
+
311
+ # Then run the script (assumes your machine has a GPU)
312
+ python -m notebooks.vjepa2_demo
313
+ ```
314
+
315
+ ### Probe-based evaluation
316
+
317
+ Probe-based evaluation consists in training an attentive probe on top of frozen V-JEPA 2 features. We provide training scripts for training your own probes, and checkpoints to run inference directly.
318
+
319
+ #### Training probes
320
+
321
+ Evaluations can be run either locally, or distributed via SLURM. (Running locally is useful for debugging and validation).
322
+ These sample commands launch Something-Something v2 video classification; other evals are launched by specifying the corresponding config.
323
+ Use provided training configs under "Evaluation Attentive Probes". These configs allow to train multiple probes in parallel with various optimization parameters.
324
+ Change filepaths as needed (e.g. `folder`, `checkpoint`, `dataset_train`, `dataset_val`) to match locations of data and downloaded checkpoints on your local filesystem.
325
+ Change \# nodes and local batch size as needed to not exceed available GPU memory.
326
+
327
+ ##### Local
328
+
329
+ To run locally, specify the GPUs to use on
330
+ ```
331
+ python -m evals.main --fname configs/eval/vitl16/ssv2.yaml \
332
+ --devices cuda:0 cuda:1
333
+ ```
334
+
335
+ ##### Distributed
336
+
337
+ ```
338
+ python -m evals.main_distributed \
339
+ --fname configs/eval/vitl/ssv2.yaml \
340
+ --time 8600 \
341
+ --account my_account --qos=my_qos
342
+ ```
343
+
344
+ #### Inference from existing probes
345
+
346
+ Use provided inference configs under [Evaluation Attentive Probes](#evaluation-attentive-probes).
347
+ Download the corresponding checkpoint, rename it to 'latest.pt', and create a folder with the checkpoint inside, with the format matching the variables in the config:
348
+ ```
349
+ [folder]/[eval_name]/[tag]/latest.pt
350
+ ```
351
+ Then run inference, locally or distributed, using the same evaluation commands as above, but with configs from `configs/inference`.
352
+
353
+ ### Pretraining
354
+
355
+ Likewise, training can also be run locally or distributed. Pretraining and cooldown training phases are
356
+ run with the same command using different configs.
357
+ These sample commands launch initial training of a ViT-L model. Configs for cooldown (or action-conditioned) training
358
+ can be found in the same directory as the config for initial training.
359
+
360
+ #### Local
361
+
362
+ ```
363
+ python -m app.main --fname configs/train/vitl16/pretrain-256px-16f.yaml \
364
+ --devices cuda:0
365
+ ```
366
+
367
+ #### Distributed
368
+
369
+ ```
370
+ python -m app.main_distributed \
371
+ --fname configs/train/vitl16/pretrain-256px-16f.yaml
372
+ --time 6000
373
+ --account my_account --qos=my_qos
374
+ ```
375
+
376
+ ### Postraining
377
+
378
+ Post-training of the action-conditioned model, starting from the pretrained VJEPA 2 backbone, also follows a similar interface, and can be run locally or distributed using [this config](configs/train/vitg16/droid-256px-8f.yaml).
379
+ We post-train the model starting from the ViT-g/16 backbone.
380
+
381
+ #### Local
382
+
383
+ ```
384
+ python -m app.main --fname configs/train/vitg16/droid-256px-8f.yaml \
385
+ --devices cuda:0
386
+ ```
387
+
388
+ #### Distributed
389
+
390
+ ```
391
+ python -m app.main_distributed \
392
+ --fname configs/train/vitg16/droid-256px-8f.yaml
393
+ --time 6000
394
+ --account my_account --qos=my_qos
395
+ ```
396
+
397
+
398
+ ## Code Structure
399
+
400
+ ```
401
+ .
402
+ ├── app # training loops
403
+ │ ├── vjepa # video JEPA pre-training
404
+ │ ├── vjepa_droid # training the action-conditioned model
405
+ │ ├── main_distributed.py # entrypoint for launch app on slurm cluster
406
+ │ └── main.py # entrypoint for launch app locally on your machine
407
+ ├── configs # config files with experiment params for training and evaluation
408
+ │ ├── train # pretraining (phase 1), cooldown (phase 2), and action-conditioned training
409
+ │ └── eval # frozen evaluations
410
+ ├── evals # evaluation loops training an attentive probe with frozen backbone...
411
+ │ ├── action_anticipation_frozen # action anticipation
412
+ │ ├── image_classification_frozen # image understanding
413
+ │ ├── video_classification_frozen # video understanding
414
+ │ ├── main_distributed.py # entrypoint for distributed evaluations
415
+ │ └── main.py # entrypoint for locally-run evaluations
416
+ ├── src # the package
417
+ │ ├── datasets # datasets, data loaders, ...
418
+ │ ├── models # model definitions
419
+ │ ├── masks # mask collators, masking utilities, ...
420
+ │ └── utils # shared utilities
421
+ ├── tests # unit tests for some modules in `src`
422
+
423
+ ```
424
+
425
+ ## License
426
+
427
+ The majority of V-JEPA 2 is licensed under MIT, however portions of the project are available under separate license terms:
428
+
429
+ [src/datasets/utils/video/randaugment.py](src/datasets/utils/video/randaugment.py)<br>
430
+ [src/datasets/utils/video/randerase.py](src/datasets/utils/video/randerase.py)<br>
431
+ [src/datasets/utils/worker_init_fn.py](src/datasets/utils/worker_init_fn.py)<br>
432
+
433
+ are licensed under the Apache 2.0 license.
434
+
435
+
436
+ ## Citation
437
+ If you find this repository useful in your research, please consider giving a star :star: and a citation
438
+ ```bibtex
439
+ @article{assran2025vjepa2,
440
+ title={V-JEPA~2: Self-Supervised Video Models Enable Understanding, Prediction and Planning},
441
+ author={Assran, Mahmoud and Bardes, Adrien and Fan, David and Garrido, Quentin and Howes, Russell and
442
+ Komeili, Mojtaba and Muckley, Matthew and Rizvi, Ammar and Roberts, Claire and Sinha, Koustuv and Zholus, Artem and
443
+ Arnaud, Sergio and Gejji, Abha and Martin, Ada and Robert Hogan, Francois and Dugas, Daniel and
444
+ Bojanowski, Piotr and Khalidov, Vasil and Labatut, Patrick and Massa, Francisco and Szafraniec, Marc and
445
+ Krishnakumar, Kapil and Li, Yong and Ma, Xiaodong and Chandar, Sarath and Meier, Franziska and LeCun, Yann and
446
+ Rabbat, Michael and Ballas, Nicolas},
447
+ journal={arXiv preprint arXiv:2506.09985},
448
+ year={2025}
449
+ }
450
+ ```
vjepa2/app/__init__.py ADDED
File without changes
vjepa2/app/main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import multiprocessing as mp
8
+ import pprint
9
+ from pathlib import Path
10
+
11
+ import yaml
12
+
13
+ from app.scaffold import main as app_main
14
+ from src.utils.distributed import init_distributed
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--fname", type=str, help="name of config file to load", default="configs.yaml")
18
+ parser.add_argument(
19
+ "--devices",
20
+ type=str,
21
+ nargs="+",
22
+ default=["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"],
23
+ help="which devices to use on local machine",
24
+ )
25
+ parser.add_argument(
26
+ "--debugmode",
27
+ type=bool,
28
+ default=False,
29
+ help="Setting this to true will not spin up new processes. "
30
+ "The main code runs the main process, which makes it easier to \
31
+ debug with checkpointing.",
32
+ )
33
+
34
+
35
+ def process_main(rank, fname, world_size, devices):
36
+ import os
37
+
38
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1])
39
+
40
+ import logging
41
+
42
+ from src.utils.logging import get_logger
43
+
44
+ logger = get_logger(force=True)
45
+ if rank == 0:
46
+ logger.setLevel(logging.INFO)
47
+ else:
48
+ logger.setLevel(logging.ERROR)
49
+
50
+ logger.info(f"called-params {fname}")
51
+
52
+ # Load config
53
+ params = None
54
+ with open(fname, "r") as y_file:
55
+ params = yaml.load(y_file, Loader=yaml.FullLoader)
56
+ logger.info("loaded params...")
57
+
58
+ # Log config
59
+ if rank == 0:
60
+ pprint.PrettyPrinter(indent=4).pprint(params)
61
+ folder = params["folder"]
62
+ params_path = os.path.join(folder, "params-pretrain.yaml")
63
+ folder = Path(folder)
64
+ folder.mkdir(parents=True, exist_ok=True)
65
+ with open(params_path, "w") as f:
66
+ yaml.dump(params, f)
67
+
68
+ # Init distributed (access to comm between GPUS on same machine)
69
+ world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
70
+ logger.info(f"Running... (rank: {rank}/{world_size})")
71
+
72
+ # Launch the app with loaded config
73
+ app_main(params["app"], args=params)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ args = parser.parse_args()
78
+ if args.debugmode:
79
+ process_main(rank=0, fname=args.fname, world_size=1, devices=["cuda:0"])
80
+ else:
81
+ num_gpus = len(args.devices)
82
+ mp.set_start_method("spawn")
83
+ for rank in range(num_gpus):
84
+ mp.Process(target=process_main, args=(rank, args.fname, num_gpus, args.devices)).start()
vjepa2/app/main_distributed.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import copy
8
+ import datetime
9
+ import os
10
+ import pprint
11
+ import shutil
12
+ from pathlib import Path
13
+
14
+ import submitit
15
+ import yaml
16
+
17
+ from app.scaffold import main as app_main
18
+ from src.utils.logging import get_logger, git_information
19
+
20
+ logger = get_logger(force=True)
21
+
22
+
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "--fname",
26
+ type=str,
27
+ help="yaml file containing config file names to launch",
28
+ default="configs.yaml",
29
+ )
30
+ parser.add_argument("--exclude", type=str, help="nodes to exclude from training", default=None)
31
+ parser.add_argument(
32
+ "--batch-launch",
33
+ action="store_true",
34
+ help="whether fname points to a file to batch-launch several config files",
35
+ )
36
+ parser.add_argument(
37
+ "--use_fname_as_folder",
38
+ action="store_true",
39
+ help="whether to append fname filename to folder",
40
+ )
41
+ parser.add_argument(
42
+ "--folder",
43
+ type=str,
44
+ default=None,
45
+ help="if specified, override 'folder' field in the .yaml with this",
46
+ )
47
+ parser.add_argument(
48
+ "--account",
49
+ type=str,
50
+ default="jepa",
51
+ help="Cluster account to use when submitting jobs",
52
+ )
53
+ parser.add_argument(
54
+ "--partition",
55
+ type=str,
56
+ default="learn",
57
+ help="Cluster partition to use when submitting jobs",
58
+ )
59
+ parser.add_argument(
60
+ "--qos",
61
+ type=str,
62
+ default=None,
63
+ help="If specified, cluster partition to use when submitting jobs",
64
+ )
65
+ parser.add_argument("--time", type=int, default=4300, help="time in minutes to run job")
66
+
67
+
68
+ class Trainer:
69
+ def __init__(self, args_pretrain, load_model=None):
70
+ self.app = args_pretrain["app"]
71
+ self.args_pretrain = args_pretrain
72
+ self.load_model = load_model
73
+
74
+ def __call__(self):
75
+ app = self.app
76
+ params = self.args_pretrain
77
+ load_model = self.load_model
78
+
79
+ logger.info("loaded pretrain params...")
80
+ pp = pprint.PrettyPrinter(indent=4)
81
+ pp.pprint(params)
82
+
83
+ # Launch app with loaded config
84
+ resume_preempt = False if load_model is None else load_model
85
+ app_main(app, args=params, resume_preempt=resume_preempt)
86
+
87
+ def checkpoint(self):
88
+ fb_trainer = Trainer(self.args_pretrain, True)
89
+ return submitit.helpers.DelayedSubmission(
90
+ fb_trainer,
91
+ )
92
+
93
+
94
+ def copy_code_folder(code_folder, ignore_patterns, ignore_paths):
95
+ path_to_node_folder = {}
96
+
97
+ for path in ignore_paths:
98
+ split_path = path.split("/")
99
+ base_path = "/".join(split_path[:-1])
100
+ node_folder = split_path[-1]
101
+ path_to_node_folder[base_path] = node_folder
102
+
103
+ def ignore_func(path, names):
104
+ ignore_list = ignore_patterns
105
+ if path in path_to_node_folder.keys():
106
+ ignore_list.append(path_to_node_folder[path])
107
+ return ignore_list
108
+
109
+ if not os.path.exists(code_folder):
110
+ shutil.copytree(".", code_folder, ignore=ignore_func)
111
+
112
+
113
+ def update_folder_with_timestamp(args_list):
114
+ new_args_list = copy.deepcopy(args_list)
115
+ for i, args in enumerate(args_list):
116
+ folder = args["folder"]
117
+ load_checkpoint = args["meta"].get("load_checkpoint", False) if "meta" in args else False
118
+ if not load_checkpoint and Path(folder).exists():
119
+ timestamp = datetime.datetime.now().strftime("%y_%m_%d_%H_%M_%S")
120
+ folder = folder.rstrip("/") + f"_{timestamp}"
121
+ logger.info(f"Folder already exists but `load_checkpoint` is False. Logging to new folder {folder}...")
122
+ new_args_list[i]["folder"] = folder
123
+ return new_args_list
124
+
125
+
126
+ def launch_app_with_parsed_args(
127
+ args_for_pretrain,
128
+ account,
129
+ partition,
130
+ qos,
131
+ mem_per_gpu="210G",
132
+ timeout=4300,
133
+ nodes=1,
134
+ tasks_per_node=1,
135
+ cpus_per_task=12,
136
+ exclude_nodes=None,
137
+ ):
138
+ args_for_pretrain = update_folder_with_timestamp(args_for_pretrain)
139
+ for ap in args_for_pretrain:
140
+ folder = ap["folder"]
141
+ Path(folder).mkdir(parents=True, exist_ok=True)
142
+ folder = args_for_pretrain[0]["folder"]
143
+
144
+ # -------------- Copy code --------------
145
+ code_folder = os.path.join(folder, "code")
146
+ ignore_patterns = [
147
+ "__pycache__",
148
+ ".vscode",
149
+ ".git",
150
+ "core",
151
+ ]
152
+ ignore_paths = [
153
+ "./evals/ava/alphaction/data",
154
+ "./demos",
155
+ "./traces",
156
+ ]
157
+ copy_code_folder(code_folder, ignore_patterns, ignore_paths)
158
+ os.chdir(code_folder)
159
+ # ---------------------------------------
160
+
161
+ # -------------- Save config file --------------
162
+ params_path = os.path.join(folder, "params-pretrain.yaml")
163
+ if not os.path.exists(params_path):
164
+ with open(params_path, "w") as f:
165
+ yaml.dump(args_for_pretrain, f)
166
+ # ----------------------------------------------
167
+
168
+ # -------------- Save git info file --------------
169
+ git_info_fpath = os.path.join(folder, "git-info.txt")
170
+ with open(git_info_fpath, "w") as f:
171
+ f.write(git_information())
172
+ # ----------------------------------------------
173
+
174
+ # -------------- SET JOB NAME --------------
175
+ folder_ = folder
176
+ if folder[-1] == "/":
177
+ folder_ = folder[:-1]
178
+ job_name = folder_.split("/")[-1]
179
+ # ------------------------------------------
180
+
181
+ executor = submitit.AutoExecutor(folder=os.path.join(folder, "job_%j"), slurm_max_num_timeout=20)
182
+ executor.update_parameters(
183
+ name=job_name,
184
+ slurm_partition=partition,
185
+ slurm_account=account,
186
+ slurm_qos=qos,
187
+ slurm_mem_per_gpu=mem_per_gpu,
188
+ timeout_min=timeout,
189
+ nodes=nodes,
190
+ tasks_per_node=tasks_per_node,
191
+ cpus_per_task=cpus_per_task,
192
+ gpus_per_node=tasks_per_node,
193
+ )
194
+
195
+ if exclude_nodes is not None:
196
+ executor.update_parameters(slurm_exclude=exclude_nodes)
197
+
198
+ jobs, trainers = [], []
199
+ with executor.batch():
200
+ for ap in args_for_pretrain:
201
+ # TODO Create sub folder and ap['folder']=subfolder
202
+ fb_trainer = Trainer(ap)
203
+ job = executor.submit(
204
+ fb_trainer,
205
+ )
206
+ trainers.append(fb_trainer)
207
+ jobs.append(job)
208
+
209
+ for job in jobs:
210
+ print(job.job_id)
211
+
212
+
213
+ def launch():
214
+ # ---------------------------------------------------------------------- #
215
+ # 1. Put config file names in a list
216
+ # ---------------------------------------------------------------------- #
217
+ config_fnames = [args.fname]
218
+
219
+ # -- If batch-launch is True, then the args.fname yaml file is not a
220
+ # -- config, but actually specifies a list of other config files
221
+ # -- to run in a slurm job array
222
+ if args.batch_launch:
223
+ with open(args.fname, "r") as y_file:
224
+ config_fnames = yaml.load(y_file, Loader=yaml.FullLoader)
225
+ # ---------------------------------------------------------------------- #
226
+
227
+ # ---------------------------------------------------------------------- #
228
+ # 2. Parse each yaml config file as a dict and place in list
229
+ # ---------------------------------------------------------------------- #
230
+ nodes, tasks_per_node = None, None
231
+ configs = []
232
+ for f in config_fnames:
233
+ with open(f, "r") as y_file:
234
+ _params = yaml.load(y_file, Loader=yaml.FullLoader)
235
+ if args.use_fname_as_folder:
236
+ assert not args.folder, "Don't specify --folder if adding fname to folder"
237
+ _params["folder"] = str(Path(_params["folder"]) / f.split("/")[-1].split(".yaml")[0])
238
+ elif args.folder:
239
+ _params["folder"] = args.folder
240
+ nodes = int(_params.get("nodes"))
241
+ tasks_per_node = int(_params.get("tasks_per_node"))
242
+ cpus_per_task = int(_params.get("cpus_per_task", 32))
243
+ mem_per_gpu = _params.get("mem_per_gpu", "210G")
244
+ configs += [_params]
245
+ logger.info(f"Loaded {len(configs)} config files")
246
+ logger.info(f"Running all jobs with {nodes=} / {tasks_per_node=}")
247
+ # ---------------------------------------------------------------------- #
248
+
249
+ # ---------------------------------------------------------------------- #
250
+ # 3. Launch evals with parsed config files
251
+ # ---------------------------------------------------------------------- #
252
+ launch_app_with_parsed_args(
253
+ args_for_pretrain=configs,
254
+ account=args.account,
255
+ partition=args.partition,
256
+ qos=args.qos,
257
+ mem_per_gpu=mem_per_gpu,
258
+ cpus_per_task=cpus_per_task,
259
+ timeout=args.time,
260
+ nodes=nodes,
261
+ tasks_per_node=tasks_per_node,
262
+ exclude_nodes=args.exclude,
263
+ )
264
+ # ---------------------------------------------------------------------- #
265
+
266
+
267
+ if __name__ == "__main__":
268
+ args = parser.parse_args()
269
+ launch()
vjepa2/app/scaffold.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import logging
8
+ import sys
9
+
10
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
11
+ logger = logging.getLogger()
12
+
13
+
14
+ def main(app, args, resume_preempt=False):
15
+
16
+ logger.info(f"Running pre-training of app: {app}")
17
+ return importlib.import_module(f"app.{app}.train").main(args=args, resume_preempt=resume_preempt)
vjepa2/app/vjepa/train.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+
8
+ # -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
9
+ try:
10
+ # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
11
+ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
12
+ # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
13
+ # -- TO EACH PROCESS
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"]
15
+ except Exception:
16
+ pass
17
+
18
+ import copy
19
+ import gc
20
+ import random
21
+ import time
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.multiprocessing as mp
26
+ import torch.nn.functional as F
27
+ from torch.nn.parallel import DistributedDataParallel
28
+
29
+ from app.vjepa.transforms import make_transforms
30
+ from app.vjepa.utils import init_opt, init_video_model, load_checkpoint
31
+ from src.datasets.data_manager import init_data
32
+ from src.masks.multiseq_multiblock3d import MaskCollator
33
+ from src.masks.utils import apply_masks
34
+ from src.utils.distributed import init_distributed
35
+ from src.utils.logging import AverageMeter, CSVLogger, get_logger, gpu_timer
36
+
37
+ # --
38
+ log_timings = True
39
+ log_freq = 10
40
+ CHECKPOINT_FREQ = 1
41
+ GARBAGE_COLLECT_ITR_FREQ = 50
42
+ # --
43
+
44
+ _GLOBAL_SEED = 0
45
+ random.seed(_GLOBAL_SEED)
46
+ np.random.seed(_GLOBAL_SEED)
47
+ torch.manual_seed(_GLOBAL_SEED)
48
+ torch.backends.cudnn.benchmark = True
49
+
50
+
51
+ logger = get_logger(__name__, force=True)
52
+
53
+
54
+ def main(args, resume_preempt=False):
55
+ # ----------------------------------------------------------------------- #
56
+ # PASSED IN PARAMS FROM CONFIG FILE
57
+ # ----------------------------------------------------------------------- #
58
+
59
+ # -- META
60
+ folder = args.get("folder")
61
+ cfgs_meta = args.get("meta")
62
+ load_model = cfgs_meta.get("load_checkpoint") or resume_preempt
63
+ r_file = cfgs_meta.get("read_checkpoint", None)
64
+ seed = cfgs_meta.get("seed", _GLOBAL_SEED)
65
+ save_every_freq = cfgs_meta.get("save_every_freq", -1)
66
+ skip_batches = cfgs_meta.get("skip_batches", -1)
67
+ use_sdpa = cfgs_meta.get("use_sdpa", False)
68
+ sync_gc = cfgs_meta.get("sync_gc", False)
69
+ which_dtype = cfgs_meta.get("dtype")
70
+ logger.info(f"{which_dtype=}")
71
+ if which_dtype.lower() == "bfloat16":
72
+ dtype = torch.bfloat16
73
+ mixed_precision = True
74
+ elif which_dtype.lower() == "float16":
75
+ dtype = torch.float16
76
+ mixed_precision = True
77
+ else:
78
+ dtype = torch.float32
79
+ mixed_precision = False
80
+
81
+ # -- MASK
82
+ cfgs_mask = args.get("mask")
83
+
84
+ # -- MODEL
85
+ cfgs_model = args.get("model")
86
+ compile_model = cfgs_model.get("compile_model", False)
87
+ use_activation_checkpointing = cfgs_model.get("use_activation_checkpointing", False)
88
+ model_name = cfgs_model.get("model_name")
89
+ pred_depth = cfgs_model.get("pred_depth")
90
+ pred_num_heads = cfgs_model.get("pred_num_heads", None)
91
+ pred_embed_dim = cfgs_model.get("pred_embed_dim")
92
+ uniform_power = cfgs_model.get("uniform_power", False)
93
+ use_mask_tokens = cfgs_model.get("use_mask_tokens", False)
94
+ zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True)
95
+ use_rope = cfgs_model.get("use_rope", False)
96
+ use_silu = cfgs_model.get("use_silu", False)
97
+ use_pred_silu = cfgs_model.get("use_pred_silu", False)
98
+ wide_silu = cfgs_model.get("wide_silu", True)
99
+
100
+ # -- DATA
101
+ cfgs_data = args.get("data")
102
+ dataset_type = cfgs_data.get("dataset_type", "videodataset")
103
+ dataset_paths = cfgs_data.get("datasets", [])
104
+ datasets_weights = cfgs_data.get("datasets_weights")
105
+ dataset_fpcs = cfgs_data.get("dataset_fpcs")
106
+ max_num_frames = max(dataset_fpcs)
107
+ if datasets_weights is not None:
108
+ assert len(datasets_weights) == len(dataset_paths), "Must have one sampling weight specified for each dataset"
109
+ batch_size = cfgs_data.get("batch_size")
110
+ tubelet_size = cfgs_data.get("tubelet_size")
111
+ fps = cfgs_data.get("fps")
112
+ crop_size = cfgs_data.get("crop_size", 224)
113
+ patch_size = cfgs_data.get("patch_size")
114
+ pin_mem = cfgs_data.get("pin_mem", False)
115
+ num_workers = cfgs_data.get("num_workers", 1)
116
+ persistent_workers = cfgs_data.get("persistent_workers", True)
117
+
118
+ # -- DATA AUGS
119
+ cfgs_data_aug = args.get("data_aug")
120
+ ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3])
121
+ rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0])
122
+ motion_shift = cfgs_data_aug.get("motion_shift", False)
123
+ reprob = cfgs_data_aug.get("reprob", 0.0)
124
+ use_aa = cfgs_data_aug.get("auto_augment", False)
125
+
126
+ # -- LOSS
127
+ cfgs_loss = args.get("loss")
128
+ loss_exp = cfgs_loss.get("loss_exp")
129
+
130
+ # -- OPTIMIZATION
131
+ cfgs_opt = args.get("optimization")
132
+ is_anneal = cfgs_opt.get("is_anneal", False)
133
+ anneal_ckpt = cfgs_opt.get("anneal_ckpt", None)
134
+ if is_anneal and anneal_ckpt is None:
135
+ raise ValueError("Must specify anneal_ckpt if is_anneal is True")
136
+ resume_anneal = cfgs_opt.get("resume_anneal", False) or (is_anneal and resume_preempt)
137
+ ipe = cfgs_opt.get("ipe", None)
138
+ ipe_scale = cfgs_opt.get("ipe_scale", 1.0)
139
+ wd = float(cfgs_opt.get("weight_decay"))
140
+ final_wd = float(cfgs_opt.get("final_weight_decay"))
141
+ num_epochs = cfgs_opt.get("epochs")
142
+ warmup = cfgs_opt.get("warmup")
143
+ start_lr = cfgs_opt.get("start_lr")
144
+ lr = cfgs_opt.get("lr")
145
+ final_lr = cfgs_opt.get("final_lr")
146
+ ema = cfgs_opt.get("ema")
147
+ betas = cfgs_opt.get("betas", (0.9, 0.999))
148
+ eps = cfgs_opt.get("eps", 1.0e-8)
149
+ # ----------------------------------------------------------------------- #
150
+ # ----------------------------------------------------------------------- #
151
+
152
+ np.random.seed(seed)
153
+ torch.manual_seed(seed)
154
+ torch.backends.cudnn.benchmark = True
155
+ try:
156
+ mp.set_start_method("spawn")
157
+ except Exception:
158
+ pass
159
+
160
+ # -- init torch distributed backend
161
+ world_size, rank = init_distributed()
162
+ logger.info(f"Initialized (rank/world-size) {rank}/{world_size}")
163
+
164
+ # -- set device
165
+ if not torch.cuda.is_available():
166
+ device = torch.device("cpu")
167
+ else:
168
+ device = torch.device("cuda:0")
169
+ torch.cuda.set_device(device)
170
+
171
+ # -- log/checkpointing paths
172
+ log_file = os.path.join(folder, f"log_r{rank}.csv")
173
+ latest_file = "latest.pt"
174
+ latest_path = os.path.join(folder, latest_file)
175
+ load_path = None
176
+ if load_model:
177
+ if is_anneal:
178
+ if os.path.exists(latest_path) and resume_anneal:
179
+ load_path = latest_path
180
+ else:
181
+ load_path = anneal_ckpt
182
+ resume_anneal = False
183
+ else:
184
+ load_path = r_file if r_file is not None else latest_path
185
+ if not os.path.exists(load_path):
186
+ load_path = None
187
+ load_model = False
188
+
189
+ # -- make csv_logger
190
+ csv_logger = CSVLogger(
191
+ log_file,
192
+ ("%d", "epoch"),
193
+ ("%d", "itr"),
194
+ ("%.5f", "loss"),
195
+ ("%d", "iter-time(ms)"),
196
+ ("%d", "gpu-time(ms)"),
197
+ ("%d", "dataload-time(ms)"),
198
+ )
199
+
200
+ # -- init model
201
+ encoder, predictor = init_video_model(
202
+ uniform_power=uniform_power,
203
+ use_mask_tokens=use_mask_tokens,
204
+ num_mask_tokens=int(len(cfgs_mask) * len(dataset_fpcs)),
205
+ zero_init_mask_tokens=zero_init_mask_tokens,
206
+ device=device,
207
+ patch_size=patch_size,
208
+ max_num_frames=max_num_frames,
209
+ tubelet_size=tubelet_size,
210
+ model_name=model_name,
211
+ crop_size=crop_size,
212
+ pred_depth=pred_depth,
213
+ pred_num_heads=pred_num_heads,
214
+ pred_embed_dim=pred_embed_dim,
215
+ use_sdpa=use_sdpa,
216
+ use_silu=use_silu,
217
+ use_pred_silu=use_pred_silu,
218
+ wide_silu=wide_silu,
219
+ use_rope=use_rope,
220
+ use_activation_checkpointing=use_activation_checkpointing,
221
+ )
222
+ target_encoder = copy.deepcopy(encoder)
223
+
224
+ if compile_model:
225
+ logger.info("Compiling encoder, target_encoder, and predictor.")
226
+ torch._dynamo.config.optimize_ddp = False
227
+ encoder.compile()
228
+ target_encoder.compile()
229
+ predictor.compile()
230
+
231
+ mask_collator = MaskCollator(
232
+ cfgs_mask=cfgs_mask,
233
+ dataset_fpcs=dataset_fpcs,
234
+ crop_size=crop_size,
235
+ patch_size=patch_size,
236
+ tubelet_size=tubelet_size,
237
+ )
238
+ transform = make_transforms(
239
+ random_horizontal_flip=True,
240
+ random_resize_aspect_ratio=ar_range,
241
+ random_resize_scale=rr_scale,
242
+ reprob=reprob,
243
+ auto_augment=use_aa,
244
+ motion_shift=motion_shift,
245
+ crop_size=crop_size,
246
+ )
247
+
248
+ # -- init data-loaders/samplers
249
+ (unsupervised_loader, unsupervised_sampler) = init_data(
250
+ data=dataset_type,
251
+ root_path=dataset_paths,
252
+ batch_size=batch_size,
253
+ training=True,
254
+ dataset_fpcs=dataset_fpcs,
255
+ fps=fps,
256
+ transform=transform,
257
+ rank=rank,
258
+ world_size=world_size,
259
+ datasets_weights=datasets_weights,
260
+ persistent_workers=persistent_workers,
261
+ collator=mask_collator,
262
+ num_workers=num_workers,
263
+ pin_mem=pin_mem,
264
+ log_dir=None,
265
+ )
266
+ try:
267
+ _dlen = len(unsupervised_loader)
268
+ except Exception: # Different interface for webdataset
269
+ _dlen = unsupervised_loader.num_batches
270
+ if ipe is None:
271
+ ipe = _dlen
272
+ logger.info(f"iterations per epoch/dataset length: {ipe}/{_dlen}")
273
+
274
+ # -- init optimizer and scheduler
275
+ optimizer, scaler, scheduler, wd_scheduler = init_opt(
276
+ is_anneal=is_anneal,
277
+ encoder=encoder,
278
+ predictor=predictor,
279
+ wd=wd,
280
+ final_wd=final_wd,
281
+ start_lr=start_lr,
282
+ ref_lr=lr,
283
+ final_lr=final_lr,
284
+ iterations_per_epoch=ipe,
285
+ warmup=warmup,
286
+ num_epochs=num_epochs,
287
+ ipe_scale=ipe_scale,
288
+ mixed_precision=mixed_precision,
289
+ betas=betas,
290
+ eps=eps,
291
+ )
292
+ encoder = DistributedDataParallel(encoder, static_graph=True, find_unused_parameters=False)
293
+ predictor = DistributedDataParallel(predictor, static_graph=False, find_unused_parameters=False)
294
+ target_encoder = DistributedDataParallel(target_encoder, static_graph=True, find_unused_parameters=False)
295
+ for p in target_encoder.parameters():
296
+ p.requires_grad = False
297
+
298
+ # -- momentum schedule
299
+ momentum_scheduler = (
300
+ ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)
301
+ for i in range(int(ipe * num_epochs * ipe_scale) + 1)
302
+ )
303
+
304
+ start_epoch = 0
305
+ # -- load training checkpoint
306
+ if load_model or os.path.exists(latest_path):
307
+ (
308
+ encoder,
309
+ predictor,
310
+ target_encoder,
311
+ optimizer,
312
+ scaler,
313
+ start_epoch,
314
+ ) = load_checkpoint(
315
+ r_path=load_path,
316
+ encoder=encoder,
317
+ predictor=predictor,
318
+ target_encoder=target_encoder,
319
+ opt=optimizer,
320
+ scaler=scaler,
321
+ is_anneal=is_anneal and not resume_anneal,
322
+ )
323
+ if not is_anneal or resume_anneal:
324
+ for _ in range(start_epoch * ipe):
325
+ scheduler.step()
326
+ wd_scheduler.step()
327
+ next(momentum_scheduler)
328
+ mask_collator.step()
329
+
330
+ def save_checkpoint(epoch, path):
331
+ if rank != 0:
332
+ return
333
+ save_dict = {
334
+ "encoder": encoder.state_dict(),
335
+ "predictor": predictor.state_dict(),
336
+ "opt": optimizer.state_dict(),
337
+ "scaler": None if scaler is None else scaler.state_dict(),
338
+ "target_encoder": target_encoder.state_dict(),
339
+ "epoch": epoch,
340
+ "loss": loss_meter.avg,
341
+ "batch_size": batch_size,
342
+ "world_size": world_size,
343
+ "lr": lr,
344
+ }
345
+ try:
346
+ torch.save(save_dict, path)
347
+ except Exception as e:
348
+ logger.info(f"Encountered exception when saving checkpoint: {e}")
349
+
350
+ logger.info("Initializing loader...")
351
+ unsupervised_sampler.set_epoch(start_epoch)
352
+ loader = iter(unsupervised_loader)
353
+
354
+ if skip_batches > 0:
355
+ logger.info(f"Skip {skip_batches} batches")
356
+ # -- update distributed-data-loader epoch
357
+
358
+ for itr in range(skip_batches):
359
+ if itr % 10 == 0:
360
+ logger.info(f"Skip {itr}/{skip_batches} batches")
361
+ try:
362
+ _ = next(loader)
363
+ except Exception:
364
+ loader = iter(unsupervised_loader)
365
+ _ = next(loader)
366
+
367
+ if sync_gc:
368
+ gc.disable()
369
+ gc.collect()
370
+
371
+ # -- TRAINING LOOP
372
+ for epoch in range(start_epoch, num_epochs):
373
+ logger.info("Epoch %d" % (epoch + 1))
374
+
375
+ loss_meter = AverageMeter()
376
+ mask_meters = {fpc: AverageMeter() for fpc in dataset_fpcs}
377
+ iter_time_meter = AverageMeter()
378
+ gpu_time_meter = AverageMeter()
379
+ data_elapsed_time_meter = AverageMeter()
380
+
381
+ for itr in range(ipe):
382
+ itr_start_time = time.time()
383
+
384
+ iter_retries = 0
385
+ iter_successful = False
386
+ while not iter_successful:
387
+ try:
388
+ sample = next(loader)
389
+ iter_successful = True
390
+ except StopIteration:
391
+ logger.info("Exhausted data loaders. Refreshing...")
392
+ unsupervised_sampler.set_epoch(epoch)
393
+ loader = iter(unsupervised_loader)
394
+ except Exception as e:
395
+ NUM_RETRIES = 5
396
+ if iter_retries < NUM_RETRIES:
397
+ logger.warning(f"Encountered exception when loading data (num retries {iter_retries}):\n{e}")
398
+ iter_retries += 1
399
+ time.sleep(5)
400
+ else:
401
+ logger.warning(f"Exceeded max retries ({NUM_RETRIES}) when loading data. Skipping batch.")
402
+ raise e
403
+
404
+ for _fpc_sample in sample:
405
+ bs, fpc = _fpc_sample[0][-1][0].size()
406
+ mask_meters[fpc].update(bs / batch_size)
407
+
408
+ def load_clips():
409
+ all_clips, all_masks_enc, all_masks_pred = [], [], []
410
+ for fpc_sample in sample:
411
+ udata, masks_enc, masks_pred = fpc_sample
412
+ all_clips += [udata[0][0].to(device, non_blocking=True)]
413
+ all_masks_enc += [[m.to(device, non_blocking=True) for m in masks_enc]]
414
+ all_masks_pred += [[m.to(device, non_blocking=True) for m in masks_pred]]
415
+ return all_clips, all_masks_enc, all_masks_pred
416
+
417
+ clips, masks_enc, masks_pred = load_clips()
418
+ data_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
419
+
420
+ if sync_gc and (itr + 1) % GARBAGE_COLLECT_ITR_FREQ == 0:
421
+ logger.info("Running garbage collection...")
422
+ gc.collect()
423
+
424
+ def train_step():
425
+ _new_lr = scheduler.step()
426
+ _new_wd = wd_scheduler.step()
427
+ # --
428
+
429
+ def forward_target(c):
430
+ with torch.no_grad():
431
+ h = target_encoder(c)
432
+ h = [F.layer_norm(hi, (hi.size(-1),)) for hi in h]
433
+ return h
434
+
435
+ def forward_context(c):
436
+ z = encoder(c, masks_enc)
437
+ z = predictor(z, masks_enc, masks_pred)
438
+ return z
439
+
440
+ def loss_fn(z, h):
441
+ # Assumption: predictor will have returned only masked tokens for z
442
+ h = [apply_masks(hi, mi, concat=False) for hi, mi in zip(h, masks_pred)]
443
+
444
+ loss, n = 0, 0
445
+ for zi, hi in zip(z, h):
446
+ for zij, hij in zip(zi, hi):
447
+ loss += torch.mean(torch.abs(zij - hij) ** loss_exp) / loss_exp
448
+ n += 1
449
+ loss /= n
450
+ return loss
451
+
452
+ # Step 1. Forward
453
+ with torch.amp.autocast('cuda', dtype=dtype, enabled=mixed_precision):
454
+ h = forward_target(clips)
455
+ z = forward_context(clips)
456
+ loss = loss_fn(z, h) # jepa prediction loss
457
+
458
+ # Step 2. Backward & step
459
+ if mixed_precision:
460
+ scaler.scale(loss).backward()
461
+ scaler.unscale_(optimizer)
462
+ else:
463
+ loss.backward()
464
+ if mixed_precision:
465
+ scaler.step(optimizer)
466
+ scaler.update()
467
+ else:
468
+ optimizer.step()
469
+ optimizer.zero_grad()
470
+
471
+ # Step 3. momentum update of target encoder
472
+ m = next(momentum_scheduler)
473
+ with torch.no_grad():
474
+ params_k = []
475
+ params_q = []
476
+ for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
477
+ params_k.append(param_k)
478
+ params_q.append(param_q)
479
+ torch._foreach_mul_(params_k, m)
480
+ torch._foreach_add_(params_k, params_q, alpha=1 - m)
481
+
482
+ return (
483
+ loss.detach().item(),
484
+ _new_lr,
485
+ _new_wd,
486
+ )
487
+
488
+ (
489
+ loss,
490
+ _new_lr,
491
+ _new_wd,
492
+ ), gpu_etime_ms = gpu_timer(train_step)
493
+ iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
494
+ loss_meter.update(loss)
495
+ iter_time_meter.update(iter_elapsed_time_ms)
496
+ gpu_time_meter.update(gpu_etime_ms)
497
+ data_elapsed_time_meter.update(data_elapsed_time_ms)
498
+
499
+ # -- Logging
500
+ def log_stats():
501
+ csv_logger.log(epoch + 1, itr, loss, iter_elapsed_time_ms, gpu_etime_ms, data_elapsed_time_ms)
502
+ if (itr % log_freq == 0) or (itr == ipe - 1) or np.isnan(loss) or np.isinf(loss):
503
+ logger.info(
504
+ "[%d, %5d] loss: %.3f "
505
+ "masks: %s "
506
+ "[wd: %.2e] [lr: %.2e] "
507
+ "[mem: %.2e] "
508
+ "[iter: %.1f ms] "
509
+ "[gpu: %.1f ms] "
510
+ "[data: %.1f ms]"
511
+ % (
512
+ epoch + 1,
513
+ itr,
514
+ loss_meter.avg,
515
+ "[" + ", ".join([f"{k}: " + "%.1f" % mask_meters[k].avg for k in mask_meters]) + "]",
516
+ _new_wd,
517
+ _new_lr,
518
+ torch.cuda.max_memory_allocated() / 1024.0**2,
519
+ iter_time_meter.avg,
520
+ gpu_time_meter.avg,
521
+ data_elapsed_time_meter.avg,
522
+ )
523
+ )
524
+
525
+ log_stats()
526
+ assert not np.isnan(loss), "loss is nan"
527
+
528
+ # -- Save Checkpoint
529
+ logger.info("avg. loss %.3f" % loss_meter.avg)
530
+ # -- Save Last
531
+ if epoch % CHECKPOINT_FREQ == 0 or epoch == (num_epochs - 1):
532
+ save_checkpoint(epoch + 1, latest_path)
533
+ if save_every_freq > 0 and epoch % save_every_freq == 0:
534
+ save_every_file = f"e{epoch}.pt"
535
+ save_every_path = os.path.join(folder, save_every_file)
536
+ save_checkpoint(epoch + 1, save_every_path)
vjepa2/app/vjepa/transforms.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+
9
+ import src.datasets.utils.video.transforms as video_transforms
10
+ from src.datasets.utils.video.randerase import RandomErasing
11
+
12
+
13
+ def make_transforms(
14
+ random_horizontal_flip=True,
15
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
16
+ random_resize_scale=(0.3, 1.0),
17
+ reprob=0.0,
18
+ auto_augment=False,
19
+ motion_shift=False,
20
+ crop_size=224,
21
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
22
+ ):
23
+
24
+ _frames_augmentation = VideoTransform(
25
+ random_horizontal_flip=random_horizontal_flip,
26
+ random_resize_aspect_ratio=random_resize_aspect_ratio,
27
+ random_resize_scale=random_resize_scale,
28
+ reprob=reprob,
29
+ auto_augment=auto_augment,
30
+ motion_shift=motion_shift,
31
+ crop_size=crop_size,
32
+ normalize=normalize,
33
+ )
34
+ return _frames_augmentation
35
+
36
+
37
+ class VideoTransform(object):
38
+
39
+ def __init__(
40
+ self,
41
+ random_horizontal_flip=True,
42
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
43
+ random_resize_scale=(0.3, 1.0),
44
+ reprob=0.0,
45
+ auto_augment=False,
46
+ motion_shift=False,
47
+ crop_size=224,
48
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
49
+ ):
50
+
51
+ self.random_horizontal_flip = random_horizontal_flip
52
+ self.random_resize_aspect_ratio = random_resize_aspect_ratio
53
+ self.random_resize_scale = random_resize_scale
54
+ self.auto_augment = auto_augment
55
+ self.motion_shift = motion_shift
56
+ self.crop_size = crop_size
57
+ self.mean = torch.tensor(normalize[0], dtype=torch.float32)
58
+ self.std = torch.tensor(normalize[1], dtype=torch.float32)
59
+ if not self.auto_augment:
60
+ # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
61
+ self.mean *= 255.0
62
+ self.std *= 255.0
63
+
64
+ self.autoaug_transform = video_transforms.create_random_augment(
65
+ input_size=(crop_size, crop_size),
66
+ # auto_augment="rand-m4-n4-w1-mstd0.5-inc1",
67
+ auto_augment="rand-m7-n4-mstd0.5-inc1",
68
+ interpolation="bicubic",
69
+ )
70
+
71
+ self.spatial_transform = (
72
+ video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
73
+ )
74
+
75
+ self.reprob = reprob
76
+ self.erase_transform = RandomErasing(
77
+ reprob,
78
+ mode="pixel",
79
+ max_count=1,
80
+ num_splits=1,
81
+ device="cpu",
82
+ )
83
+
84
+ def __call__(self, buffer):
85
+
86
+ if self.auto_augment:
87
+ buffer = [transforms.ToPILImage()(frame) for frame in buffer]
88
+ buffer = self.autoaug_transform(buffer)
89
+ buffer = [transforms.ToTensor()(img) for img in buffer]
90
+ buffer = torch.stack(buffer) # T C H W
91
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
92
+ elif torch.is_tensor(buffer):
93
+ # TODO: ensure input is always a tensor?
94
+ buffer = buffer.to(torch.float32)
95
+ else:
96
+ buffer = torch.tensor(buffer, dtype=torch.float32)
97
+
98
+ buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
99
+
100
+ buffer = self.spatial_transform(
101
+ images=buffer,
102
+ target_height=self.crop_size,
103
+ target_width=self.crop_size,
104
+ scale=self.random_resize_scale,
105
+ ratio=self.random_resize_aspect_ratio,
106
+ )
107
+ if self.random_horizontal_flip:
108
+ buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
109
+
110
+ buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
111
+ if self.reprob > 0:
112
+ buffer = buffer.permute(1, 0, 2, 3)
113
+ buffer = self.erase_transform(buffer)
114
+ buffer = buffer.permute(1, 0, 2, 3)
115
+
116
+ return buffer
117
+
118
+
119
+ def tensor_normalize(tensor, mean, std):
120
+ """
121
+ Normalize a given tensor by subtracting the mean and dividing the std.
122
+ Args:
123
+ tensor (tensor): tensor to normalize.
124
+ mean (tensor or list): mean value to subtract.
125
+ std (tensor or list): std to divide.
126
+ """
127
+ if tensor.dtype == torch.uint8:
128
+ tensor = tensor.float()
129
+ tensor = tensor / 255.0
130
+ if isinstance(mean, list):
131
+ mean = torch.tensor(mean)
132
+ if isinstance(std, list):
133
+ std = torch.tensor(std)
134
+ tensor = tensor - mean
135
+ tensor = tensor / std
136
+ return tensor
137
+
138
+
139
+ def _tensor_normalize_inplace(tensor, mean, std):
140
+ """
141
+ Normalize a given tensor by subtracting the mean and dividing the std.
142
+ Args:
143
+ tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
144
+ mean (tensor): mean value to subtract (in 0 to 255 floats).
145
+ std (tensor): std to divide (in 0 to 255 floats).
146
+ """
147
+ if tensor.dtype == torch.uint8:
148
+ tensor = tensor.float()
149
+
150
+ C, T, H, W = tensor.shape
151
+ tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
152
+ tensor.sub_(mean).div_(std)
153
+ tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
154
+ return tensor
vjepa2/app/vjepa/utils.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import sys
8
+ import warnings
9
+
10
+ import torch
11
+ import yaml
12
+
13
+ import src.models.predictor as vit_pred
14
+ import src.models.vision_transformer as video_vit
15
+ from src.utils.checkpoint_loader import robust_checkpoint_loader
16
+ from src.utils.schedulers import CosineWDSchedule, LinearDecaySchedule, WarmupCosineSchedule
17
+ from src.utils.wrappers import MultiSeqWrapper, PredictorMultiSeqWrapper
18
+
19
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
20
+ logger = logging.getLogger()
21
+
22
+ MAX_RETRIES = 3
23
+
24
+
25
+ def build_eval_args(
26
+ model_name,
27
+ patch_size,
28
+ tubelet_size,
29
+ num_frames,
30
+ logging_folder,
31
+ checkpoint,
32
+ write_tag,
33
+ eval_cfg_paths,
34
+ uniform_power=False,
35
+ use_sdpa=False,
36
+ clip_duration=None,
37
+ use_silu=False,
38
+ wide_silu=True,
39
+ tag=None,
40
+ ):
41
+ """
42
+ Helper function to parse the pre-training configs to construct the
43
+ evaluation configs, return as a list of eval configs.
44
+ """
45
+ # By convention, the pre-training config should specify any required evals
46
+ # in the 'evals' key
47
+ if eval_cfg_paths is None:
48
+ logger.info("No evaluations specified!")
49
+ return
50
+
51
+ eval_nodes = None
52
+ eval_tasks_per_node = None
53
+ args_eval = []
54
+ for i, f in enumerate(eval_cfg_paths):
55
+ with open(f, "r") as y_file:
56
+ _args = yaml.load(y_file, Loader=yaml.FullLoader)
57
+ _tag = _args.get("tag", "")
58
+ _args["tag"] = f"{tag}-{_tag}"
59
+ _nodes = _args.get("nodes", None)
60
+ _tasks = _args.get("tasks_per_node", 8)
61
+ eval_nodes = _nodes if eval_nodes is None else eval_nodes
62
+ eval_tasks_per_node = _tasks if eval_tasks_per_node is None else eval_tasks_per_node
63
+ if (eval_nodes != _nodes) or (eval_tasks_per_node != _tasks):
64
+ warnings.warn("Configs for online evals must use same number of nodes for slurm-batch processing")
65
+
66
+ # Model params
67
+ _args["pretrain"] = {}
68
+ _args["pretrain"]["model_name"] = model_name
69
+ _args["pretrain"]["patch_size"] = patch_size
70
+ _args["pretrain"]["tubelet_size"] = tubelet_size
71
+ _args["pretrain"]["uniform_power"] = uniform_power
72
+ _args["pretrain"]["use_sdpa"] = use_sdpa
73
+ _args["pretrain"]["clip_duration"] = clip_duration
74
+ _args["pretrain"]["use_silu"] = use_silu
75
+ _args["pretrain"]["wide_silu"] = wide_silu
76
+
77
+ # Data params
78
+ _args["pretrain"]["frames_per_clip"] = num_frames
79
+
80
+ # Misc
81
+ _args["pretrain"]["folder"] = logging_folder
82
+ _args["pretrain"]["checkpoint"] = checkpoint
83
+ _args["pretrain"]["write_tag"] = write_tag
84
+
85
+ args_eval += [_args]
86
+
87
+ return eval_nodes, eval_tasks_per_node, args_eval
88
+
89
+
90
+ def load_checkpoint(
91
+ r_path,
92
+ encoder,
93
+ predictor,
94
+ target_encoder,
95
+ opt,
96
+ scaler,
97
+ is_anneal=False,
98
+ ):
99
+ logger.info(f"Loading checkpoint from {r_path}")
100
+ checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
101
+
102
+ epoch = 0
103
+ if not is_anneal:
104
+ epoch = checkpoint["epoch"]
105
+
106
+ # -- loading encoder
107
+ pretrained_dict = checkpoint["encoder"]
108
+ msg = encoder.load_state_dict(pretrained_dict)
109
+ logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")
110
+
111
+ # -- loading predictor
112
+ pretrained_dict = checkpoint["predictor"]
113
+ msg = predictor.load_state_dict(pretrained_dict)
114
+ logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")
115
+
116
+ # -- loading target_encoder
117
+ if target_encoder is not None:
118
+ print(list(checkpoint.keys()))
119
+ pretrained_dict = checkpoint["target_encoder"]
120
+ msg = target_encoder.load_state_dict(pretrained_dict)
121
+ logger.info(f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}")
122
+
123
+ # -- loading optimizer
124
+ opt.load_state_dict(checkpoint["opt"])
125
+ if scaler is not None:
126
+ scaler.load_state_dict(checkpoint["scaler"])
127
+ logger.info(f"loaded optimizers from epoch {epoch}")
128
+ logger.info(f"read-path: {r_path}")
129
+ del checkpoint
130
+
131
+ return (
132
+ encoder,
133
+ predictor,
134
+ target_encoder,
135
+ opt,
136
+ scaler,
137
+ epoch,
138
+ )
139
+
140
+
141
+ def init_video_model(
142
+ device,
143
+ patch_size=16,
144
+ max_num_frames=16,
145
+ tubelet_size=2,
146
+ model_name="vit_base",
147
+ crop_size=224,
148
+ pred_depth=6,
149
+ pred_num_heads=None,
150
+ pred_embed_dim=384,
151
+ uniform_power=False,
152
+ use_mask_tokens=False,
153
+ num_mask_tokens=2,
154
+ zero_init_mask_tokens=True,
155
+ use_sdpa=False,
156
+ use_rope=False,
157
+ use_silu=False,
158
+ use_pred_silu=False,
159
+ wide_silu=False,
160
+ use_activation_checkpointing=False,
161
+ ):
162
+ encoder = video_vit.__dict__[model_name](
163
+ img_size=crop_size,
164
+ patch_size=patch_size,
165
+ num_frames=max_num_frames,
166
+ tubelet_size=tubelet_size,
167
+ uniform_power=uniform_power,
168
+ use_sdpa=use_sdpa,
169
+ use_silu=use_silu,
170
+ wide_silu=wide_silu,
171
+ use_activation_checkpointing=use_activation_checkpointing,
172
+ use_rope=use_rope,
173
+ )
174
+ encoder = MultiSeqWrapper(encoder)
175
+ predictor = vit_pred.__dict__["vit_predictor"](
176
+ img_size=crop_size,
177
+ use_mask_tokens=use_mask_tokens,
178
+ patch_size=patch_size,
179
+ num_frames=max_num_frames,
180
+ tubelet_size=tubelet_size,
181
+ embed_dim=encoder.backbone.embed_dim,
182
+ predictor_embed_dim=pred_embed_dim,
183
+ depth=pred_depth,
184
+ num_heads=encoder.backbone.num_heads if pred_num_heads is None else pred_num_heads,
185
+ uniform_power=uniform_power,
186
+ num_mask_tokens=num_mask_tokens,
187
+ zero_init_mask_tokens=zero_init_mask_tokens,
188
+ use_rope=use_rope,
189
+ use_sdpa=use_sdpa,
190
+ use_silu=use_pred_silu,
191
+ wide_silu=wide_silu,
192
+ use_activation_checkpointing=use_activation_checkpointing,
193
+ )
194
+ predictor = PredictorMultiSeqWrapper(predictor)
195
+
196
+ encoder.to(device)
197
+ predictor.to(device)
198
+ logger.info(encoder)
199
+ logger.info(predictor)
200
+
201
+ def count_parameters(model):
202
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
203
+
204
+ logger.info(f"Encoder number of parameters: {count_parameters(encoder)}")
205
+ logger.info(f"Predictor number of parameters: {count_parameters(predictor)}")
206
+
207
+ return encoder, predictor
208
+
209
+
210
+ def init_opt(
211
+ is_anneal,
212
+ encoder,
213
+ predictor,
214
+ iterations_per_epoch,
215
+ start_lr,
216
+ ref_lr,
217
+ warmup,
218
+ num_epochs,
219
+ wd=1e-6,
220
+ final_wd=1e-6,
221
+ final_lr=0.0,
222
+ mixed_precision=False,
223
+ ipe_scale=1.25,
224
+ betas=(0.9, 0.999),
225
+ eps=1e-8,
226
+ zero_init_bias_wd=True,
227
+ ):
228
+ param_groups = [
229
+ {"params": (p for n, p in encoder.named_parameters() if ("bias" not in n) and (len(p.shape) != 1))},
230
+ {"params": (p for n, p in predictor.named_parameters() if ("bias" not in n) and (len(p.shape) != 1))},
231
+ {
232
+ "params": (p for n, p in encoder.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
233
+ "WD_exclude": zero_init_bias_wd,
234
+ "weight_decay": 0,
235
+ },
236
+ {
237
+ "params": (p for n, p in predictor.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
238
+ "WD_exclude": zero_init_bias_wd,
239
+ "weight_decay": 0,
240
+ },
241
+ ]
242
+
243
+ optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
244
+ if not is_anneal:
245
+ scheduler = WarmupCosineSchedule(
246
+ optimizer,
247
+ warmup_steps=int(warmup * iterations_per_epoch),
248
+ start_lr=start_lr,
249
+ ref_lr=ref_lr,
250
+ final_lr=final_lr,
251
+ T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
252
+ )
253
+ else:
254
+ scheduler = LinearDecaySchedule(
255
+ optimizer,
256
+ ref_lr=ref_lr,
257
+ final_lr=final_lr,
258
+ T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
259
+ )
260
+ wd_scheduler = CosineWDSchedule(
261
+ optimizer,
262
+ ref_wd=wd,
263
+ final_wd=final_wd,
264
+ T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
265
+ )
266
+ scaler = torch.amp.GradScaler('cuda') if mixed_precision else None
267
+ return optimizer, scaler, scheduler, wd_scheduler
vjepa2/app/vjepa_droid/droid.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import json
9
+ import os
10
+ from logging import getLogger
11
+ from math import ceil
12
+
13
+ import h5py
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch
17
+ import torch.utils.data
18
+ from decord import VideoReader, cpu
19
+ from scipy.spatial.transform import Rotation
20
+
21
+ _GLOBAL_SEED = 0
22
+ logger = getLogger()
23
+
24
+
25
+ def init_data(
26
+ data_path,
27
+ batch_size,
28
+ frames_per_clip=16,
29
+ fps=5,
30
+ crop_size=224,
31
+ rank=0,
32
+ world_size=1,
33
+ camera_views=0,
34
+ stereo_view=False,
35
+ drop_last=True,
36
+ num_workers=10,
37
+ pin_mem=True,
38
+ persistent_workers=True,
39
+ collator=None,
40
+ transform=None,
41
+ camera_frame=False,
42
+ tubelet_size=2,
43
+ ):
44
+ dataset = DROIDVideoDataset(
45
+ data_path=data_path,
46
+ frames_per_clip=frames_per_clip,
47
+ transform=transform,
48
+ fps=fps,
49
+ camera_views=camera_views,
50
+ frameskip=tubelet_size,
51
+ camera_frame=camera_frame,
52
+ )
53
+
54
+ dist_sampler = torch.utils.data.distributed.DistributedSampler(
55
+ dataset, num_replicas=world_size, rank=rank, shuffle=True
56
+ )
57
+
58
+ data_loader = torch.utils.data.DataLoader(
59
+ dataset,
60
+ collate_fn=collator,
61
+ sampler=dist_sampler,
62
+ batch_size=batch_size,
63
+ drop_last=drop_last,
64
+ pin_memory=pin_mem,
65
+ num_workers=num_workers,
66
+ persistent_workers=(num_workers > 0) and persistent_workers,
67
+ )
68
+
69
+ logger.info("VideoDataset unsupervised data loader created")
70
+
71
+ return data_loader, dist_sampler
72
+
73
+
74
+ def get_json(directory):
75
+ for filename in os.listdir(directory):
76
+ if filename.endswith(".json"):
77
+ file_path = os.path.join(directory, filename)
78
+ try:
79
+ with open(file_path, "r") as f:
80
+ return json.load(f)
81
+ except json.JSONDecodeError:
82
+ print(f"Error decoding JSON in file: {filename}")
83
+ except Exception as e:
84
+ print(f"An unexpected error occurred while processing {filename}: {e}")
85
+
86
+
87
+ class DROIDVideoDataset(torch.utils.data.Dataset):
88
+ """Video classification dataset."""
89
+
90
+ def __init__(
91
+ self,
92
+ data_path,
93
+ camera_views=["left_mp4_path", "right_mp4_path"],
94
+ frameskip=2,
95
+ frames_per_clip=16,
96
+ fps=5,
97
+ transform=None,
98
+ camera_frame=False,
99
+ ):
100
+ self.data_path = data_path
101
+ self.frames_per_clip = frames_per_clip
102
+ self.frameskip = frameskip
103
+ self.fps = fps
104
+ self.transform = transform
105
+ self.camera_frame = camera_frame
106
+ if VideoReader is None:
107
+ raise ImportError('Unable to import "decord" which is required to read videos.')
108
+
109
+ # Camera views
110
+ # ---
111
+ # wrist camera view
112
+ # left camera view
113
+ # right camera view
114
+ self.camera_views = camera_views
115
+ self.h5_name = "trajectory.h5"
116
+
117
+ samples = list(pd.read_csv(data_path, header=None, delimiter=" ").values[:, 0])
118
+ self.samples = samples
119
+
120
+ def __getitem__(self, index):
121
+ path = self.samples[index]
122
+
123
+ # -- keep trying to load videos until you find a valid sample
124
+ loaded_video = False
125
+ while not loaded_video:
126
+ try:
127
+ buffer, actions, states, extrinsics, indices = self.loadvideo_decord(path)
128
+ loaded_video = True
129
+ except Exception as e:
130
+ logger.info(f"Encountered exception when loading video {path=} {e=}")
131
+ loaded_video = False
132
+ index = np.random.randint(self.__len__())
133
+ path = self.samples[index]
134
+
135
+ return buffer, actions, states, extrinsics, indices
136
+
137
+ def poses_to_diffs(self, poses):
138
+ xyz = poses[:, :3] # shape [T, 3]
139
+ thetas = poses[:, 3:6] # euler angles, shape [T, 3]
140
+ matrices = [Rotation.from_euler("xyz", theta, degrees=False).as_matrix() for theta in thetas]
141
+ xyz_diff = xyz[1:] - xyz[:-1]
142
+ angle_diff = [matrices[t + 1] @ matrices[t].T for t in range(len(matrices) - 1)]
143
+ angle_diff = [Rotation.from_matrix(mat).as_euler("xyz", degrees=False) for mat in angle_diff]
144
+ angle_diff = np.stack([d for d in angle_diff], axis=0)
145
+ closedness = poses[:, -1:]
146
+ closedness_delta = closedness[1:] - closedness[:-1]
147
+ return np.concatenate([xyz_diff, angle_diff, closedness_delta], axis=1)
148
+
149
+ def transform_frame(self, poses, extrinsics):
150
+ gripper = poses[:, -1:]
151
+ poses = poses[:, :-1]
152
+
153
+ def pose_to_transform(pose):
154
+ trans = pose[:3] # shape [3]
155
+ theta = pose[3:6] # euler angles, shape [3]
156
+ Rot = Rotation.from_euler("xyz", theta, degrees=False).as_matrix()
157
+ T = np.eye(4)
158
+ T[:3, :3] = Rot
159
+ T[:3, 3] = trans
160
+ return T
161
+
162
+ def transform_to_pose(transform):
163
+ trans = transform[:3, 3]
164
+ Rot = transform[:3, :3]
165
+ angle = Rotation.from_matrix(Rot).as_euler("xyz", degrees=False)
166
+ return np.concatenate([trans, angle], axis=0)
167
+
168
+ new_pose = []
169
+ for p, e in zip(poses, extrinsics):
170
+ p_transform = pose_to_transform(p)
171
+ e_transform = pose_to_transform(e)
172
+ new_pose_transform = np.linalg.inv(e_transform) @ p_transform
173
+ new_pose += [transform_to_pose(new_pose_transform)]
174
+ new_pose = np.stack(new_pose, axis=0)
175
+
176
+ return np.concatenate([new_pose, gripper], axis=1)
177
+
178
+ def loadvideo_decord(self, path):
179
+ # -- load metadata
180
+ metadata = get_json(path)
181
+ if metadata is None:
182
+ raise Exception(f"No metadata for video {path=}")
183
+
184
+ # -- load trajectory info
185
+ tpath = os.path.join(path, self.h5_name)
186
+ trajectory = h5py.File(tpath)
187
+
188
+ # -- randomly sample a camera view
189
+ camera_view = self.camera_views[torch.randint(0, len(self.camera_views), (1,))]
190
+ mp4_name = metadata[camera_view].split("recordings/MP4/")[-1]
191
+ camera_name = mp4_name.split(".")[0]
192
+ extrinsics = trajectory["observation"]["camera_extrinsics"][f"{camera_name}_left"]
193
+ states = np.concatenate(
194
+ [
195
+ np.array(trajectory["observation"]["robot_state"]["cartesian_position"]),
196
+ np.array(trajectory["observation"]["robot_state"]["gripper_position"])[:, None],
197
+ ],
198
+ axis=1,
199
+ ) # [T, 7]
200
+ vpath = os.path.join(path, "recordings/MP4", mp4_name)
201
+ vr = VideoReader(vpath, num_threads=-1, ctx=cpu(0))
202
+ # --
203
+ vfps = vr.get_avg_fps()
204
+ fpc = self.frames_per_clip
205
+ fps = self.fps if self.fps is not None else vfps
206
+ fstp = ceil(vfps / fps)
207
+ nframes = int(fpc * fstp)
208
+ vlen = len(vr)
209
+
210
+ if vlen < nframes:
211
+ raise Exception(f"Video is too short {vpath=}, {nframes=}, {vlen=}")
212
+
213
+ # sample a random window of nframes
214
+ ef = np.random.randint(nframes, vlen)
215
+ sf = ef - nframes
216
+ indices = np.arange(sf, sf + nframes, fstp).astype(np.int64)
217
+ # --
218
+ states = states[indices, :][:: self.frameskip]
219
+ extrinsics = extrinsics[indices, :][:: self.frameskip]
220
+ if self.camera_frame:
221
+ states = self.transform_frame(states, extrinsics)
222
+ actions = self.poses_to_diffs(states)
223
+ # --
224
+ vr.seek(0) # go to start of video before sampling frames
225
+ buffer = vr.get_batch(indices).asnumpy()
226
+ if self.transform is not None:
227
+ buffer = self.transform(buffer)
228
+
229
+ return buffer, actions, states, extrinsics, indices
230
+
231
+ def __len__(self):
232
+ return len(self.samples)
vjepa2/app/vjepa_droid/train.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import os
9
+
10
+ # -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
11
+ try:
12
+ # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
13
+ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
14
+ # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
15
+ # -- TO EACH PROCESS
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"]
17
+ except Exception:
18
+ pass
19
+
20
+ import copy
21
+ import gc
22
+ import random
23
+ import time
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.multiprocessing as mp
28
+ import torch.nn.functional as F
29
+ from torch.nn.parallel import DistributedDataParallel
30
+
31
+ from app.vjepa_droid.droid import init_data
32
+ from app.vjepa_droid.transforms import make_transforms
33
+ from app.vjepa_droid.utils import init_opt, init_video_model, load_checkpoint, load_pretrained
34
+ from src.utils.distributed import init_distributed
35
+ from src.utils.logging import AverageMeter, CSVLogger, get_logger, gpu_timer
36
+
37
+ # --
38
+ log_timings = True
39
+ log_freq = 10
40
+ CHECKPOINT_FREQ = 1
41
+ GARBAGE_COLLECT_ITR_FREQ = 50
42
+ # --
43
+
44
+ _GLOBAL_SEED = 0
45
+ random.seed(_GLOBAL_SEED)
46
+ np.random.seed(_GLOBAL_SEED)
47
+ torch.manual_seed(_GLOBAL_SEED)
48
+ torch.backends.cudnn.benchmark = True
49
+
50
+
51
+ logger = get_logger(__name__, force=True)
52
+
53
+
54
+ def main(args, resume_preempt=False):
55
+ # ----------------------------------------------------------------------- #
56
+ # PASSED IN PARAMS FROM CONFIG FILE
57
+ # ----------------------------------------------------------------------- #
58
+
59
+ # -- META
60
+ folder = args.get("folder")
61
+ cfgs_meta = args.get("meta")
62
+ r_file = cfgs_meta.get("resume_checkpoint", None)
63
+ p_file = cfgs_meta.get("pretrain_checkpoint", None)
64
+ load_predictor = cfgs_meta.get("load_predictor", False)
65
+ context_encoder_key = cfgs_meta.get("context_encoder_key", "encoder")
66
+ target_encoder_key = cfgs_meta.get("target_encoder_key", "target_encoder")
67
+ load_encoder = cfgs_meta.get("load_encoder", True)
68
+ seed = cfgs_meta.get("seed", _GLOBAL_SEED)
69
+ save_every_freq = cfgs_meta.get("save_every_freq", -1)
70
+ skip_batches = cfgs_meta.get("skip_batches", -1)
71
+ use_sdpa = cfgs_meta.get("use_sdpa", False)
72
+ sync_gc = cfgs_meta.get("sync_gc", False)
73
+ which_dtype = cfgs_meta.get("dtype")
74
+ logger.info(f"{which_dtype=}")
75
+ if which_dtype.lower() == "bfloat16":
76
+ dtype = torch.bfloat16
77
+ mixed_precision = True
78
+ elif which_dtype.lower() == "float16":
79
+ dtype = torch.float16
80
+ mixed_precision = True
81
+ else:
82
+ dtype = torch.float32
83
+ mixed_precision = False
84
+
85
+ # -- MODEL
86
+ cfgs_model = args.get("model")
87
+ compile_model = cfgs_model.get("compile_model", False)
88
+ use_activation_checkpointing = cfgs_model.get("use_activation_checkpointing", False)
89
+ model_name = cfgs_model.get("model_name")
90
+ pred_depth = cfgs_model.get("pred_depth")
91
+ pred_num_heads = cfgs_model.get("pred_num_heads", None)
92
+ pred_embed_dim = cfgs_model.get("pred_embed_dim")
93
+ pred_is_frame_causal = cfgs_model.get("pred_is_frame_causal", True)
94
+ uniform_power = cfgs_model.get("uniform_power", False)
95
+ use_rope = cfgs_model.get("use_rope", False)
96
+ use_silu = cfgs_model.get("use_silu", False)
97
+ use_pred_silu = cfgs_model.get("use_pred_silu", False)
98
+ wide_silu = cfgs_model.get("wide_silu", True)
99
+ use_extrinsics = cfgs_model.get("use_extrinsics", False)
100
+
101
+ # -- DATA
102
+ cfgs_data = args.get("data")
103
+ datasets = cfgs_data.get("datasets", [])
104
+ dataset_path = datasets[0]
105
+ dataset_fpcs = cfgs_data.get("dataset_fpcs")
106
+ max_num_frames = max(dataset_fpcs)
107
+ camera_frame = cfgs_data.get("camera_frame", False)
108
+ camera_views = cfgs_data.get("camera_views", ["left_mp4_path"])
109
+ stereo_view = cfgs_data.get("stereo_view", False)
110
+ batch_size = cfgs_data.get("batch_size")
111
+ tubelet_size = cfgs_data.get("tubelet_size")
112
+ fps = cfgs_data.get("fps")
113
+ crop_size = cfgs_data.get("crop_size", 256)
114
+ patch_size = cfgs_data.get("patch_size")
115
+ pin_mem = cfgs_data.get("pin_mem", False)
116
+ num_workers = cfgs_data.get("num_workers", 1)
117
+ persistent_workers = cfgs_data.get("persistent_workers", True)
118
+
119
+ # -- DATA AUGS
120
+ cfgs_data_aug = args.get("data_aug")
121
+ horizontal_flip = cfgs_data_aug.get("horizontal_flip", False)
122
+ ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3])
123
+ rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0])
124
+ motion_shift = cfgs_data_aug.get("motion_shift", False)
125
+ reprob = cfgs_data_aug.get("reprob", 0.0)
126
+ use_aa = cfgs_data_aug.get("auto_augment", False)
127
+
128
+ # -- LOSS
129
+ cfgs_loss = args.get("loss")
130
+ loss_exp = cfgs_loss.get("loss_exp")
131
+ normalize_reps = cfgs_loss.get("normalize_reps")
132
+ auto_steps = min(cfgs_loss.get("auto_steps", 1), max_num_frames)
133
+ # --
134
+ tokens_per_frame = int((crop_size // patch_size) ** 2)
135
+
136
+ # -- OPTIMIZATION
137
+ cfgs_opt = args.get("optimization")
138
+ ipe = cfgs_opt.get("ipe", None)
139
+ wd = float(cfgs_opt.get("weight_decay"))
140
+ final_wd = float(cfgs_opt.get("final_weight_decay"))
141
+ num_epochs = cfgs_opt.get("epochs")
142
+ anneal = cfgs_opt.get("anneal")
143
+ warmup = cfgs_opt.get("warmup")
144
+ start_lr = cfgs_opt.get("start_lr")
145
+ lr = cfgs_opt.get("lr")
146
+ final_lr = cfgs_opt.get("final_lr")
147
+ enc_lr_scale = cfgs_opt.get("enc_lr_scale", 1.0)
148
+ betas = cfgs_opt.get("betas", (0.9, 0.999))
149
+ eps = cfgs_opt.get("eps", 1.0e-8)
150
+ # ----------------------------------------------------------------------- #
151
+ # ----------------------------------------------------------------------- #
152
+
153
+ np.random.seed(seed)
154
+ torch.manual_seed(seed)
155
+ torch.backends.cudnn.benchmark = True
156
+ try:
157
+ mp.set_start_method("spawn")
158
+ except Exception:
159
+ pass
160
+
161
+ # -- init torch distributed backend
162
+ world_size, rank = init_distributed()
163
+ logger.info(f"Initialized (rank/world-size) {rank}/{world_size}")
164
+
165
+ # -- set device
166
+ if not torch.cuda.is_available():
167
+ device = torch.device("cpu")
168
+ else:
169
+ device = torch.device("cuda:0")
170
+ torch.cuda.set_device(device)
171
+
172
+ # -- log/checkpointing paths
173
+ log_file = os.path.join(folder, f"log_r{rank}.csv")
174
+ latest_path = os.path.join(folder, "latest.pt")
175
+ resume_path = os.path.join(folder, r_file) if r_file is not None else latest_path
176
+ if not os.path.exists(resume_path):
177
+ resume_path = None
178
+
179
+ # -- make csv_logger
180
+ csv_logger = CSVLogger(
181
+ log_file,
182
+ ("%d", "epoch"),
183
+ ("%d", "itr"),
184
+ ("%.5f", "loss"),
185
+ ("%d", "iter-time(ms)"),
186
+ ("%d", "gpu-time(ms)"),
187
+ ("%d", "dataload-time(ms)"),
188
+ mode="+a",
189
+ )
190
+
191
+ # -- init model
192
+ encoder, predictor = init_video_model(
193
+ uniform_power=uniform_power,
194
+ device=device,
195
+ patch_size=patch_size,
196
+ max_num_frames=512,
197
+ tubelet_size=tubelet_size,
198
+ model_name=model_name,
199
+ crop_size=crop_size,
200
+ pred_depth=pred_depth,
201
+ pred_num_heads=pred_num_heads,
202
+ pred_embed_dim=pred_embed_dim,
203
+ action_embed_dim=7,
204
+ pred_is_frame_causal=pred_is_frame_causal,
205
+ use_extrinsics=use_extrinsics,
206
+ use_sdpa=use_sdpa,
207
+ use_silu=use_silu,
208
+ use_pred_silu=use_pred_silu,
209
+ wide_silu=wide_silu,
210
+ use_rope=use_rope,
211
+ use_activation_checkpointing=use_activation_checkpointing,
212
+ )
213
+ target_encoder = copy.deepcopy(encoder)
214
+
215
+ if compile_model:
216
+ logger.info("Compiling encoder, target_encoder, and predictor.")
217
+ torch._dynamo.config.optimize_ddp = False
218
+ encoder.compile()
219
+ target_encoder.compile()
220
+ predictor.compile()
221
+
222
+ video_collator = torch.utils.data.default_collate
223
+ transform = make_transforms(
224
+ random_horizontal_flip=horizontal_flip,
225
+ random_resize_aspect_ratio=ar_range,
226
+ random_resize_scale=rr_scale,
227
+ reprob=reprob,
228
+ auto_augment=use_aa,
229
+ motion_shift=motion_shift,
230
+ crop_size=crop_size,
231
+ )
232
+
233
+ # -- init data-loaders/samplers
234
+ (unsupervised_loader, unsupervised_sampler) = init_data(
235
+ data_path=dataset_path,
236
+ batch_size=batch_size,
237
+ frames_per_clip=max_num_frames,
238
+ tubelet_size=1,
239
+ fps=fps,
240
+ camera_views=camera_views,
241
+ camera_frame=camera_frame,
242
+ stereo_view=stereo_view,
243
+ transform=transform,
244
+ collator=video_collator,
245
+ num_workers=num_workers,
246
+ world_size=world_size,
247
+ pin_mem=pin_mem,
248
+ persistent_workers=persistent_workers,
249
+ rank=rank,
250
+ )
251
+ _dlen = len(unsupervised_loader)
252
+ if ipe is None:
253
+ ipe = _dlen
254
+ logger.info(f"iterations per epoch/dataset length: {ipe}/{_dlen}")
255
+
256
+ # -- init optimizer and scheduler
257
+ optimizer, scaler, scheduler, wd_scheduler = init_opt(
258
+ encoder=encoder,
259
+ predictor=predictor,
260
+ wd=wd,
261
+ final_wd=final_wd,
262
+ start_lr=start_lr,
263
+ ref_lr=lr,
264
+ final_lr=final_lr,
265
+ enc_lr_scale=enc_lr_scale,
266
+ iterations_per_epoch=ipe,
267
+ anneal=anneal,
268
+ warmup=warmup,
269
+ num_epochs=num_epochs,
270
+ mixed_precision=mixed_precision,
271
+ betas=betas,
272
+ eps=eps,
273
+ )
274
+ encoder = DistributedDataParallel(encoder, static_graph=True)
275
+ predictor = DistributedDataParallel(predictor, static_graph=False, find_unused_parameters=True)
276
+ target_encoder = DistributedDataParallel(target_encoder)
277
+ for p in target_encoder.parameters():
278
+ p.requires_grad = False
279
+
280
+ # -- looad pretrained weights
281
+ encoder, predictor, target_encoder = load_pretrained(
282
+ r_path=p_file,
283
+ encoder=encoder,
284
+ predictor=predictor,
285
+ context_encoder_key=context_encoder_key,
286
+ target_encoder_key=target_encoder_key,
287
+ target_encoder=target_encoder,
288
+ load_predictor=load_predictor,
289
+ load_encoder=load_encoder,
290
+ )
291
+
292
+ start_epoch = 0
293
+ # -- load training checkpoint
294
+ if os.path.exists(latest_path):
295
+ (
296
+ encoder,
297
+ predictor,
298
+ target_encoder,
299
+ optimizer,
300
+ scaler,
301
+ start_epoch,
302
+ ) = load_checkpoint(
303
+ r_path=resume_path,
304
+ encoder=encoder,
305
+ predictor=predictor,
306
+ target_encoder=target_encoder,
307
+ opt=optimizer,
308
+ scaler=scaler,
309
+ )
310
+ for _ in range(start_epoch * ipe):
311
+ scheduler.step()
312
+ wd_scheduler.step()
313
+
314
+ def save_checkpoint(epoch, path):
315
+ if rank != 0:
316
+ return
317
+ save_dict = {
318
+ "encoder": encoder.state_dict(),
319
+ "predictor": predictor.state_dict(),
320
+ "opt": optimizer.state_dict(),
321
+ "scaler": None if scaler is None else scaler.state_dict(),
322
+ "target_encoder": target_encoder.state_dict(),
323
+ "epoch": epoch,
324
+ "loss": loss_meter.avg,
325
+ "batch_size": batch_size,
326
+ "world_size": world_size,
327
+ "lr": lr,
328
+ }
329
+ try:
330
+ torch.save(save_dict, path)
331
+ except Exception as e:
332
+ logger.info(f"Encountered exception when saving checkpoint: {e}")
333
+
334
+ logger.info("Initializing loader...")
335
+ unsupervised_sampler.set_epoch(start_epoch)
336
+ loader = iter(unsupervised_loader)
337
+
338
+ if skip_batches > 0:
339
+ logger.info(f"Skip {skip_batches} batches")
340
+ # -- update distributed-data-loader epoch
341
+
342
+ for itr in range(skip_batches):
343
+ if itr % 10 == 0:
344
+ logger.info(f"Skip {itr}/{skip_batches} batches")
345
+ try:
346
+ _ = next(loader)
347
+ except Exception:
348
+ loader = iter(unsupervised_loader)
349
+ _ = next(loader)
350
+
351
+ if sync_gc:
352
+ gc.disable()
353
+ gc.collect()
354
+
355
+ # -- TRAINING LOOP
356
+ for epoch in range(start_epoch, num_epochs):
357
+ logger.info("Epoch %d" % (epoch + 1))
358
+
359
+ loss_meter = AverageMeter()
360
+ jloss_meter = AverageMeter()
361
+ sloss_meter = AverageMeter()
362
+ iter_time_meter = AverageMeter()
363
+ gpu_time_meter = AverageMeter()
364
+ data_elapsed_time_meter = AverageMeter()
365
+
366
+ for itr in range(ipe):
367
+ itr_start_time = time.time()
368
+
369
+ iter_retries = 0
370
+ iter_successful = False
371
+ while not iter_successful:
372
+ try:
373
+ sample = next(loader)
374
+ iter_successful = True
375
+ except StopIteration:
376
+ logger.info("Exhausted data loaders. Refreshing...")
377
+ unsupervised_sampler.set_epoch(epoch)
378
+ loader = iter(unsupervised_loader)
379
+ except Exception as e:
380
+ NUM_RETRIES = 5
381
+ if iter_retries < NUM_RETRIES:
382
+ logger.warning(f"Encountered exception when loading data (num retries {iter_retries}):\n{e}")
383
+ iter_retries += 1
384
+ time.sleep(5)
385
+ else:
386
+ logger.warning(f"Exceeded max retries ({NUM_RETRIES}) when loading data. Skipping batch.")
387
+ raise e
388
+
389
+ def load_clips():
390
+ clips = sample[0].to(device, non_blocking=True) # [B C T H W]
391
+ actions = sample[1].to(device, dtype=torch.float, non_blocking=True) # [B T-1 7]
392
+ states = sample[2].to(device, dtype=torch.float, non_blocking=True) # [B T 7]
393
+ extrinsics = sample[3].to(device, dtype=torch.float, non_blocking=True) # [B T 7]
394
+ return (clips, actions, states, extrinsics)
395
+
396
+ clips, actions, states, extrinsics = load_clips()
397
+ data_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
398
+
399
+ if sync_gc and (itr + 1) % GARBAGE_COLLECT_ITR_FREQ == 0:
400
+ logger.info("Running garbage collection...")
401
+ gc.collect()
402
+
403
+ def train_step():
404
+ _new_lr = scheduler.step()
405
+ _new_wd = wd_scheduler.step()
406
+ # --
407
+
408
+ def forward_target(c):
409
+ with torch.no_grad():
410
+ c = c.permute(0, 2, 1, 3, 4).flatten(0, 1).unsqueeze(2).repeat(1, 1, 2, 1, 1)
411
+ h = target_encoder(c)
412
+ h = h.view(batch_size, max_num_frames, -1, h.size(-1)).flatten(1, 2)
413
+ if normalize_reps:
414
+ h = F.layer_norm(h, (h.size(-1),))
415
+ return h
416
+
417
+ def forward_predictions(z):
418
+
419
+ def _step_predictor(_z, _a, _s, _e):
420
+ _z = predictor(_z, _a, _s, _e)
421
+ if normalize_reps:
422
+ _z = F.layer_norm(_z, (_z.size(-1),))
423
+ return _z
424
+
425
+ # -- one step of predictor with teacher forcing
426
+ _z, _a, _s, _e = z[:, :-tokens_per_frame], actions, states[:, :-1], extrinsics[:, :-1]
427
+ z_tf = _step_predictor(_z, _a, _s, _e)
428
+
429
+ # -- full auto-regressive rollouts of predictor
430
+ _z = torch.cat([z[:, : tokens_per_frame], z_tf[:, : tokens_per_frame]], dim=1)
431
+ for n in range(1, auto_steps):
432
+ _a, _s, _e = actions[:, : n + 1], states[:, : n + 1], extrinsics[:, : n + 1]
433
+ _z_nxt = _step_predictor(_z, _a, _s, _e)[:, -tokens_per_frame:]
434
+ _z = torch.cat([_z, _z_nxt], dim=1)
435
+ z_ar = _z[:, tokens_per_frame:]
436
+
437
+ return z_tf, z_ar
438
+
439
+ def loss_fn(z, h):
440
+ _h = h[:, tokens_per_frame : z.size(1) + tokens_per_frame]
441
+ return torch.mean(torch.abs(z - _h) ** loss_exp) / loss_exp
442
+
443
+ # Step 1. Forward
444
+ with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision):
445
+ h = forward_target(clips)
446
+ z_tf, z_ar = forward_predictions(h)
447
+ jloss = loss_fn(z_tf, h)
448
+ sloss = loss_fn(z_ar, h)
449
+ loss = jloss + sloss
450
+
451
+ # Step 2. Backward & step
452
+ if mixed_precision:
453
+ scaler.scale(loss).backward()
454
+ scaler.unscale_(optimizer)
455
+ else:
456
+ loss.backward()
457
+ if mixed_precision:
458
+ scaler.step(optimizer)
459
+ scaler.update()
460
+ else:
461
+ optimizer.step()
462
+ optimizer.zero_grad()
463
+
464
+ return (
465
+ float(loss),
466
+ float(jloss),
467
+ float(sloss),
468
+ _new_lr,
469
+ _new_wd,
470
+ )
471
+
472
+ (
473
+ loss,
474
+ jloss,
475
+ sloss,
476
+ _new_lr,
477
+ _new_wd,
478
+ ), gpu_etime_ms = gpu_timer(train_step)
479
+ iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
480
+ loss_meter.update(loss)
481
+ jloss_meter.update(jloss)
482
+ sloss_meter.update(sloss)
483
+ iter_time_meter.update(iter_elapsed_time_ms)
484
+ gpu_time_meter.update(gpu_etime_ms)
485
+ data_elapsed_time_meter.update(data_elapsed_time_ms)
486
+
487
+ # -- Logging
488
+ def log_stats():
489
+ csv_logger.log(epoch + 1, itr, loss, iter_elapsed_time_ms, gpu_etime_ms, data_elapsed_time_ms)
490
+ if (itr % log_freq == 0) or (itr == ipe - 1) or np.isnan(loss) or np.isinf(loss):
491
+ logger.info(
492
+ "[%d, %5d] loss: %.3f [%.2f, %.2f] "
493
+ "[wd: %.2e] [lr: %.2e] "
494
+ "[mem: %.2e] "
495
+ "[iter: %.1f ms] "
496
+ "[gpu: %.1f ms] "
497
+ "[data: %.1f ms]"
498
+ % (
499
+ epoch + 1,
500
+ itr,
501
+ loss_meter.avg,
502
+ jloss_meter.avg,
503
+ sloss_meter.avg,
504
+ _new_wd,
505
+ _new_lr,
506
+ torch.cuda.max_memory_allocated() / 1024.0**2,
507
+ iter_time_meter.avg,
508
+ gpu_time_meter.avg,
509
+ data_elapsed_time_meter.avg,
510
+ )
511
+ )
512
+
513
+ log_stats()
514
+ assert not np.isnan(loss), "loss is nan"
515
+
516
+ # -- Save Checkpoint
517
+ logger.info("avg. loss %.3f" % loss_meter.avg)
518
+ # -- Save Last
519
+ if epoch % CHECKPOINT_FREQ == 0 or epoch == (num_epochs - 1):
520
+ save_checkpoint(epoch + 1, latest_path)
521
+ if save_every_freq > 0 and epoch % save_every_freq == 0:
522
+ save_every_file = f"e{epoch}.pt"
523
+ save_every_path = os.path.join(folder, save_every_file)
524
+ save_checkpoint(epoch + 1, save_every_path)
vjepa2/app/vjepa_droid/transforms.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+
11
+ import src.datasets.utils.video.transforms as video_transforms
12
+ from src.datasets.utils.video.randerase import RandomErasing
13
+
14
+
15
+ def make_transforms(
16
+ random_horizontal_flip=True,
17
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
18
+ random_resize_scale=(0.3, 1.0),
19
+ reprob=0.0,
20
+ auto_augment=False,
21
+ motion_shift=False,
22
+ crop_size=224,
23
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
24
+ ):
25
+
26
+ _frames_augmentation = VideoTransform(
27
+ random_horizontal_flip=random_horizontal_flip,
28
+ random_resize_aspect_ratio=random_resize_aspect_ratio,
29
+ random_resize_scale=random_resize_scale,
30
+ reprob=reprob,
31
+ auto_augment=auto_augment,
32
+ motion_shift=motion_shift,
33
+ crop_size=crop_size,
34
+ normalize=normalize,
35
+ )
36
+ return _frames_augmentation
37
+
38
+
39
+ class VideoTransform(object):
40
+
41
+ def __init__(
42
+ self,
43
+ random_horizontal_flip=True,
44
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
45
+ random_resize_scale=(0.3, 1.0),
46
+ reprob=0.0,
47
+ auto_augment=False,
48
+ motion_shift=False,
49
+ crop_size=224,
50
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
51
+ ):
52
+
53
+ self.random_horizontal_flip = random_horizontal_flip
54
+ self.random_resize_aspect_ratio = random_resize_aspect_ratio
55
+ self.random_resize_scale = random_resize_scale
56
+ self.auto_augment = auto_augment
57
+ self.motion_shift = motion_shift
58
+ self.crop_size = crop_size
59
+ self.mean = torch.tensor(normalize[0], dtype=torch.float32)
60
+ self.std = torch.tensor(normalize[1], dtype=torch.float32)
61
+ if not self.auto_augment:
62
+ # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
63
+ self.mean *= 255.0
64
+ self.std *= 255.0
65
+
66
+ self.autoaug_transform = video_transforms.create_random_augment(
67
+ input_size=(crop_size, crop_size),
68
+ # auto_augment="rand-m4-n4-w1-mstd0.5-inc1",
69
+ auto_augment="rand-m7-n4-mstd0.5-inc1",
70
+ interpolation="bicubic",
71
+ )
72
+
73
+ self.spatial_transform = (
74
+ video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
75
+ )
76
+
77
+ self.reprob = reprob
78
+ self.erase_transform = RandomErasing(
79
+ reprob,
80
+ mode="pixel",
81
+ max_count=1,
82
+ num_splits=1,
83
+ device="cpu",
84
+ )
85
+
86
+ def __call__(self, buffer):
87
+
88
+ if self.auto_augment:
89
+ buffer = [transforms.ToPILImage()(frame) for frame in buffer]
90
+ buffer = self.autoaug_transform(buffer)
91
+ buffer = [transforms.ToTensor()(img) for img in buffer]
92
+ buffer = torch.stack(buffer) # T C H W
93
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
94
+ elif torch.is_tensor(buffer):
95
+ # TODO: ensure input is always a tensor?
96
+ buffer = buffer.to(torch.float32)
97
+ else:
98
+ buffer = torch.tensor(buffer, dtype=torch.float32)
99
+
100
+ buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
101
+
102
+ buffer = self.spatial_transform(
103
+ images=buffer,
104
+ target_height=self.crop_size,
105
+ target_width=self.crop_size,
106
+ scale=self.random_resize_scale,
107
+ ratio=self.random_resize_aspect_ratio,
108
+ )
109
+ if self.random_horizontal_flip:
110
+ buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
111
+
112
+ buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
113
+ if self.reprob > 0:
114
+ buffer = buffer.permute(1, 0, 2, 3)
115
+ buffer = self.erase_transform(buffer)
116
+ buffer = buffer.permute(1, 0, 2, 3)
117
+
118
+ return buffer
119
+
120
+
121
+ def tensor_normalize(tensor, mean, std):
122
+ """
123
+ Normalize a given tensor by subtracting the mean and dividing the std.
124
+ Args:
125
+ tensor (tensor): tensor to normalize.
126
+ mean (tensor or list): mean value to subtract.
127
+ std (tensor or list): std to divide.
128
+ """
129
+ if tensor.dtype == torch.uint8:
130
+ tensor = tensor.float()
131
+ tensor = tensor / 255.0
132
+ if type(mean) == list:
133
+ mean = torch.tensor(mean)
134
+ if type(std) == list:
135
+ std = torch.tensor(std)
136
+ tensor = tensor - mean
137
+ tensor = tensor / std
138
+ return tensor
139
+
140
+
141
+ def _tensor_normalize_inplace(tensor, mean, std):
142
+ """
143
+ Normalize a given tensor by subtracting the mean and dividing the std.
144
+ Args:
145
+ tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
146
+ mean (tensor): mean value to subtract (in 0 to 255 floats).
147
+ std (tensor): std to divide (in 0 to 255 floats).
148
+ """
149
+ if tensor.dtype == torch.uint8:
150
+ tensor = tensor.float()
151
+
152
+ C, T, H, W = tensor.shape
153
+ tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
154
+ tensor.sub_(mean).div_(std)
155
+ tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
156
+ return tensor
vjepa2/app/vjepa_droid/utils.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import logging
9
+ import sys
10
+
11
+ import torch
12
+
13
+ import src.models.ac_predictor as vit_ac_pred
14
+ import src.models.vision_transformer as video_vit
15
+ from src.utils.checkpoint_loader import robust_checkpoint_loader
16
+ from src.utils.schedulers import CosineWDSchedule, WSDSchedule
17
+
18
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19
+ logger = logging.getLogger()
20
+
21
+
22
+ def load_pretrained(
23
+ r_path,
24
+ encoder=None,
25
+ predictor=None,
26
+ target_encoder=None,
27
+ context_encoder_key="encoder",
28
+ target_encoder_key="target_encoder",
29
+ load_predictor=False,
30
+ load_encoder=True,
31
+ ):
32
+ logger.info(f"Loading pretrained model from {r_path}")
33
+ checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
34
+
35
+ epoch = checkpoint["epoch"]
36
+
37
+ if load_encoder:
38
+ # -- loading encoder
39
+ pretrained_dict = checkpoint[context_encoder_key]
40
+ pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
41
+ msg = encoder.load_state_dict(pretrained_dict, strict=False)
42
+ logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")
43
+
44
+ if load_predictor:
45
+ # -- loading predictor
46
+ pretrained_dict = checkpoint["predictor"]
47
+ pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
48
+ msg = predictor.load_state_dict(pretrained_dict, strict=False)
49
+ logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")
50
+
51
+ # -- loading target_encoder
52
+ if load_encoder:
53
+ if target_encoder is not None:
54
+ print(list(checkpoint.keys()))
55
+ pretrained_dict = checkpoint[target_encoder_key]
56
+ pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
57
+ msg = target_encoder.load_state_dict(pretrained_dict, strict=False)
58
+ logger.info(f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}")
59
+
60
+ del checkpoint
61
+
62
+ return (
63
+ encoder,
64
+ predictor,
65
+ target_encoder,
66
+ )
67
+
68
+
69
+ def load_checkpoint(
70
+ r_path,
71
+ encoder,
72
+ predictor,
73
+ target_encoder,
74
+ opt=None,
75
+ scaler=None,
76
+ replace_kw=["backbone."],
77
+ ):
78
+ logger.info(f"Loading checkpoint from {r_path}")
79
+ checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
80
+
81
+ epoch = checkpoint["epoch"]
82
+
83
+ # -- loading encoder
84
+ pretrained_dict = checkpoint["encoder"]
85
+ for kw in replace_kw:
86
+ pretrained_dict = {k.replace(kw, ""): v for k, v in pretrained_dict.items()}
87
+ msg = encoder.load_state_dict(pretrained_dict, strict=False)
88
+ logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")
89
+
90
+ # -- loading predictor
91
+ pretrained_dict = checkpoint["predictor"]
92
+ for kw in replace_kw:
93
+ pretrained_dict = {k.replace(kw, ""): v for k, v in pretrained_dict.items()}
94
+ msg = predictor.load_state_dict(pretrained_dict, strict=False)
95
+ logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")
96
+
97
+ # -- loading target_encoder
98
+ if target_encoder is not None:
99
+ print(list(checkpoint.keys()))
100
+ pretrained_dict = checkpoint["target_encoder"]
101
+ for kw in replace_kw:
102
+ pretrained_dict = {k.replace(kw, ""): v for k, v in pretrained_dict.items()}
103
+ msg = target_encoder.load_state_dict(pretrained_dict, strict=False)
104
+ logger.info(f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}")
105
+
106
+ # -- loading optimizer
107
+ if opt is not None:
108
+ opt.load_state_dict(checkpoint["opt"])
109
+
110
+ if scaler is not None:
111
+ scaler.load_state_dict(checkpoint["scaler"])
112
+
113
+ logger.info(f"loaded optimizers from epoch {epoch}")
114
+ logger.info(f"read-path: {r_path}")
115
+ del checkpoint
116
+
117
+ return (
118
+ encoder,
119
+ predictor,
120
+ target_encoder,
121
+ opt,
122
+ scaler,
123
+ epoch,
124
+ )
125
+
126
+
127
+ def init_video_model(
128
+ device,
129
+ patch_size=16,
130
+ max_num_frames=16,
131
+ tubelet_size=2,
132
+ model_name="vit_base",
133
+ crop_size=224,
134
+ pred_depth=6,
135
+ pred_num_heads=None,
136
+ pred_embed_dim=384,
137
+ uniform_power=False,
138
+ use_sdpa=False,
139
+ use_rope=False,
140
+ use_silu=False,
141
+ use_pred_silu=False,
142
+ wide_silu=False,
143
+ pred_is_frame_causal=True,
144
+ use_activation_checkpointing=False,
145
+ return_all_tokens=False,
146
+ action_embed_dim=7,
147
+ use_extrinsics=False,
148
+ old_pred=False,
149
+ ):
150
+ encoder = video_vit.__dict__[model_name](
151
+ img_size=crop_size,
152
+ patch_size=patch_size,
153
+ num_frames=max_num_frames,
154
+ tubelet_size=tubelet_size,
155
+ uniform_power=uniform_power,
156
+ use_sdpa=use_sdpa,
157
+ use_silu=use_silu,
158
+ wide_silu=wide_silu,
159
+ use_activation_checkpointing=use_activation_checkpointing,
160
+ use_rope=use_rope,
161
+ )
162
+
163
+ predictor = vit_ac_pred.__dict__["vit_ac_predictor"](
164
+ img_size=crop_size,
165
+ patch_size=patch_size,
166
+ num_frames=max_num_frames,
167
+ tubelet_size=tubelet_size,
168
+ embed_dim=encoder.embed_dim,
169
+ predictor_embed_dim=pred_embed_dim,
170
+ action_embed_dim=action_embed_dim,
171
+ depth=pred_depth,
172
+ is_frame_causal=pred_is_frame_causal,
173
+ num_heads=encoder.num_heads if pred_num_heads is None else pred_num_heads,
174
+ uniform_power=uniform_power,
175
+ use_rope=use_rope,
176
+ use_sdpa=use_sdpa,
177
+ use_silu=use_pred_silu,
178
+ wide_silu=wide_silu,
179
+ use_extrinsics=use_extrinsics,
180
+ use_activation_checkpointing=use_activation_checkpointing,
181
+ )
182
+
183
+ encoder.to(device)
184
+ predictor.to(device)
185
+ logger.info(encoder)
186
+ logger.info(predictor)
187
+
188
+ def count_parameters(model):
189
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
190
+
191
+ logger.info(f"Encoder number of parameters: {count_parameters(encoder)}")
192
+ logger.info(f"Predictor number of parameters: {count_parameters(predictor)}")
193
+
194
+ return encoder, predictor
195
+
196
+
197
+ def init_opt(
198
+ encoder,
199
+ predictor,
200
+ iterations_per_epoch,
201
+ start_lr,
202
+ ref_lr,
203
+ warmup,
204
+ anneal,
205
+ num_epochs,
206
+ wd=1e-6,
207
+ final_wd=1e-6,
208
+ final_lr=0.0,
209
+ mixed_precision=False,
210
+ betas=(0.9, 0.999),
211
+ eps=1e-8,
212
+ zero_init_bias_wd=True,
213
+ enc_lr_scale=1.0,
214
+ ):
215
+ param_groups = [
216
+ {
217
+ "params": (p for n, p in encoder.named_parameters() if ("bias" not in n) and (len(p.shape) != 1)),
218
+ "lr_scale": enc_lr_scale,
219
+ },
220
+ {
221
+ "params": (p for n, p in predictor.named_parameters() if ("bias" not in n) and (len(p.shape) != 1)),
222
+ },
223
+ {
224
+ "params": (p for n, p in encoder.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
225
+ "WD_exclude": zero_init_bias_wd,
226
+ "weight_decay": 0,
227
+ "lr_scale": enc_lr_scale,
228
+ },
229
+ {
230
+ "params": (p for n, p in predictor.named_parameters() if ("bias" in n) or (len(p.shape) == 1)),
231
+ "WD_exclude": zero_init_bias_wd,
232
+ "weight_decay": 0,
233
+ },
234
+ ]
235
+
236
+ optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
237
+ scheduler = WSDSchedule(
238
+ optimizer,
239
+ warmup_steps=int(warmup * iterations_per_epoch),
240
+ anneal_steps=int(anneal * iterations_per_epoch),
241
+ start_lr=start_lr,
242
+ ref_lr=ref_lr,
243
+ final_lr=final_lr,
244
+ T_max=int(num_epochs * iterations_per_epoch),
245
+ )
246
+ wd_scheduler = CosineWDSchedule(
247
+ optimizer,
248
+ ref_wd=wd,
249
+ final_wd=final_wd,
250
+ T_max=int(num_epochs * iterations_per_epoch),
251
+ )
252
+ scaler = torch.cuda.amp.GradScaler() if mixed_precision else None
253
+ return optimizer, scaler, scheduler, wd_scheduler
vjepa2/configs/eval/vitg-384/coin.yaml ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cpus_per_task: 16
2
+ eval_name: video_classification_frozen
3
+ folder: /your_folder/evals/vitg-384/coin
4
+ mem_per_gpu: 220G
5
+ nodes: 16
6
+ resume_checkpoint: true
7
+ tag: coin-vitg16-384-16x8x3
8
+ tasks_per_node: 8
9
+ experiment:
10
+ classifier:
11
+ num_heads: 16
12
+ num_probe_blocks: 4
13
+ data:
14
+ dataset_type: VideoDataset
15
+ dataset_train: /your_data_folder/COIN/train_paths.csv
16
+ dataset_val: /your_data_folder/COIN/val_paths.csv
17
+ frame_step: 4
18
+ frames_per_clip: 16
19
+ num_classes: 180
20
+ num_segments: 8
21
+ num_views_per_segment: 3
22
+ resolution: 384
23
+ optimization:
24
+ batch_size: 1
25
+ multihead_kwargs:
26
+ - final_lr: 0.0
27
+ final_weight_decay: 0.01
28
+ lr: 0.005
29
+ start_lr: 0.005
30
+ warmup: 0.0
31
+ weight_decay: 0.01
32
+ - final_lr: 0.0
33
+ final_weight_decay: 0.01
34
+ lr: 0.003
35
+ start_lr: 0.003
36
+ warmup: 0.0
37
+ weight_decay: 0.01
38
+ - final_lr: 0.0
39
+ final_weight_decay: 0.01
40
+ lr: 0.001
41
+ start_lr: 0.001
42
+ warmup: 0.0
43
+ weight_decay: 0.01
44
+ - final_lr: 0.0
45
+ final_weight_decay: 0.01
46
+ lr: 0.0003
47
+ start_lr: 0.0003
48
+ warmup: 0.0
49
+ weight_decay: 0.01
50
+ - final_lr: 0.0
51
+ final_weight_decay: 0.01
52
+ lr: 0.0001
53
+ start_lr: 0.0001
54
+ warmup: 0.0
55
+ weight_decay: 0.01
56
+ - final_lr: 0.0
57
+ final_weight_decay: 0.1
58
+ lr: 0.005
59
+ start_lr: 0.005
60
+ warmup: 0.0
61
+ weight_decay: 0.1
62
+ - final_lr: 0.0
63
+ final_weight_decay: 0.1
64
+ lr: 0.003
65
+ start_lr: 0.003
66
+ warmup: 0.0
67
+ weight_decay: 0.1
68
+ - final_lr: 0.0
69
+ final_weight_decay: 0.1
70
+ lr: 0.001
71
+ start_lr: 0.001
72
+ warmup: 0.0
73
+ weight_decay: 0.1
74
+ - final_lr: 0.0
75
+ final_weight_decay: 0.1
76
+ lr: 0.0003
77
+ start_lr: 0.0003
78
+ warmup: 0.0
79
+ weight_decay: 0.1
80
+ - final_lr: 0.0
81
+ final_weight_decay: 0.1
82
+ lr: 0.0001
83
+ start_lr: 0.0001
84
+ warmup: 0.0
85
+ weight_decay: 0.1
86
+ - final_lr: 0.0
87
+ final_weight_decay: 0.4
88
+ lr: 0.005
89
+ start_lr: 0.005
90
+ warmup: 0.0
91
+ weight_decay: 0.4
92
+ - final_lr: 0.0
93
+ final_weight_decay: 0.4
94
+ lr: 0.003
95
+ start_lr: 0.003
96
+ warmup: 0.0
97
+ weight_decay: 0.4
98
+ - final_lr: 0.0
99
+ final_weight_decay: 0.4
100
+ lr: 0.001
101
+ start_lr: 0.001
102
+ warmup: 0.0
103
+ weight_decay: 0.4
104
+ - final_lr: 0.0
105
+ final_weight_decay: 0.4
106
+ lr: 0.0003
107
+ start_lr: 0.0003
108
+ warmup: 0.0
109
+ weight_decay: 0.4
110
+ - final_lr: 0.0
111
+ final_weight_decay: 0.4
112
+ lr: 0.0001
113
+ start_lr: 0.0001
114
+ warmup: 0.0
115
+ weight_decay: 0.4
116
+ - final_lr: 0.0
117
+ final_weight_decay: 0.8
118
+ lr: 0.005
119
+ start_lr: 0.005
120
+ warmup: 0.0
121
+ weight_decay: 0.8
122
+ - final_lr: 0.0
123
+ final_weight_decay: 0.8
124
+ lr: 0.003
125
+ start_lr: 0.003
126
+ warmup: 0.0
127
+ weight_decay: 0.8
128
+ - final_lr: 0.0
129
+ final_weight_decay: 0.8
130
+ lr: 0.001
131
+ start_lr: 0.001
132
+ warmup: 0.0
133
+ weight_decay: 0.8
134
+ - final_lr: 0.0
135
+ final_weight_decay: 0.8
136
+ lr: 0.0003
137
+ start_lr: 0.0003
138
+ warmup: 0.0
139
+ weight_decay: 0.8
140
+ - final_lr: 0.0
141
+ final_weight_decay: 0.8
142
+ lr: 0.0001
143
+ start_lr: 0.0001
144
+ warmup: 0.0
145
+ weight_decay: 0.8
146
+ num_epochs: 20
147
+ use_bfloat16: true
148
+ use_pos_embed: false
149
+ model_kwargs:
150
+ checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
151
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
152
+ pretrain_kwargs:
153
+ encoder:
154
+ checkpoint_key: target_encoder
155
+ img_temporal_dim_size: null
156
+ model_name: vit_giant_xformers
157
+ patch_size: 16
158
+ tubelet_size: 2
159
+ uniform_power: true
160
+ use_rope: true
161
+ wrapper_kwargs:
162
+ max_frames: 128
163
+ use_pos_embed: false