Ryukijano commited on
Commit
8682216
·
verified ·
1 Parent(s): cc13c69

Deploy minimal DINO-Endo Space app

Browse files
Files changed (13) hide show
  1. .dockerignore +15 -0
  2. .gitignore +3 -0
  3. Dockerfile +38 -0
  4. README.md +130 -14
  5. app.py +305 -243
  6. model/__init__.py +0 -0
  7. model/mstcn.py +183 -0
  8. model/resnet.py +19 -0
  9. model/transformer.py +246 -0
  10. model_registry.py +156 -0
  11. predictor.py +642 -0
  12. requirements.txt +13 -4
  13. scripts/smoke_test.py +57 -0
.dockerignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.so
4
+ *.egg-info/
5
+ .git/
6
+ .gitignore
7
+ .cache/
8
+ .pytest_cache/
9
+ .mypy_cache/
10
+ .streamlit/
11
+ .env
12
+ .env.*
13
+ venv/
14
+ .venv/
15
+ *.ipynb
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .cache/
2
+ __pycache__/
3
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ SPACE_MODEL_DIR=/app/model \
7
+ SPACE_ENABLED_MODELS=dinov2 \
8
+ SPACE_DEFAULT_MODEL=dinov2
9
+
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ bash \
12
+ curl \
13
+ wget \
14
+ procps \
15
+ python3 \
16
+ python3-pip \
17
+ python3-venv \
18
+ git \
19
+ git-lfs \
20
+ ffmpeg \
21
+ libgl1 \
22
+ libglib2.0-0 \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ RUN useradd -m -u 1000 user && mkdir -p /app && chown user:user /app
26
+ USER user
27
+ ENV HOME=/home/user \
28
+ PATH=/home/user/.local/bin:$PATH
29
+ WORKDIR /app
30
+
31
+ COPY --chown=user requirements.txt /app/requirements.txt
32
+ RUN python3 -m pip install --upgrade pip && \
33
+ python3 -m pip install -r requirements.txt
34
+
35
+ COPY --chown=user . /app
36
+
37
+ EXPOSE 7860
38
+ CMD ["python3", "-m", "streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.headless=true"]
README.md CHANGED
@@ -1,14 +1,130 @@
1
- ---
2
- title: LLaMA Mesh
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- license: llama3.1
11
- short_description: Create 3D mesh by chatting.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DINO-ENDO Phase Recognition
3
+ emoji: 🩺
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # DINO-ENDO Streamlit Space
11
+
12
+ This folder is an isolated Hugging Face Space scaffold for the phase-recognition models in this repository.
13
+ It is intentionally separate from the existing FastAPI webapp and defaults to a **DINO-Endo demo** on paid GPU hardware such as **1x A10G (24 GB VRAM)**.
14
+ The same code can still expose AI-Endo and V-JEPA2 when you opt into them through environment variables.
15
+
16
+ ## Supported model families
17
+
18
+ - **AI-Endo**
19
+ - `resnet50.pth`
20
+ - `fusion.pth`
21
+ - `transformer.pth`
22
+ - **DINO-Endo**
23
+ - `dinov2_vit14s_latest_checkpoint.pth`
24
+ - `fusion_transformer_decoder_best_model.pth`
25
+ - optional `dinov2_decoder.pth`
26
+ - vendored `dinov2/` source tree
27
+ - **V-JEPA2**
28
+ - `vjepa_encoder_human.pt`
29
+ - `mlp_decoder_human.pth`
30
+ - vendored `vjepa2/` source tree
31
+
32
+ ## Weight delivery strategy
33
+
34
+ The default design is:
35
+
36
+ 1. Keep the **Space repo mostly code-only**.
37
+ 2. Upload weights to one or more **Hugging Face model repos**.
38
+ 3. Let the Space populate `model/` (or `SPACE_MODEL_DIR`) on demand via `huggingface_hub`.
39
+
40
+ This works better than checking all weights directly into the Space repo because code and weights stay versioned separately and Space rebuilds stay lighter.
41
+ A fully local `model/` folder is still supported as a fallback.
42
+
43
+ ## Default Space behavior
44
+
45
+ The Docker Space is configured to boot as a **DINO-Endo-first demo**:
46
+
47
+ - `SPACE_ENABLED_MODELS=dinov2`
48
+ - `SPACE_DEFAULT_MODEL=dinov2`
49
+
50
+ If you want the same Space build to expose multiple model families again, override those environment variables in Space Settings, for example:
51
+
52
+ ```text
53
+ SPACE_ENABLED_MODELS=dinov2,aiendo,vjepa2
54
+ SPACE_DEFAULT_MODEL=dinov2
55
+ ```
56
+
57
+ The Dockerfile is also set up to be **HF Dev Mode compatible**:
58
+
59
+ - app code lives under `/app`
60
+ - `/app` is owned by uid `1000`
61
+ - the required Dev Mode packages (`bash`, `curl`, `wget`, `procps`, `git`, `git-lfs`) are installed
62
+
63
+ ## Runtime configuration
64
+
65
+ The app looks for model files in `SPACE_MODEL_DIR` first (default: `./model`).
66
+ If a required checkpoint is missing locally, it will try to download it from the configured model repo(s).
67
+
68
+ ### Common environment variables
69
+
70
+ - `SPACE_ENABLED_MODELS` — comma-separated list of model families to expose in the UI
71
+ - `SPACE_DEFAULT_MODEL` — default selected model when multiple model families are enabled
72
+ - `SPACE_MODEL_DIR` — local directory where checkpoints should live (default: `./model`)
73
+ - `PHASE_MODEL_REPO_ID` — shared HF model repo for all weights
74
+ - `PHASE_MODEL_REVISION` — optional shared revision/tag/commit
75
+ - `HF_TOKEN` — only needed for private or gated repos
76
+
77
+ If `HF_HOME` / `HF_HUB_CACHE` are not set explicitly, the app will automatically use persistent `/data` storage when it exists and otherwise fall back to a local cache inside the Space folder.
78
+
79
+ ### Per-model overrides
80
+
81
+ - `AIENDO_MODEL_REPO_ID`, `DINO_MODEL_REPO_ID`, `VJEPA2_MODEL_REPO_ID`
82
+ - `AIENDO_MODEL_REVISION`, `DINO_MODEL_REVISION`, `VJEPA2_MODEL_REVISION`
83
+ - `AIENDO_MODEL_SUBFOLDER`, `DINO_MODEL_SUBFOLDER`, `VJEPA2_MODEL_SUBFOLDER`
84
+
85
+ Use subfolder env vars if you store multiple model families in one repo under different directories.
86
+
87
+ ## Local development vs. publishing
88
+
89
+ The required vendored `dinov2/` and `vjepa2/` source trees are now staged inside this folder, so the Space scaffold is self-contained.
90
+ If those upstream source trees change and you want to refresh the copies here, run:
91
+
92
+ ```bash
93
+ python scripts/stage_vendor_sources.py --overwrite
94
+ ```
95
+
96
+ That script refreshes the vendored source copies inside this folder before publishing.
97
+
98
+ ## Publishing checklist
99
+
100
+ 1. Populate the Space folder files here.
101
+ 2. Run `python scripts/stage_vendor_sources.py --overwrite` if you need to refresh the vendored source copies.
102
+ 3. Push the contents of this folder to a Hugging Face **Docker Space**.
103
+ 4. Upload your checkpoints to HF **model repo(s)**.
104
+ 5. Configure the relevant repo IDs (and `HF_TOKEN` only if the repos are private).
105
+
106
+ ## Local smoke test
107
+
108
+ Once the Space dependencies are installed, you can smoke test a predictor directly:
109
+
110
+ ```bash
111
+ python scripts/smoke_test.py --model dinov2 --model-dir /path/to/model
112
+ python scripts/smoke_test.py --model aiendo --model-dir /path/to/model
113
+ python scripts/smoke_test.py --model vjepa2 --model-dir /path/to/model
114
+ ```
115
+
116
+ ## Scope of v1
117
+
118
+ - Streamlit UI
119
+ - DINO-Endo demo by default, with optional multi-model selector when enabled
120
+ - image upload and video upload
121
+ - per-frame phase timeline output for video
122
+ - JSON / CSV export
123
+
124
+ Not included in v1:
125
+
126
+ - auth / user management
127
+ - SQL database
128
+ - PDF/HTML report generation
129
+ - background queue processing
130
+ - polyp segmentation
app.py CHANGED
@@ -1,243 +1,305 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
-
4
- import gradio as gr
5
- import os
6
- # import spaces
7
- from transformers import GemmaTokenizer, AutoModelForCausalLM
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
- from threading import Thread
10
- import torch
11
-
12
- # Set an environment variable
13
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
-
15
-
16
- DESCRIPTION = '''
17
- <div>
18
- <h1 style="text-align: center;">LLaMA-Mesh</h1>
19
- <div>
20
- <a style="display:inline-block" href="https://research.nvidia.com/labs/toronto-ai/LLaMA-Mesh/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
21
- <a style="display:inline-block; margin-left: .5em" href="https://github.com/nv-tlabs/LLaMA-Mesh"><img src='https://img.shields.io/github/stars/nv-tlabs/LLaMA-Mesh?style=social'/></a>
22
- </div>
23
- <p>LLaMA-Mesh: Unifying 3D Mesh Generation with Language Models. <a style="display:inline-block" href="https://research.nvidia.com/labs/toronto-ai/LLaMA-Mesh/">[Project Page]</a> <a style="display:inline-block" href="https://github.com/nv-tlabs/LLaMA-Mesh">[Code]</a></p>
24
- <p> Notice: (1) The default token length is 4096. If you observe incomplete generated meshes, try to increase the maximum token length into 8192.</p>
25
- <p>(2) We only support generating a single mesh per dialog round. To generate another mesh, click the "clear" button and start a new dialog.</p>
26
- <p>(3) If the LLM refuses to generate a 3D mesh, try adding more explicit instructions to the prompt, such as "create a 3D model of a table <strong>in OBJ format</strong>." A more effective approach is to request the mesh generation at the start of the dialog.</p>
27
- </div>
28
- '''
29
-
30
- LICENSE = """
31
- <p/>
32
-
33
- ---
34
- Built with Meta Llama 3.1 8B
35
- """
36
-
37
- PLACEHOLDER = """
38
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
39
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaMA-Mesh</h1>
40
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Create 3D meshes by chatting.</p>
41
- </div>
42
- """
43
-
44
-
45
- css = """
46
- h1 {
47
- text-align: center;
48
- display: block;
49
- }
50
-
51
- #duplicate-button {
52
- margin: auto;
53
- color: white;
54
- background: #1565c0;
55
- border-radius: 100vh;
56
- }
57
- """
58
- # Load the tokenizer and model
59
- model_path = "Zhengyi/LLaMA-Mesh"
60
- tokenizer = AutoTokenizer.from_pretrained(model_path)
61
- model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.float16).to('cuda')
62
- terminators = [
63
- tokenizer.eos_token_id,
64
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
65
- ]
66
-
67
-
68
- from trimesh.exchange.gltf import export_glb
69
- import gradio as gr
70
- import trimesh
71
- import numpy as np
72
- import tempfile
73
- def apply_gradient_color(mesh_text):
74
- """
75
- Apply a gradient color to the mesh vertices based on the Y-axis and save as GLB.
76
- Args:
77
- mesh_text (str): The input mesh in OBJ format as a string.
78
- Returns:
79
- str: Path to the GLB file with gradient colors applied.
80
- """
81
- # Load the mesh
82
- temp_file = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
83
- with open(temp_file+".obj", "w") as f:
84
- f.write(mesh_text)
85
- # return temp_file
86
- mesh = trimesh.load_mesh(temp_file+".obj", file_type='obj')
87
-
88
- # Get vertex coordinates
89
- vertices = mesh.vertices
90
- y_values = vertices[:, 1] # Y-axis values
91
-
92
- # Normalize Y values to range [0, 1] for color mapping
93
- y_normalized = (y_values - y_values.min()) / (y_values.max() - y_values.min())
94
-
95
- # Generate colors: Map normalized Y values to RGB gradient (e.g., blue to red)
96
- colors = np.zeros((len(vertices), 4)) # RGBA
97
- colors[:, 0] = y_normalized # Red channel
98
- colors[:, 2] = 1 - y_normalized # Blue channel
99
- colors[:, 3] = 1.0 # Alpha channel (fully opaque)
100
-
101
- # Attach colors to mesh vertices
102
- mesh.visual.vertex_colors = colors
103
-
104
- # Export to GLB format
105
- glb_path = temp_file+".glb"
106
- with open(glb_path, "wb") as f:
107
- f.write(export_glb(mesh))
108
-
109
- return glb_path
110
-
111
- def visualize_mesh(mesh_text):
112
- """
113
- Convert the provided 3D mesh text into a visualizable format.
114
- This function assumes the input is in OBJ format.
115
- """
116
- temp_file = "temp_mesh.obj"
117
- with open(temp_file, "w") as f:
118
- f.write(mesh_text)
119
- return temp_file
120
-
121
- # @spaces.GPU(duration=120)
122
- def chat_llama3_8b(message: str,
123
- history: list,
124
- temperature: float,
125
- max_new_tokens: int
126
- ) -> str:
127
- """
128
- Generate a streaming response using the llama3-8b model.
129
- Args:
130
- message (str): The input message.
131
- history (list): The conversation history used by ChatInterface.
132
- temperature (float): The temperature for generating the response.
133
- max_new_tokens (int): The maximum number of new tokens to generate.
134
- Returns:
135
- str: The generated response.
136
- """
137
- conversation = []
138
- for user, assistant in history:
139
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
140
- conversation.append({"role": "user", "content": message})
141
-
142
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
143
-
144
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
145
-
146
- generate_kwargs = dict(
147
- input_ids= input_ids,
148
- streamer=streamer,
149
- max_new_tokens=max_new_tokens,
150
- do_sample=True,
151
- temperature=temperature,
152
- eos_token_id=terminators,
153
- )
154
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
155
- if temperature == 0:
156
- generate_kwargs['do_sample'] = False
157
-
158
- t = Thread(target=model.generate, kwargs=generate_kwargs)
159
- t.start()
160
-
161
- outputs = []
162
- for text in streamer:
163
- outputs.append(text)
164
- #print(outputs)
165
- yield "".join(outputs)
166
-
167
-
168
- # Gradio block
169
- chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
170
-
171
- with gr.Blocks(fill_height=True, css=css) as demo:
172
- with gr.Column():
173
- gr.Markdown(DESCRIPTION)
174
- # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
175
- with gr.Row():
176
- with gr.Column(scale=3):
177
- gr.ChatInterface(
178
- fn=chat_llama3_8b,
179
- chatbot=chatbot,
180
- fill_height=True,
181
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
182
- additional_inputs=[
183
- gr.Slider(minimum=0,
184
- maximum=1,
185
- step=0.1,
186
- value=0.95,
187
- label="Temperature",
188
- render=False),
189
- gr.Slider(minimum=128,
190
- maximum=8192,
191
- step=1,
192
- value=4096,
193
- label="Max new tokens",
194
- render=False),
195
- ],
196
- examples=[
197
- ['Create a 3D model of a wooden hammer'],
198
- ['Create a 3D model of a pyramid in obj format'],
199
- ['Create a 3D model of a cabinet.'],
200
- ['Create a low poly 3D model of a coffe cup'],
201
- ['Create a 3D model of a table.'],
202
- ["Create a low poly 3D model of a tree."],
203
- ['Write a python code for sorting.'],
204
- ['How to setup a human base on Mars? Give short answer.'],
205
- ['Explain theory of relativity to me like I’m 8 years old.'],
206
- ['What is 9,000 * 9,000?'],
207
- ['Create a 3D model of a soda can.'],
208
- ['Create a 3D model of a sword.'],
209
- ['Create a 3D model of a wooden barrel'],
210
- ['Create a 3D model of a chair.']
211
- ],
212
- cache_examples=False,
213
- )
214
- gr.Markdown(LICENSE)
215
-
216
- with gr.Column(scale=2):
217
- output_model = gr.Model3D(
218
- label="3D Mesh Visualization",
219
- interactive=False,
220
- )
221
- gr.Markdown("You can copy the generated 3d objects in the left and paste in the textbox below. Put the button and you will see the visualization of the 3D mesh.")
222
-
223
- # Add the text box for 3D mesh input and button
224
- mesh_input = gr.Textbox(
225
- label="3D Mesh Input",
226
- placeholder="Paste your 3D mesh in OBJ format here...",
227
- lines=5,
228
- )
229
- visualize_button = gr.Button("Visualize 3D Mesh")
230
-
231
- # Link the button to the visualization function
232
- visualize_button.click(
233
- fn=apply_gradient_color,
234
- inputs=[mesh_input],
235
- outputs=[output_model]
236
- )
237
-
238
- if __name__ == "__main__":
239
- demo.launch()
240
-
241
-
242
-
243
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import tempfile
6
+ import time
7
+ from collections import Counter
8
+ from pathlib import Path
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pandas as pd
13
+ import streamlit as st
14
+ import torch
15
+ from PIL import Image
16
+
17
+ from model_registry import MODEL_SPECS, ensure_model_artifacts, get_model_source_summary
18
+ from predictor import MODEL_LABELS, PHASE_LABELS, create_predictor, normalize_model_key
19
+
20
+ st.set_page_config(page_title="DINO-Endo Phase Recognition", layout="wide")
21
+
22
+
23
+ def _phase_index(phase: str) -> int:
24
+ try:
25
+ return PHASE_LABELS.index(phase)
26
+ except ValueError:
27
+ return -1
28
+
29
+
30
+ def _image_to_rgb(uploaded_file) -> np.ndarray:
31
+ image = Image.open(uploaded_file).convert("RGB")
32
+ return np.array(image)
33
+
34
+
35
+ def _enabled_model_keys() -> list[str]:
36
+ configured = os.getenv("SPACE_ENABLED_MODELS", "").strip()
37
+ if not configured:
38
+ return list(MODEL_SPECS.keys())
39
+
40
+ enabled_keys = []
41
+ seen = set()
42
+ for token in configured.split(","):
43
+ raw = token.strip()
44
+ if not raw:
45
+ continue
46
+ normalized = normalize_model_key(raw)
47
+ if normalized not in MODEL_SPECS:
48
+ raise RuntimeError(f"SPACE_ENABLED_MODELS contains unsupported model '{raw}'")
49
+ if normalized not in seen:
50
+ enabled_keys.append(normalized)
51
+ seen.add(normalized)
52
+
53
+ if not enabled_keys:
54
+ raise RuntimeError("SPACE_ENABLED_MODELS did not resolve to any supported models")
55
+ return enabled_keys
56
+
57
+
58
+ def _default_model_key(enabled_model_keys: list[str]) -> str:
59
+ configured = os.getenv("SPACE_DEFAULT_MODEL", "").strip()
60
+ if not configured:
61
+ return "dinov2" if "dinov2" in enabled_model_keys else enabled_model_keys[0]
62
+
63
+ normalized = normalize_model_key(configured)
64
+ if normalized not in enabled_model_keys:
65
+ raise RuntimeError(
66
+ f"SPACE_DEFAULT_MODEL '{configured}' is not enabled by SPACE_ENABLED_MODELS"
67
+ )
68
+ return normalized
69
+
70
+
71
+ def _space_caption(enabled_model_keys: list[str]) -> str:
72
+ if enabled_model_keys == ["dinov2"]:
73
+ return "Streamlit Hugging Face Space demo for the DINO-Endo phase-recognition stack."
74
+ return "DINO-first Streamlit Hugging Face Space demo for DINO-Endo, AI-Endo, and V-JEPA2."
75
+
76
+
77
+ def _ensure_predictor(model_key: str):
78
+ active_key = st.session_state.get("active_model_key")
79
+ active_predictor = st.session_state.get("active_predictor")
80
+
81
+ if active_predictor is not None and active_key != model_key:
82
+ active_predictor.unload()
83
+ st.session_state.pop("active_predictor", None)
84
+ st.session_state.pop("active_model_key", None)
85
+
86
+ if st.session_state.get("active_predictor") is None:
87
+ with st.spinner(f"Preparing {MODEL_LABELS[model_key]}..."):
88
+ model_dir = ensure_model_artifacts(model_key)
89
+ predictor = create_predictor(model_key, model_dir=str(model_dir))
90
+ predictor.warm_up()
91
+ st.session_state["active_predictor"] = predictor
92
+ st.session_state["active_model_key"] = model_key
93
+
94
+ return st.session_state["active_predictor"]
95
+
96
+
97
+ def _analyse_video(uploaded_file, predictor, frame_stride: int, max_frames: int):
98
+ suffix = Path(uploaded_file.name).suffix or ".mp4"
99
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
100
+ tmp.write(uploaded_file.getbuffer())
101
+ temp_path = Path(tmp.name)
102
+
103
+ capture = cv2.VideoCapture(str(temp_path))
104
+ if not capture.isOpened():
105
+ temp_path.unlink(missing_ok=True)
106
+ raise RuntimeError("Unable to open uploaded video")
107
+
108
+ total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
109
+ fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0)
110
+ progress = st.progress(0)
111
+ status = st.empty()
112
+
113
+ predictor.reset_state()
114
+ records = []
115
+ processed = 0
116
+ frame_index = 0
117
+
118
+ try:
119
+ while True:
120
+ ok, frame = capture.read()
121
+ if not ok:
122
+ break
123
+
124
+ if frame_index % frame_stride != 0:
125
+ frame_index += 1
126
+ continue
127
+
128
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
129
+ started = time.perf_counter()
130
+ result = predictor.predict(rgb)
131
+ elapsed_ms = (time.perf_counter() - started) * 1000.0
132
+
133
+ probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
134
+ record = {
135
+ "frame_index": frame_index,
136
+ "timestamp_sec": round(frame_index / fps, 3) if fps > 0 else None,
137
+ "phase": result.get("phase", "unknown"),
138
+ "phase_id": _phase_index(result.get("phase", "unknown")),
139
+ "confidence": float(result.get("confidence", 0.0)),
140
+ "frames_used": int(result.get("frames_used", processed + 1)),
141
+ "idle": float(probs[0]) if len(probs) > 0 else 0.0,
142
+ "marking": float(probs[1]) if len(probs) > 1 else 0.0,
143
+ "injection": float(probs[2]) if len(probs) > 2 else 0.0,
144
+ "dissection": float(probs[3]) if len(probs) > 3 else 0.0,
145
+ "inference_ms": round(elapsed_ms, 3),
146
+ }
147
+ records.append(record)
148
+ processed += 1
149
+
150
+ if total_frames > 0:
151
+ progress.progress(min(frame_index + 1, total_frames) / total_frames)
152
+ else:
153
+ progress.progress(min(processed / max_frames, 1.0))
154
+ status.caption(f"Processed {processed} sampled frames")
155
+
156
+ frame_index += 1
157
+ if processed >= max_frames:
158
+ break
159
+ finally:
160
+ capture.release()
161
+ temp_path.unlink(missing_ok=True)
162
+ predictor.reset_state()
163
+
164
+ progress.empty()
165
+ status.empty()
166
+ return records, {"fps": fps, "total_frames": total_frames, "sampled_frames": processed}
167
+
168
+
169
+ def _records_to_frame(records):
170
+ if not records:
171
+ return pd.DataFrame(columns=["frame_index", "timestamp_sec", "phase", "confidence"])
172
+ return pd.DataFrame.from_records(records)
173
+
174
+
175
+ def _download_payloads(df: pd.DataFrame):
176
+ json_payload = df.to_json(orient="records", indent=2).encode("utf-8")
177
+ csv_payload = df.to_csv(index=False).encode("utf-8")
178
+ return json_payload, csv_payload
179
+
180
+
181
+ def _render_single_result(result: dict):
182
+ probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
183
+ metrics = st.columns(3)
184
+ metrics[0].metric("Predicted phase", result.get("phase", "unknown").upper())
185
+ metrics[1].metric("Confidence", f"{float(result.get('confidence', 0.0)):.1%}")
186
+ metrics[2].metric("Frames used", int(result.get("frames_used", 1)))
187
+
188
+ prob_df = pd.DataFrame({"phase": list(PHASE_LABELS), "probability": probs})
189
+ st.bar_chart(prob_df.set_index("phase"))
190
+ st.download_button(
191
+ label="Download JSON",
192
+ data=json.dumps(result, indent=2).encode("utf-8"),
193
+ file_name="phase_prediction.json",
194
+ mime="application/json",
195
+ key="download-single-json",
196
+ )
197
+
198
+
199
+ def _render_video_results(records, meta):
200
+ if not records:
201
+ st.warning("No frames were processed from the uploaded video.")
202
+ return
203
+
204
+ df = _records_to_frame(records)
205
+ counts = Counter(df["phase"].tolist())
206
+ dominant_phase, dominant_count = counts.most_common(1)[0]
207
+
208
+ metrics = st.columns(4)
209
+ metrics[0].metric("Sampled frames", int(meta["sampled_frames"]))
210
+ metrics[1].metric("Dominant phase", dominant_phase.upper())
211
+ metrics[2].metric("Mean confidence", f"{df['confidence'].mean():.1%}")
212
+ metrics[3].metric("Average inference", f"{df['inference_ms'].mean():.1f} ms")
213
+
214
+ chart_df = df.copy()
215
+ if "timestamp_sec" in chart_df and chart_df["timestamp_sec"].notna().any():
216
+ chart_df = chart_df.set_index("timestamp_sec")
217
+ else:
218
+ chart_df = chart_df.set_index("frame_index")
219
+
220
+ st.subheader("Confidence timeline")
221
+ st.line_chart(chart_df[["confidence"]])
222
+
223
+ st.subheader("Phase timeline")
224
+ st.line_chart(chart_df[["phase_id"]])
225
+
226
+ st.subheader("Per-frame predictions")
227
+ st.dataframe(df, use_container_width=True, hide_index=True)
228
+
229
+ json_payload, csv_payload = _download_payloads(df)
230
+ left, right = st.columns(2)
231
+ left.download_button("Download JSON", json_payload, file_name="phase_timeline.json", mime="application/json")
232
+ right.download_button("Download CSV", csv_payload, file_name="phase_timeline.csv", mime="text/csv")
233
+
234
+
235
+ def main():
236
+ enabled_model_keys = _enabled_model_keys()
237
+ default_model_key = _default_model_key(enabled_model_keys)
238
+
239
+ st.title("DINO-Endo Surgical Phase Recognition")
240
+ st.caption(_space_caption(enabled_model_keys))
241
+
242
+ st.sidebar.markdown("### Model")
243
+ if len(enabled_model_keys) == 1:
244
+ model_key = enabled_model_keys[0]
245
+ st.sidebar.write(MODEL_LABELS[model_key])
246
+ else:
247
+ model_key = st.sidebar.selectbox(
248
+ "Model",
249
+ options=enabled_model_keys,
250
+ index=enabled_model_keys.index(default_model_key),
251
+ format_func=lambda key: MODEL_LABELS[key],
252
+ )
253
+
254
+ source_summary = get_model_source_summary(model_key)
255
+ st.sidebar.markdown("### Runtime")
256
+ st.sidebar.write(f"CUDA available: `{torch.cuda.is_available()}`")
257
+ if torch.cuda.is_available():
258
+ st.sidebar.write(f"Device: `{torch.cuda.get_device_name(torch.cuda.current_device())}`")
259
+ st.sidebar.write(f"Model dir: `{source_summary['model_dir']}`")
260
+ st.sidebar.write(f"HF repo: `{source_summary['repo_id'] or 'local-only'}`")
261
+ if source_summary["subfolder"]:
262
+ st.sidebar.write(f"Repo subfolder: `{source_summary['subfolder']}`")
263
+
264
+ image_tab, video_tab = st.tabs(["Image", "Video"])
265
+
266
+ with image_tab:
267
+ uploaded_image = st.file_uploader("Upload an RGB frame", type=["png", "jpg", "jpeg"], key="image-uploader")
268
+ if uploaded_image is not None:
269
+ rgb = _image_to_rgb(uploaded_image)
270
+ st.image(rgb, caption=uploaded_image.name, use_container_width=True)
271
+ if st.button("Run image inference", key="run-image"):
272
+ predictor = _ensure_predictor(model_key)
273
+ predictor.reset_state()
274
+ started = time.perf_counter()
275
+ result = predictor.predict(rgb)
276
+ result["inference_ms"] = round((time.perf_counter() - started) * 1000.0, 3)
277
+ predictor.reset_state()
278
+ _render_single_result(result)
279
+
280
+ with video_tab:
281
+ frame_stride = st.slider("Analyze every Nth frame", min_value=1, max_value=30, value=5, step=1)
282
+ max_frames = st.slider("Maximum sampled frames", min_value=10, max_value=600, value=180, step=10)
283
+ uploaded_video = st.file_uploader(
284
+ "Upload a video",
285
+ type=["mp4", "mov", "avi", "mkv", "webm", "m4v"],
286
+ key="video-uploader",
287
+ )
288
+ if uploaded_video is not None:
289
+ st.video(uploaded_video)
290
+ if st.button("Analyze video", key="run-video"):
291
+ predictor = _ensure_predictor(model_key)
292
+ records, meta = _analyse_video(uploaded_video, predictor, frame_stride=frame_stride, max_frames=max_frames)
293
+ _render_video_results(records, meta)
294
+
295
+ if st.sidebar.button("Unload active model"):
296
+ predictor = st.session_state.get("active_predictor")
297
+ if predictor is not None:
298
+ predictor.unload()
299
+ st.session_state.pop("active_predictor", None)
300
+ st.session_state.pop("active_model_key", None)
301
+ st.sidebar.success("Model unloaded")
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()
model/__init__.py ADDED
File without changes
model/mstcn.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import copy
6
+
7
+
8
+ class MultiStageModel(nn.Module):
9
+ def __init__(self, mstcn_stages, mstcn_layers, mstcn_f_maps, mstcn_f_dim, out_features, mstcn_causal_conv, is_train=True, dropout_prob: float = 0.0):
10
+ self.num_stages = mstcn_stages
11
+ self.num_layers = mstcn_layers
12
+ self.num_f_maps = mstcn_f_maps
13
+ self.dim = mstcn_f_dim
14
+ self.num_classes = out_features
15
+ self.causal_conv = mstcn_causal_conv
16
+ self.is_train = is_train
17
+ print(f"num_stages_classification: {self.num_stages}, num_layers: {self.num_layers}, num_f_maps: {self.num_f_maps}, dim: {self.dim}")
18
+ super(MultiStageModel, self).__init__()
19
+ self.stage1 = SingleStageModel(self.num_layers,
20
+ self.num_f_maps,
21
+ self.dim,
22
+ self.num_classes,
23
+ causal_conv=self.causal_conv,
24
+ is_train=is_train,
25
+ dropout_prob=dropout_prob)
26
+ self.stages = SingleStageModel(self.num_layers,
27
+ self.num_f_maps,
28
+ self.num_classes,
29
+ self.num_classes,
30
+ causal_conv=self.causal_conv,
31
+ is_train=is_train,
32
+ dropout_prob=dropout_prob)
33
+
34
+ self.smoothing = False
35
+
36
+ def forward(self, x):
37
+ """
38
+ If is_train is False (inference), return first-stage features [B, num_f_maps, T]
39
+ so downstream Transformer receives 32-d features, matching the working pipeline.
40
+ If is_train is True (training/classification), return stacked class logits.
41
+ """
42
+ out = self.stage1(x)
43
+ if not self.is_train:
44
+ # Inference path: return temporal features (num_f_maps channels)
45
+ return out
46
+
47
+ # Training path: run second stage on class probabilities
48
+ outputs_classes = out.unsqueeze(0)
49
+ out_classes = self.stages(F.softmax(out, dim=1))
50
+ outputs_classes = torch.cat((outputs_classes, out_classes.unsqueeze(0)), dim=0)
51
+ return outputs_classes
52
+
53
+ @staticmethod
54
+ def add_model_specific_args(parser): # pragma: no cover
55
+ mstcn_reg_model_specific_args = parser.add_argument_group(title='mstcn reg specific args options')
56
+ mstcn_reg_model_specific_args.add_argument("--mstcn_stages", default=4, type=int)
57
+ mstcn_reg_model_specific_args.add_argument("--mstcn_layers", default=10, type=int)
58
+ mstcn_reg_model_specific_args.add_argument("--mstcn_f_maps", default=64, type=int)
59
+ mstcn_reg_model_specific_args.add_argument("--mstcn_f_dim", default=2048, type=int)
60
+ mstcn_reg_model_specific_args.add_argument("--mstcn_causal_conv", action='store_true')
61
+ return parser
62
+
63
+
64
+ class SingleStageModel(nn.Module):
65
+ def __init__(self,
66
+ num_layers: int,
67
+ num_f_maps: int,
68
+ dim: int,
69
+ num_classes: int,
70
+ causal_conv: bool = False,
71
+ is_train: bool = True,
72
+ dropout_prob: float = 0.0):
73
+ super(SingleStageModel, self).__init__()
74
+ self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
75
+ self.is_train = is_train
76
+ self.layers = nn.ModuleList([
77
+ copy.deepcopy(DilatedResidualLayer(2 ** i, num_f_maps, num_f_maps, causal_conv=causal_conv, dropout_prob=dropout_prob))
78
+ for i in range(num_layers)
79
+ ])
80
+ if self.is_train:
81
+ self.conv_out_classes = nn.Conv1d(num_f_maps, num_classes, 1)
82
+
83
+ def forward(self, x):
84
+ out = self.conv_1x1(x)
85
+ for layer in self.layers:
86
+ out = layer(out)
87
+ if self.is_train:
88
+ out = self.conv_out_classes(out)
89
+ return out
90
+
91
+
92
+ class DilatedResidualLayer(nn.Module):
93
+ def __init__(self,
94
+ dilation: int,
95
+ in_channels: int,
96
+ out_channels: int,
97
+ causal_conv: bool = False,
98
+ kernel_size: int = 3,
99
+ dropout_prob: float = 0.0):
100
+ super(DilatedResidualLayer, self).__init__()
101
+ self.causal_conv = causal_conv
102
+ self.dilation = dilation
103
+ self.kernel_size = kernel_size
104
+ padding = (dilation * (kernel_size - 1)) if self.causal_conv else dilation
105
+ self.conv_dilated = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
106
+ self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
107
+ self.dropout = nn.Dropout(dropout_prob)
108
+
109
+ self.activation = nn.ReLU(inplace=True)
110
+
111
+ def forward(self, x):
112
+ out = self.activation(self.conv_dilated(x))
113
+ out = self.dropout(out)
114
+ if self.causal_conv:
115
+ out = out[:, :, :-(self.dilation * 2)]
116
+ out = self.activation(self.conv_1x1(out))
117
+ out = self.dropout(out)
118
+ return x + out
119
+
120
+
121
+ class SingleStageModel1(nn.Module):
122
+ def __init__(self,
123
+ num_layers,
124
+ num_f_maps,
125
+ dim,
126
+ num_classes,
127
+ causal_conv=False):
128
+ super(SingleStageModel1, self).__init__()
129
+ self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
130
+
131
+ self.layers = nn.ModuleList([
132
+ copy.deepcopy(
133
+ DilatedResidualLayer(2**i,
134
+ num_f_maps,
135
+ num_f_maps,
136
+ causal_conv=causal_conv))
137
+ for i in range(num_layers)
138
+ ])
139
+ self.conv_out_classes = nn.Conv1d(num_f_maps, num_classes, 1)
140
+
141
+ def forward(self, x):
142
+ out = self.conv_1x1(x)
143
+ for layer in self.layers:
144
+ out = layer(out)
145
+ out_classes = self.conv_out_classes(out)
146
+ return out_classes, out
147
+
148
+ class MultiStageModel1(nn.Module):
149
+ def __init__(self, mstcn_stages, mstcn_layers, mstcn_f_maps, mstcn_f_dim, out_features, mstcn_causal_conv):
150
+ self.num_stages = mstcn_stages # 4 #2
151
+ self.num_layers = mstcn_layers # 10 #5
152
+ self.num_f_maps = mstcn_f_maps # 64 #64
153
+ self.dim = mstcn_f_dim #2048 # 2048
154
+ self.num_classes = out_features # 7
155
+ self.causal_conv = mstcn_causal_conv
156
+ print(
157
+ f"num_stages_classification: {self.num_stages}, num_layers: {self.num_layers}, num_f_maps:"
158
+ f" {self.num_f_maps}, dim: {self.dim}")
159
+ super(MultiStageModel1, self).__init__()
160
+ self.stage1 = SingleStageModel1(self.num_layers,
161
+ self.num_f_maps,
162
+ self.dim,
163
+ self.num_classes,
164
+ causal_conv=self.causal_conv)
165
+ self.stages = nn.ModuleList([
166
+ copy.deepcopy(
167
+ SingleStageModel1(self.num_layers,
168
+ self.num_f_maps,
169
+ self.num_classes,
170
+ self.num_classes,
171
+ causal_conv=self.causal_conv))
172
+ for s in range(self.num_stages - 1)
173
+ ])
174
+ self.smoothing = False
175
+
176
+ def forward(self, x):
177
+ out_classes, _ = self.stage1(x)
178
+ outputs_classes = out_classes.unsqueeze(0)
179
+ for s in self.stages:
180
+ out_classes, out = s(F.softmax(out_classes, dim=1))
181
+ outputs_classes = torch.cat(
182
+ (outputs_classes, out_classes.unsqueeze(0)), dim=0)
183
+ return out
model/resnet.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ from torchvision import models, transforms
5
+ from torchvision.models import ResNet50_Weights
6
+
7
+ # User's ResNet variant (adapted for 2048-d features, no head)
8
+ class ResNet(nn.Module):
9
+ def __init__(self, out_channels=4, has_fc=False):
10
+ super(ResNet, self).__init__()
11
+ self.resnet = torchvision.models.resnet50(pretrained=False)
12
+ if not has_fc:
13
+ self.resnet.fc = nn.Identity() # Output 2048-d features
14
+ else:
15
+ # Keep the original fc layer for compatibility
16
+ pass
17
+
18
+ def forward(self, x):
19
+ return self.resnet(x)
model/transformer.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import math
5
+
6
+
7
+ # some code adapted from https://wmathor.com/index.php/archives/1455/
8
+
9
+
10
+ class ScaledDotProductAttention(nn.Module):
11
+ def __init__(self, d_k, n_heads):
12
+ super(ScaledDotProductAttention, self).__init__()
13
+ self.d_k = d_k
14
+ self.n_heads = n_heads
15
+
16
+ def forward(self, Q, K, V):
17
+ '''
18
+ Q: [batch_size, n_heads, len_q=1, d_k]
19
+ K: [batch_size, n_heads, len_k, d_k]
20
+ V: [batch_size, n_heads, len_v(=len_k), d_v]
21
+ '''
22
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(
23
+ self.d_k) # scores : [batch_size, n_heads, len_q, len_k]
24
+
25
+ attn = nn.Softmax(dim=-1)(scores) # [batch_size, n_heads, len_q, len_q]
26
+ context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
27
+ return context, attn
28
+
29
+
30
+ class MultiHeadAttention(nn.Module):
31
+ def __init__(self, d_model, d_k, d_v, n_heads, len_q, len_k):
32
+ super(MultiHeadAttention, self).__init__()
33
+
34
+ self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
35
+ self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
36
+ self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
37
+ self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) # Linear only change the last dimension
38
+
39
+ self.d_model = d_model
40
+ self.d_k = d_k
41
+ self.d_v = d_v
42
+ self.n_heads = n_heads
43
+ self.ScaledDotProductAttention = ScaledDotProductAttention(self.d_k, n_heads)
44
+ self.len_q = len_q
45
+ self.len_k = len_k
46
+
47
+ def forward(self, input_Q, input_K, input_V):
48
+ '''
49
+ input_Q: [batch_size, len_q, d_model] [512, 1, 5] --> Spatial info
50
+ input_K: [batch_size, len_k, d_model] [512, 30, 5] --> Temporal info
51
+ input_V: [batch_size, len_v(=len_k), d_model] [512, 30, 5] --> Temporal info
52
+ '''
53
+ residual, batch_size = input_Q, input_Q.size(0)
54
+ # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
55
+ Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # Q: [batch_size, n_heads, len_q, d_k]
56
+
57
+ K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # K: [batch_size, n_heads, len_k, d_k]
58
+
59
+ V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) # V: [batch_size, n_heads, len_v(=len_k), d_v]
60
+
61
+ # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
62
+ context, attn = self.ScaledDotProductAttention(Q, K, V)
63
+ context = context.transpose(1, 2).reshape(batch_size, -1,
64
+ self.n_heads * self.d_v) # context: [batch_size, len_q, n_heads * d_v]
65
+ output = self.fc(context) # [batch_size, len_q, d_model]
66
+ layer_norm = nn.LayerNorm(self.d_model).to(output.device)
67
+ return layer_norm(output + residual), attn # All batch size dimensions are reserved.
68
+
69
+
70
+ class PoswiseFeedForwardNet(nn.Module):
71
+ def __init__(self, d_model, d_ff):
72
+ super(PoswiseFeedForwardNet, self).__init__()
73
+ self.fc = nn.Sequential(
74
+ nn.Linear(d_model, d_ff, bias=False),
75
+ nn.ReLU(),
76
+ nn.Linear(d_ff, d_model, bias=False)
77
+ )
78
+ self.d_model = d_model
79
+
80
+ def forward(self, inputs):
81
+ '''
82
+ inputs: [batch_size, seq_len, d_model]
83
+ '''
84
+ residual = inputs
85
+ output = self.fc(inputs)
86
+ layer_norm = nn.LayerNorm(self.d_model).to(output.device)
87
+ return layer_norm(output + residual) # [batch_size, seq_len, d_model]
88
+
89
+
90
+ class EncoderLayer(nn.Module):
91
+ def __init__(self, d_model, d_ff, d_k, d_v, n_heads, len_q):
92
+ super(EncoderLayer, self).__init__()
93
+ self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads, 1, len_q)
94
+ self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)
95
+
96
+ def forward(self, enc_inputs):
97
+ '''
98
+ enc_inputs: [batch_size, src_len, d_model]
99
+ '''
100
+ # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
101
+ enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs) # enc_inputs to same Q,K,V
102
+ enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
103
+ return enc_outputs, attn
104
+
105
+
106
+ class Encoder(nn.Module):
107
+ def __init__(self, d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q):
108
+ super(Encoder, self).__init__()
109
+ self.layers = nn.ModuleList([EncoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
110
+
111
+ def forward(self, enc_inputs):
112
+ '''
113
+ enc_inputs: [batch_size, src_len, d_model]
114
+ '''
115
+ enc_outputs = enc_inputs
116
+ enc_self_attns = []
117
+ for layer in self.layers:
118
+ # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
119
+ enc_outputs, enc_self_attn = layer(enc_outputs)
120
+ enc_self_attns.append(enc_self_attn)
121
+ return enc_outputs, enc_self_attns
122
+
123
+
124
+ class DecoderLayer(nn.Module):
125
+ def __init__(self, d_model, d_ff, d_k, d_v, n_heads, len_q):
126
+ super(DecoderLayer, self).__init__()
127
+ self.dec_enc_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads, 1, len_q)
128
+ self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)
129
+
130
+ def forward(self, dec_inputs, enc_outputs):
131
+ '''
132
+ dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5] --> Spatial info
133
+ enc_outputs: [batch_size, src_len, d_model] [512, 30, 5] --> Temporal info
134
+ dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
135
+ dec_enc_attn_mask: [batch_size, tgt_len, src_len]
136
+ '''
137
+ # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
138
+ # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
139
+ dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_inputs, enc_outputs, enc_outputs)
140
+ dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
141
+ return dec_outputs, dec_enc_attn
142
+
143
+
144
+ class Decoder(nn.Module):
145
+ def __init__(self, d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q):
146
+ super(Decoder, self).__init__()
147
+ self.layers = nn.ModuleList([DecoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
148
+
149
+ def forward(self, dec_inputs, enc_outputs):
150
+ '''
151
+ dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5]
152
+ enc_intpus: [batch_size, src_len, d_model] [512, 30, 5]
153
+ enc_outputs: [batsh_size, src_len, d_model]
154
+ '''
155
+ dec_outputs = dec_inputs # self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
156
+ # dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
157
+
158
+ dec_enc_attns = []
159
+ for layer in self.layers:
160
+ # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
161
+ dec_outputs, dec_enc_attn = layer(dec_outputs, enc_outputs)
162
+ dec_enc_attns.append(dec_enc_attn)
163
+ return dec_outputs
164
+
165
+
166
+ # d_model, Embedding Size
167
+ # d_ff, FeedForward dimension
168
+ # d_k = d_v, dimension of K(=Q), V
169
+ # n_layers, number of Encoder of Decoder Layer
170
+ # n_heads, number of heads in Multi-Head Attention
171
+
172
+ class Transformer2_3_1(nn.Module):
173
+ def __init__(self, d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q):
174
+ super(Transformer2_3_1, self).__init__()
175
+ self.encoder = Encoder(d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q)
176
+ self.decoder = Decoder(d_model, d_ff, d_k, d_v, 1, n_heads, len_q)
177
+
178
+ def forward(self, enc_inputs, dec_inputs):
179
+ '''
180
+ enc_inputs: [Frames, src_len, d_model] [512, 30, 5]
181
+ dec_inputs: [Frames, 1, d_model] [512, 1, 5]
182
+ '''
183
+ # tensor to store decoder outputs
184
+ # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
185
+
186
+ # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
187
+ enc_outputs, enc_self_attns = self.encoder(enc_inputs) # Self-attention for temporal features
188
+ dec_outputs = self.decoder(dec_inputs, enc_outputs)
189
+ return dec_outputs
190
+
191
+
192
+ class Transformer(nn.Module):
193
+ def __init__(self, mstcn_f_maps, mstcn_f_dim, out_features, len_q, d_model=None):
194
+ super(Transformer, self).__init__()
195
+ # Use provided d_model (256) else fallback to mstcn_f_maps
196
+ self.d_model = d_model if d_model is not None else mstcn_f_maps
197
+ self.num_classes = out_features
198
+ self.len_q = len_q
199
+
200
+ # Spatial encoder with d_ff = d_model; heads=8; d_k=d_v=d_model
201
+ self.spatial_encoder = EncoderLayer(self.d_model, self.d_model, self.d_model, self.d_model, 8, 5)
202
+ self.transformer = Transformer2_3_1(d_model=self.d_model, d_ff=self.d_model, d_k=self.d_model,
203
+ d_v=self.d_model, n_layers=1, n_heads=8, len_q=len_q)
204
+ self.fc = nn.Linear(mstcn_f_dim, self.d_model, bias=False)
205
+
206
+ # Final head 256 -> num_classes, no bias to match checkpoint
207
+ self.out = nn.Sequential(
208
+ nn.ReLU(),
209
+ nn.Dropout(p=0.1),
210
+ nn.Linear(self.d_model, out_features, bias=False)
211
+ )
212
+
213
+ def forward(self, x, long_feature):
214
+ # x: [B, 256, T]; long_feature: [B, T, 256]
215
+ B, D, T = x.shape
216
+ out_features = x.transpose(1, 2) # [B, T, 256]
217
+
218
+ # Build sliding windows for temporal inputs
219
+ inputs = []
220
+ for i in range(T):
221
+ if i < self.len_q - 1:
222
+ pad = torch.zeros((B, self.len_q - 1 - i, self.d_model), device=x.device)
223
+ win = torch.cat([pad, out_features[:, :i + 1, :]], dim=1)
224
+ else:
225
+ win = out_features[:, i - self.len_q + 1:i + 1, :]
226
+ inputs.append(win)
227
+ inputs = torch.stack(inputs, dim=0).squeeze(1) # [T, B, len_q, 256]
228
+
229
+ # Project long features and create spatial windows
230
+ feas = torch.tanh(self.fc(long_feature)) # [B, T, 256]
231
+ spa_len = min(10, T)
232
+ out_feas = []
233
+ for i in range(T):
234
+ if i < spa_len - 1:
235
+ pad = torch.zeros((B, spa_len - 1 - i, self.d_model), device=feas.device)
236
+ win = torch.cat([pad, feas[:, :i + 1, :]], dim=1)
237
+ else:
238
+ win = out_features[:, i - spa_len + 1:i + 1, :]
239
+ out_feas.append(win)
240
+ out_feas = torch.stack(out_feas, dim=0).squeeze(1)
241
+ out_feas, _ = self.spatial_encoder(out_feas)
242
+
243
+ # Temporal-spatial fusion
244
+ output = self.transformer(inputs, out_feas) # [T, B, 1, 256] collapsed → [T, B, 256]
245
+ output = self.out(output) # [T, B, C]
246
+ return output.transpose(0, 1) # [B, T, C]
model_registry.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Dict, Iterable, Tuple
8
+
9
+ from huggingface_hub import hf_hub_download
10
+ from huggingface_hub.utils import EntryNotFoundError
11
+
12
+ APP_ROOT = Path(__file__).resolve().parent
13
+ MODEL_ROOT = Path(os.environ.get("SPACE_MODEL_DIR", APP_ROOT / "model")).expanduser().resolve()
14
+
15
+
16
+ def _default_hf_home() -> Path:
17
+ data_dir = Path("/data")
18
+ if data_dir.is_dir():
19
+ return data_dir / ".huggingface"
20
+ return APP_ROOT / ".cache" / "huggingface"
21
+
22
+
23
+ HF_HOME = Path(os.environ.setdefault("HF_HOME", str(_default_hf_home()))).expanduser().resolve()
24
+ os.environ.setdefault("HF_HUB_CACHE", str(HF_HOME / "hub"))
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class ModelSpec:
29
+ key: str
30
+ label: str
31
+ required_files: Tuple[str, ...]
32
+ optional_files: Tuple[str, ...] = ()
33
+
34
+
35
+ MODEL_SPECS: Dict[str, ModelSpec] = {
36
+ "aiendo": ModelSpec(
37
+ key="aiendo",
38
+ label="AI-Endo",
39
+ required_files=("resnet50.pth", "fusion.pth", "transformer.pth"),
40
+ ),
41
+ "dinov2": ModelSpec(
42
+ key="dinov2",
43
+ label="DINO-Endo",
44
+ required_files=("dinov2_vit14s_latest_checkpoint.pth", "fusion_transformer_decoder_best_model.pth"),
45
+ optional_files=("dinov2_decoder.pth",),
46
+ ),
47
+ "vjepa2": ModelSpec(
48
+ key="vjepa2",
49
+ label="V-JEPA2",
50
+ required_files=("vjepa_encoder_human.pt", "mlp_decoder_human.pth"),
51
+ ),
52
+ }
53
+
54
+
55
+ def _repo_env_name(model_key: str) -> str:
56
+ prefix = {"aiendo": "AIENDO", "dinov2": "DINO", "vjepa2": "VJEPA2"}[model_key]
57
+ return f"{prefix}_MODEL_REPO_ID"
58
+
59
+
60
+ def _revision_env_name(model_key: str) -> str:
61
+ prefix = {"aiendo": "AIENDO", "dinov2": "DINO", "vjepa2": "VJEPA2"}[model_key]
62
+ return f"{prefix}_MODEL_REVISION"
63
+
64
+
65
+ def _subfolder_env_name(model_key: str) -> str:
66
+ prefix = {"aiendo": "AIENDO", "dinov2": "DINO", "vjepa2": "VJEPA2"}[model_key]
67
+ return f"{prefix}_MODEL_SUBFOLDER"
68
+
69
+
70
+ def get_model_repo_id(model_key: str) -> str | None:
71
+ return os.getenv(_repo_env_name(model_key)) or os.getenv("PHASE_MODEL_REPO_ID")
72
+
73
+
74
+ def get_model_revision(model_key: str) -> str | None:
75
+ return os.getenv(_revision_env_name(model_key)) or os.getenv("PHASE_MODEL_REVISION")
76
+
77
+
78
+ def get_model_subfolder(model_key: str) -> str:
79
+ return (os.getenv(_subfolder_env_name(model_key)) or "").strip("/")
80
+
81
+
82
+ def get_hf_token() -> str | None:
83
+ return os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
84
+
85
+
86
+ def ensure_model_root() -> Path:
87
+ MODEL_ROOT.mkdir(parents=True, exist_ok=True)
88
+ HF_HOME.mkdir(parents=True, exist_ok=True)
89
+ Path(os.environ["HF_HUB_CACHE"]).mkdir(parents=True, exist_ok=True)
90
+ return MODEL_ROOT
91
+
92
+
93
+ def _remote_filename(model_key: str, filename: str) -> str:
94
+ subfolder = get_model_subfolder(model_key)
95
+ return f"{subfolder}/{filename}" if subfolder else filename
96
+
97
+
98
+ def _download_to_model_root(model_key: str, filename: str, *, optional: bool = False) -> Path | None:
99
+ target = ensure_model_root() / filename
100
+ if target.exists():
101
+ return target
102
+
103
+ repo_id = get_model_repo_id(model_key)
104
+ if not repo_id:
105
+ if optional:
106
+ return None
107
+ raise FileNotFoundError(
108
+ f"Missing {filename} in {MODEL_ROOT}. Set { _repo_env_name(model_key) } or PHASE_MODEL_REPO_ID, "
109
+ f"or copy the checkpoint into the local model directory."
110
+ )
111
+
112
+ try:
113
+ downloaded = hf_hub_download(
114
+ repo_id=repo_id,
115
+ filename=_remote_filename(model_key, filename),
116
+ repo_type="model",
117
+ revision=get_model_revision(model_key),
118
+ token=get_hf_token(),
119
+ )
120
+ except EntryNotFoundError:
121
+ if optional:
122
+ return None
123
+ raise
124
+
125
+ downloaded_path = Path(downloaded)
126
+ if downloaded_path.resolve() != target.resolve():
127
+ shutil.copy2(downloaded_path, target)
128
+ return target
129
+
130
+
131
+ def ensure_model_artifacts(model_key: str) -> Path:
132
+ if model_key not in MODEL_SPECS:
133
+ raise KeyError(f"Unknown model key: {model_key}")
134
+
135
+ spec = MODEL_SPECS[model_key]
136
+ ensure_model_root()
137
+
138
+ for filename in spec.required_files:
139
+ _download_to_model_root(model_key, filename, optional=False)
140
+ for filename in spec.optional_files:
141
+ _download_to_model_root(model_key, filename, optional=True)
142
+
143
+ return MODEL_ROOT
144
+
145
+
146
+ def get_model_source_summary(model_key: str) -> dict:
147
+ spec = MODEL_SPECS[model_key]
148
+ return {
149
+ "label": spec.label,
150
+ "model_dir": str(MODEL_ROOT),
151
+ "repo_id": get_model_repo_id(model_key),
152
+ "revision": get_model_revision(model_key),
153
+ "subfolder": get_model_subfolder(model_key),
154
+ "required_files": list(spec.required_files),
155
+ "optional_files": list(spec.optional_files),
156
+ }
predictor.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ from contextlib import nullcontext
6
+ from pathlib import Path
7
+
8
+ import albumentations as A
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ try:
16
+ from torch.amp import autocast
17
+ MIXED_PRECISION_AVAILABLE = True
18
+ except ImportError: # pragma: no cover
19
+ MIXED_PRECISION_AVAILABLE = False
20
+
21
+ from model.resnet import ResNet
22
+ from model.mstcn import MultiStageModel
23
+ from model.transformer import Transformer
24
+
25
+ PHASE_LABELS = ("idle", "marking", "injection", "dissection")
26
+ MODEL_LABELS = {
27
+ "aiendo": "AI-Endo",
28
+ "dinov2": "DINO-Endo",
29
+ "vjepa2": "V-JEPA2",
30
+ }
31
+
32
+
33
+ def _app_root() -> Path:
34
+ return Path(__file__).resolve().parent
35
+
36
+
37
+ def default_model_dir() -> str:
38
+ return str(Path(os.environ.get("SPACE_MODEL_DIR", _app_root() / "model")).expanduser().resolve())
39
+
40
+
41
+ def normalize_model_key(name: str | None) -> str:
42
+ token = (name or "aiendo").lower().replace("-", "").replace("_", "").strip()
43
+ if token in ("aiendo", "resnet", "aiendoresnet", "aiendoresnetmstcn", "aiendoresnetmstcntransformer"):
44
+ return "aiendo"
45
+ if token in ("dinov2", "dinov2endo", "dinoendo", "dino"):
46
+ return "dinov2"
47
+ if token in ("vjepa2", "vjepa", "vjepa2endo"):
48
+ return "vjepa2"
49
+ raise KeyError(f"Unsupported model key: {name}")
50
+
51
+
52
+ def _load_trusted_checkpoint(path: str, map_location="cpu"):
53
+ try:
54
+ return torch.load(path, map_location=map_location, weights_only=False)
55
+ except TypeError: # pragma: no cover
56
+ return torch.load(path, map_location=map_location)
57
+
58
+
59
+ def _strip_state_dict_prefixes(state_dict, prefixes):
60
+ cleaned_state = {}
61
+ for key, value in state_dict.items():
62
+ while any(key.startswith(prefix) for prefix in prefixes):
63
+ for prefix in prefixes:
64
+ if key.startswith(prefix):
65
+ key = key[len(prefix):]
66
+ cleaned_state[key] = value
67
+ return cleaned_state
68
+
69
+
70
+ def _validate_load_result(
71
+ load_result,
72
+ model_name: str,
73
+ *,
74
+ allowed_missing=(),
75
+ allowed_missing_prefixes=(),
76
+ allowed_unexpected=(),
77
+ allowed_unexpected_prefixes=(),
78
+ ):
79
+ missing = [
80
+ key
81
+ for key in load_result.missing_keys
82
+ if key not in allowed_missing and not any(key.startswith(prefix) for prefix in allowed_missing_prefixes)
83
+ ]
84
+ unexpected = [
85
+ key
86
+ for key in load_result.unexpected_keys
87
+ if key not in allowed_unexpected and not any(key.startswith(prefix) for prefix in allowed_unexpected_prefixes)
88
+ ]
89
+ if missing or unexpected:
90
+ problems = []
91
+ if missing:
92
+ problems.append(f"missing={missing[:10]}")
93
+ if unexpected:
94
+ problems.append(f"unexpected={unexpected[:10]}")
95
+ raise RuntimeError(f"{model_name} checkpoint mismatch ({'; '.join(problems)})")
96
+
97
+
98
+ def _resolve_vendor_repo(repo_name: str, extra_candidates=()):
99
+ app_root = _app_root()
100
+ candidates = [app_root / repo_name]
101
+ if len(app_root.parents) >= 2:
102
+ candidates.append(app_root.parents[1] / repo_name)
103
+ candidates.extend(extra_candidates)
104
+
105
+ for candidate in candidates:
106
+ if candidate and candidate.exists():
107
+ return candidate
108
+ raise FileNotFoundError(f"Required vendor repo '{repo_name}' not found. Stage it into this folder or keep the repo-root copy available.")
109
+
110
+
111
+ class Predictor:
112
+ def __init__(self, model_dir: str | None = None, device: str = "cuda"):
113
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
114
+ self.model_dir = model_dir or default_model_dir()
115
+ self.seq_length = 1024
116
+ self.trans_seq = 30
117
+ self.aug = A.Compose([A.Resize(height=224, width=224), A.Normalize()])
118
+ self.frame_feature_cache = None
119
+ self.label_dict = dict(enumerate(PHASE_LABELS))
120
+ self.available = False
121
+
122
+ self._norm_mean = None
123
+ self._norm_std = None
124
+ if self.device.type == "cuda":
125
+ self._norm_mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
126
+ self._norm_std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
127
+
128
+ self._load_models(self.model_dir)
129
+
130
+ def _load_models(self, model_dir: str):
131
+ self.resnet = ResNet(out_channels=4, has_fc=False)
132
+ paras = torch.load(os.path.join(model_dir, "resnet50.pth"), map_location=self.device)["model"]
133
+ paras = {k: v for k, v in paras.items() if "fc" not in k and "embed" not in k}
134
+ paras = {k.replace("share.", "resnet."): v for k, v in paras.items()}
135
+ self.resnet.load_state_dict(paras, strict=True)
136
+ self.resnet.to(self.device).eval()
137
+
138
+ self.fusion = MultiStageModel(
139
+ mstcn_stages=2,
140
+ mstcn_layers=8,
141
+ mstcn_f_maps=32,
142
+ mstcn_f_dim=2048,
143
+ out_features=4,
144
+ mstcn_causal_conv=True,
145
+ is_train=False,
146
+ )
147
+ fusion_weights = torch.load(os.path.join(model_dir, "fusion.pth"), map_location=self.device)
148
+ fusion_load = self.fusion.load_state_dict(fusion_weights, strict=False)
149
+ _validate_load_result(
150
+ fusion_load,
151
+ "AI-Endo fusion",
152
+ allowed_unexpected_prefixes=("stage1.conv_out_classes.",),
153
+ )
154
+ self.fusion.to(self.device).eval()
155
+
156
+ self.transformer = Transformer(32, 2048, 4, 30, d_model=32)
157
+ trans_weights = torch.load(os.path.join(model_dir, "transformer.pth"), map_location=self.device)
158
+ self.transformer.load_state_dict(trans_weights)
159
+ self.transformer.to(self.device).eval()
160
+ self.available = True
161
+
162
+ def _amp_context(self):
163
+ return autocast("cuda") if MIXED_PRECISION_AVAILABLE and self.device.type == "cuda" else nullcontext()
164
+
165
+ def _preprocess_gpu(self, rgb_image: np.ndarray) -> torch.Tensor:
166
+ tensor = torch.from_numpy(rgb_image).permute(2, 0, 1).unsqueeze(0)
167
+ tensor = tensor.to(self.device, dtype=torch.float32, non_blocking=True).div_(255.0)
168
+ if tensor.shape[-2:] != (224, 224):
169
+ tensor = F.interpolate(tensor, size=(224, 224), mode="bilinear", align_corners=False)
170
+ return (tensor - self._norm_mean) / self._norm_std
171
+
172
+ def warm_up(self):
173
+ dummy = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
174
+ self.predict(dummy)
175
+ self.reset_state()
176
+
177
+ def reset_state(self):
178
+ self.frame_feature_cache = None
179
+ if torch.cuda.is_available():
180
+ torch.cuda.empty_cache()
181
+
182
+ def unload(self):
183
+ self.available = False
184
+ self.resnet.to("cpu")
185
+ self.fusion.to("cpu")
186
+ self.transformer.to("cpu")
187
+ self.resnet = None
188
+ self.fusion = None
189
+ self.transformer = None
190
+ self.frame_feature_cache = None
191
+ if torch.cuda.is_available():
192
+ torch.cuda.empty_cache()
193
+
194
+ def _cache_features(self, feature: torch.Tensor):
195
+ if self.frame_feature_cache is None:
196
+ self.frame_feature_cache = feature
197
+ elif self.frame_feature_cache.shape[0] > self.seq_length:
198
+ self.frame_feature_cache = torch.cat([self.frame_feature_cache[1:], feature], dim=0)
199
+ else:
200
+ self.frame_feature_cache = torch.cat([self.frame_feature_cache, feature], dim=0)
201
+
202
+ @torch.inference_mode()
203
+ def predict(self, rgb_image: np.ndarray):
204
+ if self._norm_mean is not None:
205
+ tensor = self._preprocess_gpu(rgb_image)
206
+ else:
207
+ processed = self.aug(image=rgb_image)["image"]
208
+ chw = np.transpose(processed, (2, 0, 1))
209
+ tensor = torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device)
210
+
211
+ with self._amp_context():
212
+ feature = self.resnet(tensor).clone()
213
+ self._cache_features(feature)
214
+
215
+ if self.frame_feature_cache is None:
216
+ single_frame_feature = feature.unsqueeze(1)
217
+ temporal_input = single_frame_feature.transpose(1, 2)
218
+ temporal_feature = self.fusion(temporal_input)
219
+ outputs = self.transformer(temporal_feature.detach(), single_frame_feature)
220
+ final_logits = outputs[-1, -1, :]
221
+ probs = F.softmax(final_logits.float(), dim=-1)
222
+ pred_np = probs.detach().cpu().numpy()
223
+ confidence = float(np.max(pred_np))
224
+ phase_idx = max(0, min(3, int(np.argmax(pred_np))))
225
+ phase = self.label_dict.get(phase_idx, "idle")
226
+ return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": 1}
227
+
228
+ if self.frame_feature_cache.shape[0] < 30:
229
+ available_frames = self.frame_feature_cache.shape[0] + 1
230
+ cat_frame_feature = torch.cat([self.frame_feature_cache, feature], dim=0).unsqueeze(0)
231
+ temporal_input = cat_frame_feature.transpose(1, 2)
232
+ temporal_feature = self.fusion(temporal_input)
233
+ outputs = self.transformer(temporal_feature.detach(), cat_frame_feature)
234
+ final_logits = outputs[-1, -1, :]
235
+ probs = F.softmax(final_logits.float(), dim=-1)
236
+ pred_np = probs.detach().cpu().numpy()
237
+ confidence = float(np.max(pred_np))
238
+ phase_idx = max(0, min(3, int(np.argmax(pred_np))))
239
+ phase = self.label_dict.get(phase_idx, "idle")
240
+ return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
241
+
242
+ cat_frame_feature = self.frame_feature_cache.unsqueeze(0)
243
+ temporal_input = cat_frame_feature.transpose(1, 2)
244
+ temporal_feature = self.fusion(temporal_input)
245
+ outputs = self.transformer(temporal_feature.detach(), cat_frame_feature)
246
+ final_logits = outputs[-1, -1, :]
247
+ probs = F.softmax(final_logits.float(), dim=-1)
248
+ pred_np = probs.detach().cpu().numpy()
249
+
250
+ confidence = float(np.max(pred_np))
251
+ phase_idx = max(0, min(3, int(np.argmax(pred_np))))
252
+ phase = self.label_dict.get(phase_idx, "idle")
253
+ return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": min(self.trans_seq, self.frame_feature_cache.shape[0])}
254
+
255
+
256
+ class PredictorDinoV2:
257
+ def __init__(self, model_dir: str | None = None, device: str = "cuda"):
258
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
259
+ self.model_dir = model_dir or default_model_dir()
260
+ self.seq_length = 30
261
+ self.available = False
262
+ self.backbone = None
263
+ self.decoder = None
264
+ self.label_dict = dict(enumerate(PHASE_LABELS))
265
+ self.aug = A.Compose([
266
+ A.SmallestMaxSize(max_size=256, interpolation=cv2.INTER_LINEAR),
267
+ A.CenterCrop(height=224, width=224),
268
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
269
+ ])
270
+ self.frame_features = []
271
+ self._load_models(self.model_dir)
272
+
273
+ def _amp_context(self):
274
+ return autocast("cuda") if MIXED_PRECISION_AVAILABLE and self.device.type == "cuda" else nullcontext()
275
+
276
+ def _resolve_local_dino_repo(self):
277
+ candidates = [_app_root() / "dinov2"]
278
+ app_root = _app_root()
279
+ if len(app_root.parents) >= 2:
280
+ candidates.append(app_root.parents[1] / "dinov2")
281
+ candidates.append(Path(torch.hub.get_dir()) / "facebookresearch_dinov2_main")
282
+ for candidate in candidates:
283
+ if (candidate / "hubconf.py").is_file():
284
+ return str(candidate)
285
+ raise FileNotFoundError("Local DINOv2 repo not found. Stage dinov2/ into this folder or keep the repo-root copy available.")
286
+
287
+ def _load_models(self, model_dir: str):
288
+ repo_path = self._resolve_local_dino_repo()
289
+ self.backbone = torch.hub.load(repo_path, "dinov2_vits14", source="local", pretrained=False)
290
+
291
+ encoder_path = os.path.join(model_dir, "dinov2_vit14s_latest_checkpoint.pth")
292
+ if not os.path.exists(encoder_path):
293
+ raise FileNotFoundError("DINOv2 encoder checkpoint not found")
294
+ encoder_checkpoint = _load_trusted_checkpoint(encoder_path, map_location="cpu")
295
+ encoder_state = encoder_checkpoint.get("student", encoder_checkpoint)
296
+ encoder_state = _strip_state_dict_prefixes(encoder_state, ("module.", "model."))
297
+ encoder_load = self.backbone.load_state_dict(encoder_state, strict=False)
298
+ _validate_load_result(encoder_load, "DINOv2 backbone")
299
+ self.backbone.to(self.device).eval()
300
+
301
+ decoder_path = os.path.join(model_dir, "fusion_transformer_decoder_best_model.pth")
302
+ if not os.path.exists(decoder_path):
303
+ raise FileNotFoundError("DINOv2 decoder checkpoint not found")
304
+ decoder_checkpoint = _load_trusted_checkpoint(decoder_path, map_location="cpu")
305
+ decoder_state = decoder_checkpoint.get("state_dict", decoder_checkpoint)
306
+ decoder_state = _strip_state_dict_prefixes(decoder_state, ("module.", "model."))
307
+
308
+ class FusionTransformerDecoder(nn.Module):
309
+ def __init__(self, feature_dim=384, num_classes=4, mstcn_stages=2, mstcn_layers=8, mstcn_f_maps=16, mstcn_f_dim=256, seq_length=30, d_model=256):
310
+ super().__init__()
311
+ self.reduce = nn.Linear(feature_dim, mstcn_f_dim)
312
+ self.mstcn = MultiStageModel(
313
+ mstcn_stages=mstcn_stages,
314
+ mstcn_layers=mstcn_layers,
315
+ mstcn_f_maps=mstcn_f_maps,
316
+ mstcn_f_dim=mstcn_f_dim,
317
+ out_features=num_classes,
318
+ mstcn_causal_conv=True,
319
+ is_train=False,
320
+ )
321
+ self.transformer = Transformer(
322
+ mstcn_f_maps=mstcn_f_maps,
323
+ mstcn_f_dim=mstcn_f_dim,
324
+ out_features=num_classes,
325
+ len_q=seq_length,
326
+ d_model=d_model,
327
+ )
328
+
329
+ def forward(self, x):
330
+ x = x.permute(0, 2, 1)
331
+ x_reduced = self.reduce(x)
332
+ mstcn_input = x_reduced.permute(0, 2, 1)
333
+ temporal_features = self.mstcn(mstcn_input)
334
+ if isinstance(temporal_features, (list, tuple)):
335
+ temporal_features = temporal_features[-1]
336
+ elif isinstance(temporal_features, torch.Tensor) and temporal_features.dim() == 4:
337
+ temporal_features = temporal_features[-1]
338
+
339
+ if temporal_features.shape[1] == mstcn_input.shape[1]:
340
+ transformer_input = temporal_features.detach()
341
+ else:
342
+ transformer_input = mstcn_input.detach()
343
+
344
+ transformer_out = self.transformer(transformer_input, x_reduced)
345
+ return transformer_out.permute(0, 2, 1)
346
+
347
+ self.decoder = FusionTransformerDecoder()
348
+ decoder_load = self.decoder.load_state_dict(decoder_state, strict=False)
349
+ _validate_load_result(
350
+ decoder_load,
351
+ "DINOv2 decoder",
352
+ allowed_unexpected_prefixes=(
353
+ "mstcn.stage1.conv_out_classes.",
354
+ "mstcn.stages.conv_out_classes.",
355
+ ),
356
+ )
357
+ self.decoder.to(self.device).eval()
358
+ self.available = True
359
+
360
+ def reset_state(self):
361
+ self.frame_features = []
362
+ if torch.cuda.is_available():
363
+ torch.cuda.empty_cache()
364
+
365
+ def warm_up(self):
366
+ dummy_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
367
+ self.predict(dummy_img)
368
+ self.reset_state()
369
+
370
+ def unload(self):
371
+ if self.backbone is not None:
372
+ self.backbone.to("cpu")
373
+ if self.decoder is not None:
374
+ self.decoder.to("cpu")
375
+ self.backbone = None
376
+ self.decoder = None
377
+ self.frame_features = []
378
+ self.available = False
379
+ if torch.cuda.is_available():
380
+ torch.cuda.empty_cache()
381
+
382
+ @torch.inference_mode()
383
+ def predict(self, rgb_image: np.ndarray):
384
+ if not self.available or self.backbone is None or self.decoder is None:
385
+ raise RuntimeError("DINO-Endo predictor is not available")
386
+
387
+ processed = self.aug(image=rgb_image)["image"]
388
+ chw = np.transpose(processed, (2, 0, 1))
389
+ tensor = torch.tensor(chw, dtype=torch.float32).unsqueeze(0).to(self.device)
390
+
391
+ with self._amp_context():
392
+ feats = self.backbone.forward_features(tensor)
393
+ if isinstance(feats, dict):
394
+ feats = feats.get("x_norm_clstoken", next(iter(feats.values())))
395
+ if feats.dim() == 3:
396
+ feats = feats.mean(dim=1)
397
+
398
+ self.frame_features.append(feats.squeeze(0).detach().cpu())
399
+ if len(self.frame_features) > self.seq_length:
400
+ self.frame_features = self.frame_features[-self.seq_length:]
401
+
402
+ available_frames = len(self.frame_features)
403
+ seq = torch.stack(self.frame_features[-available_frames:]).unsqueeze(0).to(self.device)
404
+ if available_frames < self.seq_length:
405
+ last_frame = seq[:, -1:, :]
406
+ padding = last_frame.repeat(1, self.seq_length - available_frames, 1)
407
+ seq = torch.cat([seq, padding], dim=1)
408
+
409
+ decoder_input = seq.transpose(1, 2)
410
+ with self._amp_context():
411
+ logits = self.decoder(decoder_input)
412
+
413
+ if logits.dim() != 3:
414
+ raise ValueError(f"Unexpected DINOv2 decoder output shape: {tuple(logits.shape)}")
415
+ if logits.shape[1] == len(self.label_dict):
416
+ last = logits[0, :, -1]
417
+ elif logits.shape[2] == len(self.label_dict):
418
+ last = logits[0, -1, :]
419
+ else:
420
+ raise ValueError(f"Unexpected DINOv2 class dimension in decoder output: {tuple(logits.shape)}")
421
+
422
+ probs = torch.softmax(last, dim=0)
423
+ pred_np = probs.detach().cpu().numpy()
424
+ confidence = float(np.max(pred_np))
425
+ phase_idx = int(np.argmax(pred_np))
426
+ phase = self.label_dict.get(phase_idx, "idle")
427
+ return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
428
+
429
+
430
+ class PredictorVJEPA2:
431
+ def __init__(self, model_dir: str | None = None, device: str = "cuda"):
432
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
433
+ self.model_dir = model_dir or default_model_dir()
434
+ self.available = False
435
+ self.encoder = None
436
+ self.decoder = None
437
+ self.label_dict = dict(enumerate(PHASE_LABELS))
438
+ self._clip_frames = 16
439
+ self._tubelet_size = 2
440
+ self._crop_size = 256
441
+ self._decoder_seq_length = 30
442
+ self._frame_buffer = []
443
+ self._feature_buffer = []
444
+ self._vjepa_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1, 1)
445
+ self._vjepa_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1, 1)
446
+ self._load_models(self.model_dir)
447
+
448
+ def _amp_context(self):
449
+ return autocast("cuda") if MIXED_PRECISION_AVAILABLE and self.device.type == "cuda" else nullcontext()
450
+
451
+ def _resolve_vjepa_repo(self):
452
+ extras = []
453
+ app_root = _app_root()
454
+ if len(app_root.parents) >= 2:
455
+ extras.append(app_root.parents[1] / "webapp" / "vjepa2")
456
+ return _resolve_vendor_repo("vjepa2", extras)
457
+
458
+ @staticmethod
459
+ def _clean_checkpoint_keys(state_dict):
460
+ cleaned_state = {}
461
+ for key, value in state_dict.items():
462
+ while key.startswith("module.") or key.startswith("backbone."):
463
+ if key.startswith("module."):
464
+ key = key[len("module.") :]
465
+ elif key.startswith("backbone."):
466
+ key = key[len("backbone.") :]
467
+ cleaned_state[key] = value
468
+ return cleaned_state
469
+
470
+ @staticmethod
471
+ def _validate_load_result(load_result, model_name: str):
472
+ if load_result.unexpected_keys:
473
+ sample = ", ".join(load_result.unexpected_keys[:5])
474
+ raise RuntimeError(f"{model_name} load had unexpected keys: {sample}")
475
+ if load_result.missing_keys:
476
+ sample = ", ".join(load_result.missing_keys[:5])
477
+ raise RuntimeError(f"{model_name} load missed required keys: {sample}")
478
+
479
+ def _extract_temporal_features(self, features: torch.Tensor) -> torch.Tensor:
480
+ if isinstance(features, dict):
481
+ features = features.get("x_norm_patchtokens", features.get("x_norm_clstoken", next(iter(features.values()))))
482
+
483
+ if features.dim() == 2:
484
+ return features.unsqueeze(1).repeat(1, self._clip_frames, 1)
485
+ if features.dim() != 3:
486
+ raise ValueError(f"Unexpected V-JEPA2 encoder output shape: {tuple(features.shape)}")
487
+
488
+ temporal_tokens = self._clip_frames // self._tubelet_size
489
+ if temporal_tokens <= 0:
490
+ raise ValueError("Invalid V-JEPA2 temporal configuration")
491
+ if features.shape[1] % temporal_tokens != 0:
492
+ raise ValueError(
493
+ f"Cannot reshape V-JEPA2 features of shape {tuple(features.shape)} into {temporal_tokens} temporal groups"
494
+ )
495
+
496
+ spatial_tokens = features.shape[1] // temporal_tokens
497
+ features = features.view(features.shape[0], temporal_tokens, spatial_tokens, features.shape[2]).mean(dim=2)
498
+ return features.repeat_interleave(self._tubelet_size, dim=1)[:, : self._clip_frames, :]
499
+
500
+ def _preprocess_clip(self, frames) -> torch.Tensor:
501
+ resized_frames = [cv2.resize(frame, (self._crop_size, self._crop_size), interpolation=cv2.INTER_LINEAR) for frame in frames]
502
+ clip = np.stack(resized_frames, axis=0).astype(np.float32) / 255.0
503
+ tensor = torch.from_numpy(np.transpose(clip, (3, 0, 1, 2)))
504
+ return (tensor - self._vjepa_mean) / self._vjepa_std
505
+
506
+ def _load_models(self, model_dir: str):
507
+ vjepa2_path = self._resolve_vjepa_repo()
508
+ if str(vjepa2_path) not in sys.path:
509
+ sys.path.insert(0, str(vjepa2_path))
510
+
511
+ from src.models import vision_transformer as vjepa_vit
512
+ from src.utils.checkpoint_loader import robust_checkpoint_loader
513
+
514
+ encoder_path = os.path.join(model_dir, "vjepa_encoder_human.pt")
515
+ if not os.path.exists(encoder_path):
516
+ raise FileNotFoundError("V-JEPA2 encoder not found")
517
+
518
+ checkpoint = robust_checkpoint_loader(encoder_path, map_location=torch.device("cpu"))
519
+ encoder_state = self._clean_checkpoint_keys(checkpoint.get("encoder", checkpoint))
520
+
521
+ self.encoder = vjepa_vit.vit_large(
522
+ patch_size=16,
523
+ num_frames=self._clip_frames,
524
+ tubelet_size=self._tubelet_size,
525
+ img_size=self._crop_size,
526
+ uniform_power=True,
527
+ use_sdpa=True,
528
+ use_rope=True,
529
+ )
530
+ encoder_load = self.encoder.load_state_dict(encoder_state, strict=False)
531
+ self._validate_load_result(encoder_load, "V-JEPA2 encoder")
532
+ self.encoder.to(self.device).eval()
533
+
534
+ decoder_path = os.path.join(model_dir, "mlp_decoder_human.pth")
535
+ if not os.path.exists(decoder_path):
536
+ raise FileNotFoundError("V-JEPA2 MLP decoder not found")
537
+
538
+ decoder_checkpoint = torch.load(decoder_path, map_location="cpu")
539
+ decoder_state = decoder_checkpoint.get("model", decoder_checkpoint)
540
+ decoder_in_dim = int(decoder_checkpoint.get("in_dim", 1024))
541
+ decoder_num_classes = int(decoder_checkpoint.get("num_classes", len(self.label_dict)))
542
+ self._decoder_seq_length = int(decoder_checkpoint.get("seq_length", self._decoder_seq_length))
543
+
544
+ class MLPDecoder(nn.Module):
545
+ def __init__(self, in_dim=1024, hidden_dim=256, num_classes=4):
546
+ super().__init__()
547
+ self.norm = nn.LayerNorm(in_dim)
548
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
549
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
550
+ self.fc3 = nn.Linear(hidden_dim, num_classes)
551
+ self.relu = nn.ReLU()
552
+ self.drop = nn.Dropout(0.5)
553
+
554
+ def forward(self, x):
555
+ x = x.mean(dim=1)
556
+ x = self.norm(x)
557
+ x = self.drop(self.relu(self.fc1(x)))
558
+ x = self.drop(self.relu(self.fc2(x)))
559
+ return self.fc3(x)
560
+
561
+ self.decoder = MLPDecoder(in_dim=decoder_in_dim, num_classes=decoder_num_classes)
562
+ self.decoder.load_state_dict(decoder_state, strict=True)
563
+ self.decoder.to(self.device).eval()
564
+ self.available = True
565
+
566
+ def reset_state(self):
567
+ self._frame_buffer = []
568
+ self._feature_buffer = []
569
+ if torch.cuda.is_available():
570
+ torch.cuda.empty_cache()
571
+
572
+ def warm_up(self):
573
+ dummy = np.random.randint(0, 255, (self._crop_size, self._crop_size, 3), dtype=np.uint8)
574
+ self.predict(dummy)
575
+ self.reset_state()
576
+
577
+ def unload(self):
578
+ if self.encoder is not None:
579
+ self.encoder.to("cpu")
580
+ if self.decoder is not None:
581
+ self.decoder.to("cpu")
582
+ self.encoder = None
583
+ self.decoder = None
584
+ self._frame_buffer = []
585
+ self._feature_buffer = []
586
+ self.available = False
587
+ if torch.cuda.is_available():
588
+ torch.cuda.empty_cache()
589
+
590
+ @torch.inference_mode()
591
+ def predict(self, rgb_image: np.ndarray):
592
+ if not self.available:
593
+ raise RuntimeError("V-JEPA2 predictor is not available")
594
+
595
+ frame = np.ascontiguousarray(rgb_image, dtype=np.uint8)
596
+ self._frame_buffer.append(frame)
597
+ if len(self._frame_buffer) > self._clip_frames:
598
+ self._frame_buffer = self._frame_buffer[-self._clip_frames:]
599
+
600
+ clip_frames = list(self._frame_buffer)
601
+ while len(clip_frames) < self._clip_frames:
602
+ clip_frames.append(clip_frames[-1])
603
+
604
+ tensor = self._preprocess_clip(clip_frames).unsqueeze(0).to(self.device)
605
+ with self._amp_context():
606
+ features = self._extract_temporal_features(self.encoder(tensor))
607
+
608
+ latest_feature_idx = min(len(self._frame_buffer), self._clip_frames) - 1
609
+ latest_feature = features[0, latest_feature_idx].float().detach().cpu()
610
+ self._feature_buffer.append(latest_feature)
611
+ if len(self._feature_buffer) > self._decoder_seq_length:
612
+ self._feature_buffer = self._feature_buffer[-self._decoder_seq_length:]
613
+
614
+ available_frames = len(self._feature_buffer)
615
+ seq = torch.stack(self._feature_buffer, dim=0).unsqueeze(0).to(self.device)
616
+ if available_frames < self._decoder_seq_length:
617
+ padding = seq[:, -1:, :].repeat(1, self._decoder_seq_length - available_frames, 1)
618
+ seq = torch.cat([seq, padding], dim=1)
619
+
620
+ with self._amp_context():
621
+ logits = self.decoder(seq)
622
+
623
+ probs = torch.softmax(logits[0], dim=0)
624
+ pred_np = probs.detach().cpu().numpy()
625
+ confidence = float(np.max(pred_np))
626
+ phase_idx = int(np.argmax(pred_np))
627
+ phase = self.label_dict.get(phase_idx, "idle")
628
+ return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
629
+
630
+
631
+ def create_predictor(model_key: str, model_dir: str | None = None, device: str | None = None):
632
+ resolved_key = normalize_model_key(model_key)
633
+ resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
634
+ resolved_model_dir = model_dir or default_model_dir()
635
+
636
+ if resolved_key == "aiendo":
637
+ return Predictor(model_dir=resolved_model_dir, device=resolved_device)
638
+ if resolved_key == "dinov2":
639
+ return PredictorDinoV2(model_dir=resolved_model_dir, device=resolved_device)
640
+ if resolved_key == "vjepa2":
641
+ return PredictorVJEPA2(model_dir=resolved_model_dir, device=resolved_device)
642
+ raise KeyError(f"Unsupported model key: {model_key}")
requirements.txt CHANGED
@@ -1,4 +1,13 @@
1
- accelerate
2
- transformers
3
- trimesh
4
- numpy
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ streamlit>=1.40,<2
3
+ torch==2.5.1
4
+ torchvision==0.20.1
5
+ numpy>=1.26,<3
6
+ pandas>=2.2,<3
7
+ opencv-python-headless>=4.10,<5
8
+ pillow>=10,<12
9
+ albumentations>=2.0,<3
10
+ huggingface_hub>=0.27,<1
11
+ pyyaml>=6,<7
12
+ timm>=1.0,<2
13
+ einops>=0.8,<1
scripts/smoke_test.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+
9
+ import sys
10
+
11
+ SCRIPT_PATH = Path(__file__).resolve()
12
+ SPACE_ROOT = SCRIPT_PATH.parents[1]
13
+ if str(SPACE_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(SPACE_ROOT))
15
+
16
+ from predictor import create_predictor
17
+
18
+
19
+ MODEL_REQUIREMENTS = {
20
+ "aiendo": ("resnet50.pth", "fusion.pth", "transformer.pth"),
21
+ "dinov2": ("dinov2_vit14s_latest_checkpoint.pth", "fusion_transformer_decoder_best_model.pth"),
22
+ "vjepa2": ("vjepa_encoder_human.pt", "mlp_decoder_human.pth"),
23
+ }
24
+
25
+
26
+ def parse_args() -> argparse.Namespace:
27
+ parser = argparse.ArgumentParser(description="Smoke test the isolated HF Space predictors.")
28
+ parser.add_argument("--model", choices=sorted(MODEL_REQUIREMENTS), required=True)
29
+ parser.add_argument("--model-dir", default=str(SPACE_ROOT / "model"))
30
+ parser.add_argument("--device", default="cuda")
31
+ parser.add_argument("--image-size", type=int, default=256)
32
+ return parser.parse_args()
33
+
34
+
35
+ def main() -> None:
36
+ args = parse_args()
37
+ model_dir = Path(args.model_dir).expanduser().resolve()
38
+ missing = [name for name in MODEL_REQUIREMENTS[args.model] if not (model_dir / name).exists()]
39
+ if missing:
40
+ raise FileNotFoundError(f"Missing required checkpoints in {model_dir}: {', '.join(missing)}")
41
+
42
+ os.environ["SPACE_MODEL_DIR"] = str(model_dir)
43
+ dummy = np.random.randint(0, 255, (args.image_size, args.image_size, 3), dtype=np.uint8)
44
+
45
+ predictor = create_predictor(args.model, model_dir=str(model_dir), device=args.device)
46
+ predictor.reset_state()
47
+ result = predictor.predict(dummy)
48
+ predictor.unload()
49
+
50
+ print(f"model={args.model}")
51
+ print(f"phase={result.get('phase')}")
52
+ print(f"confidence={result.get('confidence')}")
53
+ print(f"frames_used={result.get('frames_used')}")
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()