Deploy minimal DINO-Endo Space app
Browse files- .dockerignore +15 -0
- .gitignore +3 -0
- Dockerfile +38 -0
- README.md +130 -14
- app.py +305 -243
- model/__init__.py +0 -0
- model/mstcn.py +183 -0
- model/resnet.py +19 -0
- model/transformer.py +246 -0
- model_registry.py +156 -0
- predictor.py +642 -0
- requirements.txt +13 -4
- scripts/smoke_test.py +57 -0
.dockerignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*.so
|
| 4 |
+
*.egg-info/
|
| 5 |
+
.git/
|
| 6 |
+
.gitignore
|
| 7 |
+
.cache/
|
| 8 |
+
.pytest_cache/
|
| 9 |
+
.mypy_cache/
|
| 10 |
+
.streamlit/
|
| 11 |
+
.env
|
| 12 |
+
.env.*
|
| 13 |
+
venv/
|
| 14 |
+
.venv/
|
| 15 |
+
*.ipynb
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.cache/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
SPACE_MODEL_DIR=/app/model \
|
| 7 |
+
SPACE_ENABLED_MODELS=dinov2 \
|
| 8 |
+
SPACE_DEFAULT_MODEL=dinov2
|
| 9 |
+
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
bash \
|
| 12 |
+
curl \
|
| 13 |
+
wget \
|
| 14 |
+
procps \
|
| 15 |
+
python3 \
|
| 16 |
+
python3-pip \
|
| 17 |
+
python3-venv \
|
| 18 |
+
git \
|
| 19 |
+
git-lfs \
|
| 20 |
+
ffmpeg \
|
| 21 |
+
libgl1 \
|
| 22 |
+
libglib2.0-0 \
|
| 23 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 24 |
+
|
| 25 |
+
RUN useradd -m -u 1000 user && mkdir -p /app && chown user:user /app
|
| 26 |
+
USER user
|
| 27 |
+
ENV HOME=/home/user \
|
| 28 |
+
PATH=/home/user/.local/bin:$PATH
|
| 29 |
+
WORKDIR /app
|
| 30 |
+
|
| 31 |
+
COPY --chown=user requirements.txt /app/requirements.txt
|
| 32 |
+
RUN python3 -m pip install --upgrade pip && \
|
| 33 |
+
python3 -m pip install -r requirements.txt
|
| 34 |
+
|
| 35 |
+
COPY --chown=user . /app
|
| 36 |
+
|
| 37 |
+
EXPOSE 7860
|
| 38 |
+
CMD ["python3", "-m", "streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.headless=true"]
|
README.md
CHANGED
|
@@ -1,14 +1,130 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DINO-ENDO Phase Recognition
|
| 3 |
+
emoji: 🩺
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# DINO-ENDO Streamlit Space
|
| 11 |
+
|
| 12 |
+
This folder is an isolated Hugging Face Space scaffold for the phase-recognition models in this repository.
|
| 13 |
+
It is intentionally separate from the existing FastAPI webapp and defaults to a **DINO-Endo demo** on paid GPU hardware such as **1x A10G (24 GB VRAM)**.
|
| 14 |
+
The same code can still expose AI-Endo and V-JEPA2 when you opt into them through environment variables.
|
| 15 |
+
|
| 16 |
+
## Supported model families
|
| 17 |
+
|
| 18 |
+
- **AI-Endo**
|
| 19 |
+
- `resnet50.pth`
|
| 20 |
+
- `fusion.pth`
|
| 21 |
+
- `transformer.pth`
|
| 22 |
+
- **DINO-Endo**
|
| 23 |
+
- `dinov2_vit14s_latest_checkpoint.pth`
|
| 24 |
+
- `fusion_transformer_decoder_best_model.pth`
|
| 25 |
+
- optional `dinov2_decoder.pth`
|
| 26 |
+
- vendored `dinov2/` source tree
|
| 27 |
+
- **V-JEPA2**
|
| 28 |
+
- `vjepa_encoder_human.pt`
|
| 29 |
+
- `mlp_decoder_human.pth`
|
| 30 |
+
- vendored `vjepa2/` source tree
|
| 31 |
+
|
| 32 |
+
## Weight delivery strategy
|
| 33 |
+
|
| 34 |
+
The default design is:
|
| 35 |
+
|
| 36 |
+
1. Keep the **Space repo mostly code-only**.
|
| 37 |
+
2. Upload weights to one or more **Hugging Face model repos**.
|
| 38 |
+
3. Let the Space populate `model/` (or `SPACE_MODEL_DIR`) on demand via `huggingface_hub`.
|
| 39 |
+
|
| 40 |
+
This works better than checking all weights directly into the Space repo because code and weights stay versioned separately and Space rebuilds stay lighter.
|
| 41 |
+
A fully local `model/` folder is still supported as a fallback.
|
| 42 |
+
|
| 43 |
+
## Default Space behavior
|
| 44 |
+
|
| 45 |
+
The Docker Space is configured to boot as a **DINO-Endo-first demo**:
|
| 46 |
+
|
| 47 |
+
- `SPACE_ENABLED_MODELS=dinov2`
|
| 48 |
+
- `SPACE_DEFAULT_MODEL=dinov2`
|
| 49 |
+
|
| 50 |
+
If you want the same Space build to expose multiple model families again, override those environment variables in Space Settings, for example:
|
| 51 |
+
|
| 52 |
+
```text
|
| 53 |
+
SPACE_ENABLED_MODELS=dinov2,aiendo,vjepa2
|
| 54 |
+
SPACE_DEFAULT_MODEL=dinov2
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
The Dockerfile is also set up to be **HF Dev Mode compatible**:
|
| 58 |
+
|
| 59 |
+
- app code lives under `/app`
|
| 60 |
+
- `/app` is owned by uid `1000`
|
| 61 |
+
- the required Dev Mode packages (`bash`, `curl`, `wget`, `procps`, `git`, `git-lfs`) are installed
|
| 62 |
+
|
| 63 |
+
## Runtime configuration
|
| 64 |
+
|
| 65 |
+
The app looks for model files in `SPACE_MODEL_DIR` first (default: `./model`).
|
| 66 |
+
If a required checkpoint is missing locally, it will try to download it from the configured model repo(s).
|
| 67 |
+
|
| 68 |
+
### Common environment variables
|
| 69 |
+
|
| 70 |
+
- `SPACE_ENABLED_MODELS` — comma-separated list of model families to expose in the UI
|
| 71 |
+
- `SPACE_DEFAULT_MODEL` — default selected model when multiple model families are enabled
|
| 72 |
+
- `SPACE_MODEL_DIR` — local directory where checkpoints should live (default: `./model`)
|
| 73 |
+
- `PHASE_MODEL_REPO_ID` — shared HF model repo for all weights
|
| 74 |
+
- `PHASE_MODEL_REVISION` — optional shared revision/tag/commit
|
| 75 |
+
- `HF_TOKEN` — only needed for private or gated repos
|
| 76 |
+
|
| 77 |
+
If `HF_HOME` / `HF_HUB_CACHE` are not set explicitly, the app will automatically use persistent `/data` storage when it exists and otherwise fall back to a local cache inside the Space folder.
|
| 78 |
+
|
| 79 |
+
### Per-model overrides
|
| 80 |
+
|
| 81 |
+
- `AIENDO_MODEL_REPO_ID`, `DINO_MODEL_REPO_ID`, `VJEPA2_MODEL_REPO_ID`
|
| 82 |
+
- `AIENDO_MODEL_REVISION`, `DINO_MODEL_REVISION`, `VJEPA2_MODEL_REVISION`
|
| 83 |
+
- `AIENDO_MODEL_SUBFOLDER`, `DINO_MODEL_SUBFOLDER`, `VJEPA2_MODEL_SUBFOLDER`
|
| 84 |
+
|
| 85 |
+
Use subfolder env vars if you store multiple model families in one repo under different directories.
|
| 86 |
+
|
| 87 |
+
## Local development vs. publishing
|
| 88 |
+
|
| 89 |
+
The required vendored `dinov2/` and `vjepa2/` source trees are now staged inside this folder, so the Space scaffold is self-contained.
|
| 90 |
+
If those upstream source trees change and you want to refresh the copies here, run:
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
python scripts/stage_vendor_sources.py --overwrite
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
That script refreshes the vendored source copies inside this folder before publishing.
|
| 97 |
+
|
| 98 |
+
## Publishing checklist
|
| 99 |
+
|
| 100 |
+
1. Populate the Space folder files here.
|
| 101 |
+
2. Run `python scripts/stage_vendor_sources.py --overwrite` if you need to refresh the vendored source copies.
|
| 102 |
+
3. Push the contents of this folder to a Hugging Face **Docker Space**.
|
| 103 |
+
4. Upload your checkpoints to HF **model repo(s)**.
|
| 104 |
+
5. Configure the relevant repo IDs (and `HF_TOKEN` only if the repos are private).
|
| 105 |
+
|
| 106 |
+
## Local smoke test
|
| 107 |
+
|
| 108 |
+
Once the Space dependencies are installed, you can smoke test a predictor directly:
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
python scripts/smoke_test.py --model dinov2 --model-dir /path/to/model
|
| 112 |
+
python scripts/smoke_test.py --model aiendo --model-dir /path/to/model
|
| 113 |
+
python scripts/smoke_test.py --model vjepa2 --model-dir /path/to/model
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
## Scope of v1
|
| 117 |
+
|
| 118 |
+
- Streamlit UI
|
| 119 |
+
- DINO-Endo demo by default, with optional multi-model selector when enabled
|
| 120 |
+
- image upload and video upload
|
| 121 |
+
- per-frame phase timeline output for video
|
| 122 |
+
- JSON / CSV export
|
| 123 |
+
|
| 124 |
+
Not included in v1:
|
| 125 |
+
|
| 126 |
+
- auth / user management
|
| 127 |
+
- SQL database
|
| 128 |
+
- PDF/HTML report generation
|
| 129 |
+
- background queue processing
|
| 130 |
+
- polyp segmentation
|
app.py
CHANGED
|
@@ -1,243 +1,305 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
-
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
|
| 10 |
-
import
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
""
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import time
|
| 7 |
+
from collections import Counter
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import streamlit as st
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from model_registry import MODEL_SPECS, ensure_model_artifacts, get_model_source_summary
|
| 18 |
+
from predictor import MODEL_LABELS, PHASE_LABELS, create_predictor, normalize_model_key
|
| 19 |
+
|
| 20 |
+
st.set_page_config(page_title="DINO-Endo Phase Recognition", layout="wide")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _phase_index(phase: str) -> int:
|
| 24 |
+
try:
|
| 25 |
+
return PHASE_LABELS.index(phase)
|
| 26 |
+
except ValueError:
|
| 27 |
+
return -1
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _image_to_rgb(uploaded_file) -> np.ndarray:
|
| 31 |
+
image = Image.open(uploaded_file).convert("RGB")
|
| 32 |
+
return np.array(image)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _enabled_model_keys() -> list[str]:
|
| 36 |
+
configured = os.getenv("SPACE_ENABLED_MODELS", "").strip()
|
| 37 |
+
if not configured:
|
| 38 |
+
return list(MODEL_SPECS.keys())
|
| 39 |
+
|
| 40 |
+
enabled_keys = []
|
| 41 |
+
seen = set()
|
| 42 |
+
for token in configured.split(","):
|
| 43 |
+
raw = token.strip()
|
| 44 |
+
if not raw:
|
| 45 |
+
continue
|
| 46 |
+
normalized = normalize_model_key(raw)
|
| 47 |
+
if normalized not in MODEL_SPECS:
|
| 48 |
+
raise RuntimeError(f"SPACE_ENABLED_MODELS contains unsupported model '{raw}'")
|
| 49 |
+
if normalized not in seen:
|
| 50 |
+
enabled_keys.append(normalized)
|
| 51 |
+
seen.add(normalized)
|
| 52 |
+
|
| 53 |
+
if not enabled_keys:
|
| 54 |
+
raise RuntimeError("SPACE_ENABLED_MODELS did not resolve to any supported models")
|
| 55 |
+
return enabled_keys
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _default_model_key(enabled_model_keys: list[str]) -> str:
|
| 59 |
+
configured = os.getenv("SPACE_DEFAULT_MODEL", "").strip()
|
| 60 |
+
if not configured:
|
| 61 |
+
return "dinov2" if "dinov2" in enabled_model_keys else enabled_model_keys[0]
|
| 62 |
+
|
| 63 |
+
normalized = normalize_model_key(configured)
|
| 64 |
+
if normalized not in enabled_model_keys:
|
| 65 |
+
raise RuntimeError(
|
| 66 |
+
f"SPACE_DEFAULT_MODEL '{configured}' is not enabled by SPACE_ENABLED_MODELS"
|
| 67 |
+
)
|
| 68 |
+
return normalized
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _space_caption(enabled_model_keys: list[str]) -> str:
|
| 72 |
+
if enabled_model_keys == ["dinov2"]:
|
| 73 |
+
return "Streamlit Hugging Face Space demo for the DINO-Endo phase-recognition stack."
|
| 74 |
+
return "DINO-first Streamlit Hugging Face Space demo for DINO-Endo, AI-Endo, and V-JEPA2."
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _ensure_predictor(model_key: str):
|
| 78 |
+
active_key = st.session_state.get("active_model_key")
|
| 79 |
+
active_predictor = st.session_state.get("active_predictor")
|
| 80 |
+
|
| 81 |
+
if active_predictor is not None and active_key != model_key:
|
| 82 |
+
active_predictor.unload()
|
| 83 |
+
st.session_state.pop("active_predictor", None)
|
| 84 |
+
st.session_state.pop("active_model_key", None)
|
| 85 |
+
|
| 86 |
+
if st.session_state.get("active_predictor") is None:
|
| 87 |
+
with st.spinner(f"Preparing {MODEL_LABELS[model_key]}..."):
|
| 88 |
+
model_dir = ensure_model_artifacts(model_key)
|
| 89 |
+
predictor = create_predictor(model_key, model_dir=str(model_dir))
|
| 90 |
+
predictor.warm_up()
|
| 91 |
+
st.session_state["active_predictor"] = predictor
|
| 92 |
+
st.session_state["active_model_key"] = model_key
|
| 93 |
+
|
| 94 |
+
return st.session_state["active_predictor"]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _analyse_video(uploaded_file, predictor, frame_stride: int, max_frames: int):
|
| 98 |
+
suffix = Path(uploaded_file.name).suffix or ".mp4"
|
| 99 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
| 100 |
+
tmp.write(uploaded_file.getbuffer())
|
| 101 |
+
temp_path = Path(tmp.name)
|
| 102 |
+
|
| 103 |
+
capture = cv2.VideoCapture(str(temp_path))
|
| 104 |
+
if not capture.isOpened():
|
| 105 |
+
temp_path.unlink(missing_ok=True)
|
| 106 |
+
raise RuntimeError("Unable to open uploaded video")
|
| 107 |
+
|
| 108 |
+
total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
| 109 |
+
fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0)
|
| 110 |
+
progress = st.progress(0)
|
| 111 |
+
status = st.empty()
|
| 112 |
+
|
| 113 |
+
predictor.reset_state()
|
| 114 |
+
records = []
|
| 115 |
+
processed = 0
|
| 116 |
+
frame_index = 0
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
while True:
|
| 120 |
+
ok, frame = capture.read()
|
| 121 |
+
if not ok:
|
| 122 |
+
break
|
| 123 |
+
|
| 124 |
+
if frame_index % frame_stride != 0:
|
| 125 |
+
frame_index += 1
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 129 |
+
started = time.perf_counter()
|
| 130 |
+
result = predictor.predict(rgb)
|
| 131 |
+
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
| 132 |
+
|
| 133 |
+
probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
|
| 134 |
+
record = {
|
| 135 |
+
"frame_index": frame_index,
|
| 136 |
+
"timestamp_sec": round(frame_index / fps, 3) if fps > 0 else None,
|
| 137 |
+
"phase": result.get("phase", "unknown"),
|
| 138 |
+
"phase_id": _phase_index(result.get("phase", "unknown")),
|
| 139 |
+
"confidence": float(result.get("confidence", 0.0)),
|
| 140 |
+
"frames_used": int(result.get("frames_used", processed + 1)),
|
| 141 |
+
"idle": float(probs[0]) if len(probs) > 0 else 0.0,
|
| 142 |
+
"marking": float(probs[1]) if len(probs) > 1 else 0.0,
|
| 143 |
+
"injection": float(probs[2]) if len(probs) > 2 else 0.0,
|
| 144 |
+
"dissection": float(probs[3]) if len(probs) > 3 else 0.0,
|
| 145 |
+
"inference_ms": round(elapsed_ms, 3),
|
| 146 |
+
}
|
| 147 |
+
records.append(record)
|
| 148 |
+
processed += 1
|
| 149 |
+
|
| 150 |
+
if total_frames > 0:
|
| 151 |
+
progress.progress(min(frame_index + 1, total_frames) / total_frames)
|
| 152 |
+
else:
|
| 153 |
+
progress.progress(min(processed / max_frames, 1.0))
|
| 154 |
+
status.caption(f"Processed {processed} sampled frames")
|
| 155 |
+
|
| 156 |
+
frame_index += 1
|
| 157 |
+
if processed >= max_frames:
|
| 158 |
+
break
|
| 159 |
+
finally:
|
| 160 |
+
capture.release()
|
| 161 |
+
temp_path.unlink(missing_ok=True)
|
| 162 |
+
predictor.reset_state()
|
| 163 |
+
|
| 164 |
+
progress.empty()
|
| 165 |
+
status.empty()
|
| 166 |
+
return records, {"fps": fps, "total_frames": total_frames, "sampled_frames": processed}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _records_to_frame(records):
|
| 170 |
+
if not records:
|
| 171 |
+
return pd.DataFrame(columns=["frame_index", "timestamp_sec", "phase", "confidence"])
|
| 172 |
+
return pd.DataFrame.from_records(records)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _download_payloads(df: pd.DataFrame):
|
| 176 |
+
json_payload = df.to_json(orient="records", indent=2).encode("utf-8")
|
| 177 |
+
csv_payload = df.to_csv(index=False).encode("utf-8")
|
| 178 |
+
return json_payload, csv_payload
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _render_single_result(result: dict):
|
| 182 |
+
probs = result.get("probs", [0.0, 0.0, 0.0, 0.0])
|
| 183 |
+
metrics = st.columns(3)
|
| 184 |
+
metrics[0].metric("Predicted phase", result.get("phase", "unknown").upper())
|
| 185 |
+
metrics[1].metric("Confidence", f"{float(result.get('confidence', 0.0)):.1%}")
|
| 186 |
+
metrics[2].metric("Frames used", int(result.get("frames_used", 1)))
|
| 187 |
+
|
| 188 |
+
prob_df = pd.DataFrame({"phase": list(PHASE_LABELS), "probability": probs})
|
| 189 |
+
st.bar_chart(prob_df.set_index("phase"))
|
| 190 |
+
st.download_button(
|
| 191 |
+
label="Download JSON",
|
| 192 |
+
data=json.dumps(result, indent=2).encode("utf-8"),
|
| 193 |
+
file_name="phase_prediction.json",
|
| 194 |
+
mime="application/json",
|
| 195 |
+
key="download-single-json",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _render_video_results(records, meta):
|
| 200 |
+
if not records:
|
| 201 |
+
st.warning("No frames were processed from the uploaded video.")
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
df = _records_to_frame(records)
|
| 205 |
+
counts = Counter(df["phase"].tolist())
|
| 206 |
+
dominant_phase, dominant_count = counts.most_common(1)[0]
|
| 207 |
+
|
| 208 |
+
metrics = st.columns(4)
|
| 209 |
+
metrics[0].metric("Sampled frames", int(meta["sampled_frames"]))
|
| 210 |
+
metrics[1].metric("Dominant phase", dominant_phase.upper())
|
| 211 |
+
metrics[2].metric("Mean confidence", f"{df['confidence'].mean():.1%}")
|
| 212 |
+
metrics[3].metric("Average inference", f"{df['inference_ms'].mean():.1f} ms")
|
| 213 |
+
|
| 214 |
+
chart_df = df.copy()
|
| 215 |
+
if "timestamp_sec" in chart_df and chart_df["timestamp_sec"].notna().any():
|
| 216 |
+
chart_df = chart_df.set_index("timestamp_sec")
|
| 217 |
+
else:
|
| 218 |
+
chart_df = chart_df.set_index("frame_index")
|
| 219 |
+
|
| 220 |
+
st.subheader("Confidence timeline")
|
| 221 |
+
st.line_chart(chart_df[["confidence"]])
|
| 222 |
+
|
| 223 |
+
st.subheader("Phase timeline")
|
| 224 |
+
st.line_chart(chart_df[["phase_id"]])
|
| 225 |
+
|
| 226 |
+
st.subheader("Per-frame predictions")
|
| 227 |
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 228 |
+
|
| 229 |
+
json_payload, csv_payload = _download_payloads(df)
|
| 230 |
+
left, right = st.columns(2)
|
| 231 |
+
left.download_button("Download JSON", json_payload, file_name="phase_timeline.json", mime="application/json")
|
| 232 |
+
right.download_button("Download CSV", csv_payload, file_name="phase_timeline.csv", mime="text/csv")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def main():
|
| 236 |
+
enabled_model_keys = _enabled_model_keys()
|
| 237 |
+
default_model_key = _default_model_key(enabled_model_keys)
|
| 238 |
+
|
| 239 |
+
st.title("DINO-Endo Surgical Phase Recognition")
|
| 240 |
+
st.caption(_space_caption(enabled_model_keys))
|
| 241 |
+
|
| 242 |
+
st.sidebar.markdown("### Model")
|
| 243 |
+
if len(enabled_model_keys) == 1:
|
| 244 |
+
model_key = enabled_model_keys[0]
|
| 245 |
+
st.sidebar.write(MODEL_LABELS[model_key])
|
| 246 |
+
else:
|
| 247 |
+
model_key = st.sidebar.selectbox(
|
| 248 |
+
"Model",
|
| 249 |
+
options=enabled_model_keys,
|
| 250 |
+
index=enabled_model_keys.index(default_model_key),
|
| 251 |
+
format_func=lambda key: MODEL_LABELS[key],
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
source_summary = get_model_source_summary(model_key)
|
| 255 |
+
st.sidebar.markdown("### Runtime")
|
| 256 |
+
st.sidebar.write(f"CUDA available: `{torch.cuda.is_available()}`")
|
| 257 |
+
if torch.cuda.is_available():
|
| 258 |
+
st.sidebar.write(f"Device: `{torch.cuda.get_device_name(torch.cuda.current_device())}`")
|
| 259 |
+
st.sidebar.write(f"Model dir: `{source_summary['model_dir']}`")
|
| 260 |
+
st.sidebar.write(f"HF repo: `{source_summary['repo_id'] or 'local-only'}`")
|
| 261 |
+
if source_summary["subfolder"]:
|
| 262 |
+
st.sidebar.write(f"Repo subfolder: `{source_summary['subfolder']}`")
|
| 263 |
+
|
| 264 |
+
image_tab, video_tab = st.tabs(["Image", "Video"])
|
| 265 |
+
|
| 266 |
+
with image_tab:
|
| 267 |
+
uploaded_image = st.file_uploader("Upload an RGB frame", type=["png", "jpg", "jpeg"], key="image-uploader")
|
| 268 |
+
if uploaded_image is not None:
|
| 269 |
+
rgb = _image_to_rgb(uploaded_image)
|
| 270 |
+
st.image(rgb, caption=uploaded_image.name, use_container_width=True)
|
| 271 |
+
if st.button("Run image inference", key="run-image"):
|
| 272 |
+
predictor = _ensure_predictor(model_key)
|
| 273 |
+
predictor.reset_state()
|
| 274 |
+
started = time.perf_counter()
|
| 275 |
+
result = predictor.predict(rgb)
|
| 276 |
+
result["inference_ms"] = round((time.perf_counter() - started) * 1000.0, 3)
|
| 277 |
+
predictor.reset_state()
|
| 278 |
+
_render_single_result(result)
|
| 279 |
+
|
| 280 |
+
with video_tab:
|
| 281 |
+
frame_stride = st.slider("Analyze every Nth frame", min_value=1, max_value=30, value=5, step=1)
|
| 282 |
+
max_frames = st.slider("Maximum sampled frames", min_value=10, max_value=600, value=180, step=10)
|
| 283 |
+
uploaded_video = st.file_uploader(
|
| 284 |
+
"Upload a video",
|
| 285 |
+
type=["mp4", "mov", "avi", "mkv", "webm", "m4v"],
|
| 286 |
+
key="video-uploader",
|
| 287 |
+
)
|
| 288 |
+
if uploaded_video is not None:
|
| 289 |
+
st.video(uploaded_video)
|
| 290 |
+
if st.button("Analyze video", key="run-video"):
|
| 291 |
+
predictor = _ensure_predictor(model_key)
|
| 292 |
+
records, meta = _analyse_video(uploaded_video, predictor, frame_stride=frame_stride, max_frames=max_frames)
|
| 293 |
+
_render_video_results(records, meta)
|
| 294 |
+
|
| 295 |
+
if st.sidebar.button("Unload active model"):
|
| 296 |
+
predictor = st.session_state.get("active_predictor")
|
| 297 |
+
if predictor is not None:
|
| 298 |
+
predictor.unload()
|
| 299 |
+
st.session_state.pop("active_predictor", None)
|
| 300 |
+
st.session_state.pop("active_model_key", None)
|
| 301 |
+
st.sidebar.success("Model unloaded")
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
main()
|
model/__init__.py
ADDED
|
File without changes
|
model/mstcn.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MultiStageModel(nn.Module):
|
| 9 |
+
def __init__(self, mstcn_stages, mstcn_layers, mstcn_f_maps, mstcn_f_dim, out_features, mstcn_causal_conv, is_train=True, dropout_prob: float = 0.0):
|
| 10 |
+
self.num_stages = mstcn_stages
|
| 11 |
+
self.num_layers = mstcn_layers
|
| 12 |
+
self.num_f_maps = mstcn_f_maps
|
| 13 |
+
self.dim = mstcn_f_dim
|
| 14 |
+
self.num_classes = out_features
|
| 15 |
+
self.causal_conv = mstcn_causal_conv
|
| 16 |
+
self.is_train = is_train
|
| 17 |
+
print(f"num_stages_classification: {self.num_stages}, num_layers: {self.num_layers}, num_f_maps: {self.num_f_maps}, dim: {self.dim}")
|
| 18 |
+
super(MultiStageModel, self).__init__()
|
| 19 |
+
self.stage1 = SingleStageModel(self.num_layers,
|
| 20 |
+
self.num_f_maps,
|
| 21 |
+
self.dim,
|
| 22 |
+
self.num_classes,
|
| 23 |
+
causal_conv=self.causal_conv,
|
| 24 |
+
is_train=is_train,
|
| 25 |
+
dropout_prob=dropout_prob)
|
| 26 |
+
self.stages = SingleStageModel(self.num_layers,
|
| 27 |
+
self.num_f_maps,
|
| 28 |
+
self.num_classes,
|
| 29 |
+
self.num_classes,
|
| 30 |
+
causal_conv=self.causal_conv,
|
| 31 |
+
is_train=is_train,
|
| 32 |
+
dropout_prob=dropout_prob)
|
| 33 |
+
|
| 34 |
+
self.smoothing = False
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
"""
|
| 38 |
+
If is_train is False (inference), return first-stage features [B, num_f_maps, T]
|
| 39 |
+
so downstream Transformer receives 32-d features, matching the working pipeline.
|
| 40 |
+
If is_train is True (training/classification), return stacked class logits.
|
| 41 |
+
"""
|
| 42 |
+
out = self.stage1(x)
|
| 43 |
+
if not self.is_train:
|
| 44 |
+
# Inference path: return temporal features (num_f_maps channels)
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
# Training path: run second stage on class probabilities
|
| 48 |
+
outputs_classes = out.unsqueeze(0)
|
| 49 |
+
out_classes = self.stages(F.softmax(out, dim=1))
|
| 50 |
+
outputs_classes = torch.cat((outputs_classes, out_classes.unsqueeze(0)), dim=0)
|
| 51 |
+
return outputs_classes
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def add_model_specific_args(parser): # pragma: no cover
|
| 55 |
+
mstcn_reg_model_specific_args = parser.add_argument_group(title='mstcn reg specific args options')
|
| 56 |
+
mstcn_reg_model_specific_args.add_argument("--mstcn_stages", default=4, type=int)
|
| 57 |
+
mstcn_reg_model_specific_args.add_argument("--mstcn_layers", default=10, type=int)
|
| 58 |
+
mstcn_reg_model_specific_args.add_argument("--mstcn_f_maps", default=64, type=int)
|
| 59 |
+
mstcn_reg_model_specific_args.add_argument("--mstcn_f_dim", default=2048, type=int)
|
| 60 |
+
mstcn_reg_model_specific_args.add_argument("--mstcn_causal_conv", action='store_true')
|
| 61 |
+
return parser
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class SingleStageModel(nn.Module):
|
| 65 |
+
def __init__(self,
|
| 66 |
+
num_layers: int,
|
| 67 |
+
num_f_maps: int,
|
| 68 |
+
dim: int,
|
| 69 |
+
num_classes: int,
|
| 70 |
+
causal_conv: bool = False,
|
| 71 |
+
is_train: bool = True,
|
| 72 |
+
dropout_prob: float = 0.0):
|
| 73 |
+
super(SingleStageModel, self).__init__()
|
| 74 |
+
self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
|
| 75 |
+
self.is_train = is_train
|
| 76 |
+
self.layers = nn.ModuleList([
|
| 77 |
+
copy.deepcopy(DilatedResidualLayer(2 ** i, num_f_maps, num_f_maps, causal_conv=causal_conv, dropout_prob=dropout_prob))
|
| 78 |
+
for i in range(num_layers)
|
| 79 |
+
])
|
| 80 |
+
if self.is_train:
|
| 81 |
+
self.conv_out_classes = nn.Conv1d(num_f_maps, num_classes, 1)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
out = self.conv_1x1(x)
|
| 85 |
+
for layer in self.layers:
|
| 86 |
+
out = layer(out)
|
| 87 |
+
if self.is_train:
|
| 88 |
+
out = self.conv_out_classes(out)
|
| 89 |
+
return out
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class DilatedResidualLayer(nn.Module):
|
| 93 |
+
def __init__(self,
|
| 94 |
+
dilation: int,
|
| 95 |
+
in_channels: int,
|
| 96 |
+
out_channels: int,
|
| 97 |
+
causal_conv: bool = False,
|
| 98 |
+
kernel_size: int = 3,
|
| 99 |
+
dropout_prob: float = 0.0):
|
| 100 |
+
super(DilatedResidualLayer, self).__init__()
|
| 101 |
+
self.causal_conv = causal_conv
|
| 102 |
+
self.dilation = dilation
|
| 103 |
+
self.kernel_size = kernel_size
|
| 104 |
+
padding = (dilation * (kernel_size - 1)) if self.causal_conv else dilation
|
| 105 |
+
self.conv_dilated = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
| 106 |
+
self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
|
| 107 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 108 |
+
|
| 109 |
+
self.activation = nn.ReLU(inplace=True)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
out = self.activation(self.conv_dilated(x))
|
| 113 |
+
out = self.dropout(out)
|
| 114 |
+
if self.causal_conv:
|
| 115 |
+
out = out[:, :, :-(self.dilation * 2)]
|
| 116 |
+
out = self.activation(self.conv_1x1(out))
|
| 117 |
+
out = self.dropout(out)
|
| 118 |
+
return x + out
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class SingleStageModel1(nn.Module):
|
| 122 |
+
def __init__(self,
|
| 123 |
+
num_layers,
|
| 124 |
+
num_f_maps,
|
| 125 |
+
dim,
|
| 126 |
+
num_classes,
|
| 127 |
+
causal_conv=False):
|
| 128 |
+
super(SingleStageModel1, self).__init__()
|
| 129 |
+
self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
|
| 130 |
+
|
| 131 |
+
self.layers = nn.ModuleList([
|
| 132 |
+
copy.deepcopy(
|
| 133 |
+
DilatedResidualLayer(2**i,
|
| 134 |
+
num_f_maps,
|
| 135 |
+
num_f_maps,
|
| 136 |
+
causal_conv=causal_conv))
|
| 137 |
+
for i in range(num_layers)
|
| 138 |
+
])
|
| 139 |
+
self.conv_out_classes = nn.Conv1d(num_f_maps, num_classes, 1)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
out = self.conv_1x1(x)
|
| 143 |
+
for layer in self.layers:
|
| 144 |
+
out = layer(out)
|
| 145 |
+
out_classes = self.conv_out_classes(out)
|
| 146 |
+
return out_classes, out
|
| 147 |
+
|
| 148 |
+
class MultiStageModel1(nn.Module):
|
| 149 |
+
def __init__(self, mstcn_stages, mstcn_layers, mstcn_f_maps, mstcn_f_dim, out_features, mstcn_causal_conv):
|
| 150 |
+
self.num_stages = mstcn_stages # 4 #2
|
| 151 |
+
self.num_layers = mstcn_layers # 10 #5
|
| 152 |
+
self.num_f_maps = mstcn_f_maps # 64 #64
|
| 153 |
+
self.dim = mstcn_f_dim #2048 # 2048
|
| 154 |
+
self.num_classes = out_features # 7
|
| 155 |
+
self.causal_conv = mstcn_causal_conv
|
| 156 |
+
print(
|
| 157 |
+
f"num_stages_classification: {self.num_stages}, num_layers: {self.num_layers}, num_f_maps:"
|
| 158 |
+
f" {self.num_f_maps}, dim: {self.dim}")
|
| 159 |
+
super(MultiStageModel1, self).__init__()
|
| 160 |
+
self.stage1 = SingleStageModel1(self.num_layers,
|
| 161 |
+
self.num_f_maps,
|
| 162 |
+
self.dim,
|
| 163 |
+
self.num_classes,
|
| 164 |
+
causal_conv=self.causal_conv)
|
| 165 |
+
self.stages = nn.ModuleList([
|
| 166 |
+
copy.deepcopy(
|
| 167 |
+
SingleStageModel1(self.num_layers,
|
| 168 |
+
self.num_f_maps,
|
| 169 |
+
self.num_classes,
|
| 170 |
+
self.num_classes,
|
| 171 |
+
causal_conv=self.causal_conv))
|
| 172 |
+
for s in range(self.num_stages - 1)
|
| 173 |
+
])
|
| 174 |
+
self.smoothing = False
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
out_classes, _ = self.stage1(x)
|
| 178 |
+
outputs_classes = out_classes.unsqueeze(0)
|
| 179 |
+
for s in self.stages:
|
| 180 |
+
out_classes, out = s(F.softmax(out_classes, dim=1))
|
| 181 |
+
outputs_classes = torch.cat(
|
| 182 |
+
(outputs_classes, out_classes.unsqueeze(0)), dim=0)
|
| 183 |
+
return out
|
model/resnet.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision
|
| 4 |
+
from torchvision import models, transforms
|
| 5 |
+
from torchvision.models import ResNet50_Weights
|
| 6 |
+
|
| 7 |
+
# User's ResNet variant (adapted for 2048-d features, no head)
|
| 8 |
+
class ResNet(nn.Module):
|
| 9 |
+
def __init__(self, out_channels=4, has_fc=False):
|
| 10 |
+
super(ResNet, self).__init__()
|
| 11 |
+
self.resnet = torchvision.models.resnet50(pretrained=False)
|
| 12 |
+
if not has_fc:
|
| 13 |
+
self.resnet.fc = nn.Identity() # Output 2048-d features
|
| 14 |
+
else:
|
| 15 |
+
# Keep the original fc layer for compatibility
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return self.resnet(x)
|
model/transformer.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# some code adapted from https://wmathor.com/index.php/archives/1455/
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ScaledDotProductAttention(nn.Module):
|
| 11 |
+
def __init__(self, d_k, n_heads):
|
| 12 |
+
super(ScaledDotProductAttention, self).__init__()
|
| 13 |
+
self.d_k = d_k
|
| 14 |
+
self.n_heads = n_heads
|
| 15 |
+
|
| 16 |
+
def forward(self, Q, K, V):
|
| 17 |
+
'''
|
| 18 |
+
Q: [batch_size, n_heads, len_q=1, d_k]
|
| 19 |
+
K: [batch_size, n_heads, len_k, d_k]
|
| 20 |
+
V: [batch_size, n_heads, len_v(=len_k), d_v]
|
| 21 |
+
'''
|
| 22 |
+
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(
|
| 23 |
+
self.d_k) # scores : [batch_size, n_heads, len_q, len_k]
|
| 24 |
+
|
| 25 |
+
attn = nn.Softmax(dim=-1)(scores) # [batch_size, n_heads, len_q, len_q]
|
| 26 |
+
context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
|
| 27 |
+
return context, attn
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MultiHeadAttention(nn.Module):
|
| 31 |
+
def __init__(self, d_model, d_k, d_v, n_heads, len_q, len_k):
|
| 32 |
+
super(MultiHeadAttention, self).__init__()
|
| 33 |
+
|
| 34 |
+
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
|
| 35 |
+
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
|
| 36 |
+
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
|
| 37 |
+
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) # Linear only change the last dimension
|
| 38 |
+
|
| 39 |
+
self.d_model = d_model
|
| 40 |
+
self.d_k = d_k
|
| 41 |
+
self.d_v = d_v
|
| 42 |
+
self.n_heads = n_heads
|
| 43 |
+
self.ScaledDotProductAttention = ScaledDotProductAttention(self.d_k, n_heads)
|
| 44 |
+
self.len_q = len_q
|
| 45 |
+
self.len_k = len_k
|
| 46 |
+
|
| 47 |
+
def forward(self, input_Q, input_K, input_V):
|
| 48 |
+
'''
|
| 49 |
+
input_Q: [batch_size, len_q, d_model] [512, 1, 5] --> Spatial info
|
| 50 |
+
input_K: [batch_size, len_k, d_model] [512, 30, 5] --> Temporal info
|
| 51 |
+
input_V: [batch_size, len_v(=len_k), d_model] [512, 30, 5] --> Temporal info
|
| 52 |
+
'''
|
| 53 |
+
residual, batch_size = input_Q, input_Q.size(0)
|
| 54 |
+
# (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
|
| 55 |
+
Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # Q: [batch_size, n_heads, len_q, d_k]
|
| 56 |
+
|
| 57 |
+
K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # K: [batch_size, n_heads, len_k, d_k]
|
| 58 |
+
|
| 59 |
+
V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) # V: [batch_size, n_heads, len_v(=len_k), d_v]
|
| 60 |
+
|
| 61 |
+
# context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
|
| 62 |
+
context, attn = self.ScaledDotProductAttention(Q, K, V)
|
| 63 |
+
context = context.transpose(1, 2).reshape(batch_size, -1,
|
| 64 |
+
self.n_heads * self.d_v) # context: [batch_size, len_q, n_heads * d_v]
|
| 65 |
+
output = self.fc(context) # [batch_size, len_q, d_model]
|
| 66 |
+
layer_norm = nn.LayerNorm(self.d_model).to(output.device)
|
| 67 |
+
return layer_norm(output + residual), attn # All batch size dimensions are reserved.
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class PoswiseFeedForwardNet(nn.Module):
|
| 71 |
+
def __init__(self, d_model, d_ff):
|
| 72 |
+
super(PoswiseFeedForwardNet, self).__init__()
|
| 73 |
+
self.fc = nn.Sequential(
|
| 74 |
+
nn.Linear(d_model, d_ff, bias=False),
|
| 75 |
+
nn.ReLU(),
|
| 76 |
+
nn.Linear(d_ff, d_model, bias=False)
|
| 77 |
+
)
|
| 78 |
+
self.d_model = d_model
|
| 79 |
+
|
| 80 |
+
def forward(self, inputs):
|
| 81 |
+
'''
|
| 82 |
+
inputs: [batch_size, seq_len, d_model]
|
| 83 |
+
'''
|
| 84 |
+
residual = inputs
|
| 85 |
+
output = self.fc(inputs)
|
| 86 |
+
layer_norm = nn.LayerNorm(self.d_model).to(output.device)
|
| 87 |
+
return layer_norm(output + residual) # [batch_size, seq_len, d_model]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class EncoderLayer(nn.Module):
|
| 91 |
+
def __init__(self, d_model, d_ff, d_k, d_v, n_heads, len_q):
|
| 92 |
+
super(EncoderLayer, self).__init__()
|
| 93 |
+
self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads, 1, len_q)
|
| 94 |
+
self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)
|
| 95 |
+
|
| 96 |
+
def forward(self, enc_inputs):
|
| 97 |
+
'''
|
| 98 |
+
enc_inputs: [batch_size, src_len, d_model]
|
| 99 |
+
'''
|
| 100 |
+
# enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
|
| 101 |
+
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs) # enc_inputs to same Q,K,V
|
| 102 |
+
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
|
| 103 |
+
return enc_outputs, attn
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Encoder(nn.Module):
|
| 107 |
+
def __init__(self, d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q):
|
| 108 |
+
super(Encoder, self).__init__()
|
| 109 |
+
self.layers = nn.ModuleList([EncoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
|
| 110 |
+
|
| 111 |
+
def forward(self, enc_inputs):
|
| 112 |
+
'''
|
| 113 |
+
enc_inputs: [batch_size, src_len, d_model]
|
| 114 |
+
'''
|
| 115 |
+
enc_outputs = enc_inputs
|
| 116 |
+
enc_self_attns = []
|
| 117 |
+
for layer in self.layers:
|
| 118 |
+
# enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
|
| 119 |
+
enc_outputs, enc_self_attn = layer(enc_outputs)
|
| 120 |
+
enc_self_attns.append(enc_self_attn)
|
| 121 |
+
return enc_outputs, enc_self_attns
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class DecoderLayer(nn.Module):
|
| 125 |
+
def __init__(self, d_model, d_ff, d_k, d_v, n_heads, len_q):
|
| 126 |
+
super(DecoderLayer, self).__init__()
|
| 127 |
+
self.dec_enc_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads, 1, len_q)
|
| 128 |
+
self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)
|
| 129 |
+
|
| 130 |
+
def forward(self, dec_inputs, enc_outputs):
|
| 131 |
+
'''
|
| 132 |
+
dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5] --> Spatial info
|
| 133 |
+
enc_outputs: [batch_size, src_len, d_model] [512, 30, 5] --> Temporal info
|
| 134 |
+
dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
|
| 135 |
+
dec_enc_attn_mask: [batch_size, tgt_len, src_len]
|
| 136 |
+
'''
|
| 137 |
+
# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
|
| 138 |
+
# dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
|
| 139 |
+
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_inputs, enc_outputs, enc_outputs)
|
| 140 |
+
dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
|
| 141 |
+
return dec_outputs, dec_enc_attn
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Decoder(nn.Module):
|
| 145 |
+
def __init__(self, d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q):
|
| 146 |
+
super(Decoder, self).__init__()
|
| 147 |
+
self.layers = nn.ModuleList([DecoderLayer(d_model, d_ff, d_k, d_v, n_heads, len_q) for _ in range(n_layers)])
|
| 148 |
+
|
| 149 |
+
def forward(self, dec_inputs, enc_outputs):
|
| 150 |
+
'''
|
| 151 |
+
dec_inputs: [batch_size, tgt_len, d_model] [512, 1, 5]
|
| 152 |
+
enc_intpus: [batch_size, src_len, d_model] [512, 30, 5]
|
| 153 |
+
enc_outputs: [batsh_size, src_len, d_model]
|
| 154 |
+
'''
|
| 155 |
+
dec_outputs = dec_inputs # self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
|
| 156 |
+
# dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
|
| 157 |
+
|
| 158 |
+
dec_enc_attns = []
|
| 159 |
+
for layer in self.layers:
|
| 160 |
+
# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
|
| 161 |
+
dec_outputs, dec_enc_attn = layer(dec_outputs, enc_outputs)
|
| 162 |
+
dec_enc_attns.append(dec_enc_attn)
|
| 163 |
+
return dec_outputs
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# d_model, Embedding Size
|
| 167 |
+
# d_ff, FeedForward dimension
|
| 168 |
+
# d_k = d_v, dimension of K(=Q), V
|
| 169 |
+
# n_layers, number of Encoder of Decoder Layer
|
| 170 |
+
# n_heads, number of heads in Multi-Head Attention
|
| 171 |
+
|
| 172 |
+
class Transformer2_3_1(nn.Module):
|
| 173 |
+
def __init__(self, d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q):
|
| 174 |
+
super(Transformer2_3_1, self).__init__()
|
| 175 |
+
self.encoder = Encoder(d_model, d_ff, d_k, d_v, n_layers, n_heads, len_q)
|
| 176 |
+
self.decoder = Decoder(d_model, d_ff, d_k, d_v, 1, n_heads, len_q)
|
| 177 |
+
|
| 178 |
+
def forward(self, enc_inputs, dec_inputs):
|
| 179 |
+
'''
|
| 180 |
+
enc_inputs: [Frames, src_len, d_model] [512, 30, 5]
|
| 181 |
+
dec_inputs: [Frames, 1, d_model] [512, 1, 5]
|
| 182 |
+
'''
|
| 183 |
+
# tensor to store decoder outputs
|
| 184 |
+
# outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
|
| 185 |
+
|
| 186 |
+
# enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
|
| 187 |
+
enc_outputs, enc_self_attns = self.encoder(enc_inputs) # Self-attention for temporal features
|
| 188 |
+
dec_outputs = self.decoder(dec_inputs, enc_outputs)
|
| 189 |
+
return dec_outputs
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class Transformer(nn.Module):
|
| 193 |
+
def __init__(self, mstcn_f_maps, mstcn_f_dim, out_features, len_q, d_model=None):
|
| 194 |
+
super(Transformer, self).__init__()
|
| 195 |
+
# Use provided d_model (256) else fallback to mstcn_f_maps
|
| 196 |
+
self.d_model = d_model if d_model is not None else mstcn_f_maps
|
| 197 |
+
self.num_classes = out_features
|
| 198 |
+
self.len_q = len_q
|
| 199 |
+
|
| 200 |
+
# Spatial encoder with d_ff = d_model; heads=8; d_k=d_v=d_model
|
| 201 |
+
self.spatial_encoder = EncoderLayer(self.d_model, self.d_model, self.d_model, self.d_model, 8, 5)
|
| 202 |
+
self.transformer = Transformer2_3_1(d_model=self.d_model, d_ff=self.d_model, d_k=self.d_model,
|
| 203 |
+
d_v=self.d_model, n_layers=1, n_heads=8, len_q=len_q)
|
| 204 |
+
self.fc = nn.Linear(mstcn_f_dim, self.d_model, bias=False)
|
| 205 |
+
|
| 206 |
+
# Final head 256 -> num_classes, no bias to match checkpoint
|
| 207 |
+
self.out = nn.Sequential(
|
| 208 |
+
nn.ReLU(),
|
| 209 |
+
nn.Dropout(p=0.1),
|
| 210 |
+
nn.Linear(self.d_model, out_features, bias=False)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def forward(self, x, long_feature):
|
| 214 |
+
# x: [B, 256, T]; long_feature: [B, T, 256]
|
| 215 |
+
B, D, T = x.shape
|
| 216 |
+
out_features = x.transpose(1, 2) # [B, T, 256]
|
| 217 |
+
|
| 218 |
+
# Build sliding windows for temporal inputs
|
| 219 |
+
inputs = []
|
| 220 |
+
for i in range(T):
|
| 221 |
+
if i < self.len_q - 1:
|
| 222 |
+
pad = torch.zeros((B, self.len_q - 1 - i, self.d_model), device=x.device)
|
| 223 |
+
win = torch.cat([pad, out_features[:, :i + 1, :]], dim=1)
|
| 224 |
+
else:
|
| 225 |
+
win = out_features[:, i - self.len_q + 1:i + 1, :]
|
| 226 |
+
inputs.append(win)
|
| 227 |
+
inputs = torch.stack(inputs, dim=0).squeeze(1) # [T, B, len_q, 256]
|
| 228 |
+
|
| 229 |
+
# Project long features and create spatial windows
|
| 230 |
+
feas = torch.tanh(self.fc(long_feature)) # [B, T, 256]
|
| 231 |
+
spa_len = min(10, T)
|
| 232 |
+
out_feas = []
|
| 233 |
+
for i in range(T):
|
| 234 |
+
if i < spa_len - 1:
|
| 235 |
+
pad = torch.zeros((B, spa_len - 1 - i, self.d_model), device=feas.device)
|
| 236 |
+
win = torch.cat([pad, feas[:, :i + 1, :]], dim=1)
|
| 237 |
+
else:
|
| 238 |
+
win = out_features[:, i - spa_len + 1:i + 1, :]
|
| 239 |
+
out_feas.append(win)
|
| 240 |
+
out_feas = torch.stack(out_feas, dim=0).squeeze(1)
|
| 241 |
+
out_feas, _ = self.spatial_encoder(out_feas)
|
| 242 |
+
|
| 243 |
+
# Temporal-spatial fusion
|
| 244 |
+
output = self.transformer(inputs, out_feas) # [T, B, 1, 256] collapsed → [T, B, 256]
|
| 245 |
+
output = self.out(output) # [T, B, C]
|
| 246 |
+
return output.transpose(0, 1) # [B, T, C]
|
model_registry.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Iterable, Tuple
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
from huggingface_hub.utils import EntryNotFoundError
|
| 11 |
+
|
| 12 |
+
APP_ROOT = Path(__file__).resolve().parent
|
| 13 |
+
MODEL_ROOT = Path(os.environ.get("SPACE_MODEL_DIR", APP_ROOT / "model")).expanduser().resolve()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _default_hf_home() -> Path:
|
| 17 |
+
data_dir = Path("/data")
|
| 18 |
+
if data_dir.is_dir():
|
| 19 |
+
return data_dir / ".huggingface"
|
| 20 |
+
return APP_ROOT / ".cache" / "huggingface"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
HF_HOME = Path(os.environ.setdefault("HF_HOME", str(_default_hf_home()))).expanduser().resolve()
|
| 24 |
+
os.environ.setdefault("HF_HUB_CACHE", str(HF_HOME / "hub"))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class ModelSpec:
|
| 29 |
+
key: str
|
| 30 |
+
label: str
|
| 31 |
+
required_files: Tuple[str, ...]
|
| 32 |
+
optional_files: Tuple[str, ...] = ()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
MODEL_SPECS: Dict[str, ModelSpec] = {
|
| 36 |
+
"aiendo": ModelSpec(
|
| 37 |
+
key="aiendo",
|
| 38 |
+
label="AI-Endo",
|
| 39 |
+
required_files=("resnet50.pth", "fusion.pth", "transformer.pth"),
|
| 40 |
+
),
|
| 41 |
+
"dinov2": ModelSpec(
|
| 42 |
+
key="dinov2",
|
| 43 |
+
label="DINO-Endo",
|
| 44 |
+
required_files=("dinov2_vit14s_latest_checkpoint.pth", "fusion_transformer_decoder_best_model.pth"),
|
| 45 |
+
optional_files=("dinov2_decoder.pth",),
|
| 46 |
+
),
|
| 47 |
+
"vjepa2": ModelSpec(
|
| 48 |
+
key="vjepa2",
|
| 49 |
+
label="V-JEPA2",
|
| 50 |
+
required_files=("vjepa_encoder_human.pt", "mlp_decoder_human.pth"),
|
| 51 |
+
),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _repo_env_name(model_key: str) -> str:
|
| 56 |
+
prefix = {"aiendo": "AIENDO", "dinov2": "DINO", "vjepa2": "VJEPA2"}[model_key]
|
| 57 |
+
return f"{prefix}_MODEL_REPO_ID"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _revision_env_name(model_key: str) -> str:
|
| 61 |
+
prefix = {"aiendo": "AIENDO", "dinov2": "DINO", "vjepa2": "VJEPA2"}[model_key]
|
| 62 |
+
return f"{prefix}_MODEL_REVISION"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _subfolder_env_name(model_key: str) -> str:
|
| 66 |
+
prefix = {"aiendo": "AIENDO", "dinov2": "DINO", "vjepa2": "VJEPA2"}[model_key]
|
| 67 |
+
return f"{prefix}_MODEL_SUBFOLDER"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_model_repo_id(model_key: str) -> str | None:
|
| 71 |
+
return os.getenv(_repo_env_name(model_key)) or os.getenv("PHASE_MODEL_REPO_ID")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_model_revision(model_key: str) -> str | None:
|
| 75 |
+
return os.getenv(_revision_env_name(model_key)) or os.getenv("PHASE_MODEL_REVISION")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_model_subfolder(model_key: str) -> str:
|
| 79 |
+
return (os.getenv(_subfolder_env_name(model_key)) or "").strip("/")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_hf_token() -> str | None:
|
| 83 |
+
return os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def ensure_model_root() -> Path:
|
| 87 |
+
MODEL_ROOT.mkdir(parents=True, exist_ok=True)
|
| 88 |
+
HF_HOME.mkdir(parents=True, exist_ok=True)
|
| 89 |
+
Path(os.environ["HF_HUB_CACHE"]).mkdir(parents=True, exist_ok=True)
|
| 90 |
+
return MODEL_ROOT
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _remote_filename(model_key: str, filename: str) -> str:
|
| 94 |
+
subfolder = get_model_subfolder(model_key)
|
| 95 |
+
return f"{subfolder}/{filename}" if subfolder else filename
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _download_to_model_root(model_key: str, filename: str, *, optional: bool = False) -> Path | None:
|
| 99 |
+
target = ensure_model_root() / filename
|
| 100 |
+
if target.exists():
|
| 101 |
+
return target
|
| 102 |
+
|
| 103 |
+
repo_id = get_model_repo_id(model_key)
|
| 104 |
+
if not repo_id:
|
| 105 |
+
if optional:
|
| 106 |
+
return None
|
| 107 |
+
raise FileNotFoundError(
|
| 108 |
+
f"Missing {filename} in {MODEL_ROOT}. Set { _repo_env_name(model_key) } or PHASE_MODEL_REPO_ID, "
|
| 109 |
+
f"or copy the checkpoint into the local model directory."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
downloaded = hf_hub_download(
|
| 114 |
+
repo_id=repo_id,
|
| 115 |
+
filename=_remote_filename(model_key, filename),
|
| 116 |
+
repo_type="model",
|
| 117 |
+
revision=get_model_revision(model_key),
|
| 118 |
+
token=get_hf_token(),
|
| 119 |
+
)
|
| 120 |
+
except EntryNotFoundError:
|
| 121 |
+
if optional:
|
| 122 |
+
return None
|
| 123 |
+
raise
|
| 124 |
+
|
| 125 |
+
downloaded_path = Path(downloaded)
|
| 126 |
+
if downloaded_path.resolve() != target.resolve():
|
| 127 |
+
shutil.copy2(downloaded_path, target)
|
| 128 |
+
return target
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def ensure_model_artifacts(model_key: str) -> Path:
|
| 132 |
+
if model_key not in MODEL_SPECS:
|
| 133 |
+
raise KeyError(f"Unknown model key: {model_key}")
|
| 134 |
+
|
| 135 |
+
spec = MODEL_SPECS[model_key]
|
| 136 |
+
ensure_model_root()
|
| 137 |
+
|
| 138 |
+
for filename in spec.required_files:
|
| 139 |
+
_download_to_model_root(model_key, filename, optional=False)
|
| 140 |
+
for filename in spec.optional_files:
|
| 141 |
+
_download_to_model_root(model_key, filename, optional=True)
|
| 142 |
+
|
| 143 |
+
return MODEL_ROOT
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_model_source_summary(model_key: str) -> dict:
|
| 147 |
+
spec = MODEL_SPECS[model_key]
|
| 148 |
+
return {
|
| 149 |
+
"label": spec.label,
|
| 150 |
+
"model_dir": str(MODEL_ROOT),
|
| 151 |
+
"repo_id": get_model_repo_id(model_key),
|
| 152 |
+
"revision": get_model_revision(model_key),
|
| 153 |
+
"subfolder": get_model_subfolder(model_key),
|
| 154 |
+
"required_files": list(spec.required_files),
|
| 155 |
+
"optional_files": list(spec.optional_files),
|
| 156 |
+
}
|
predictor.py
ADDED
|
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from contextlib import nullcontext
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import albumentations as A
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from torch.amp import autocast
|
| 17 |
+
MIXED_PRECISION_AVAILABLE = True
|
| 18 |
+
except ImportError: # pragma: no cover
|
| 19 |
+
MIXED_PRECISION_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
from model.resnet import ResNet
|
| 22 |
+
from model.mstcn import MultiStageModel
|
| 23 |
+
from model.transformer import Transformer
|
| 24 |
+
|
| 25 |
+
PHASE_LABELS = ("idle", "marking", "injection", "dissection")
|
| 26 |
+
MODEL_LABELS = {
|
| 27 |
+
"aiendo": "AI-Endo",
|
| 28 |
+
"dinov2": "DINO-Endo",
|
| 29 |
+
"vjepa2": "V-JEPA2",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _app_root() -> Path:
|
| 34 |
+
return Path(__file__).resolve().parent
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def default_model_dir() -> str:
|
| 38 |
+
return str(Path(os.environ.get("SPACE_MODEL_DIR", _app_root() / "model")).expanduser().resolve())
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def normalize_model_key(name: str | None) -> str:
|
| 42 |
+
token = (name or "aiendo").lower().replace("-", "").replace("_", "").strip()
|
| 43 |
+
if token in ("aiendo", "resnet", "aiendoresnet", "aiendoresnetmstcn", "aiendoresnetmstcntransformer"):
|
| 44 |
+
return "aiendo"
|
| 45 |
+
if token in ("dinov2", "dinov2endo", "dinoendo", "dino"):
|
| 46 |
+
return "dinov2"
|
| 47 |
+
if token in ("vjepa2", "vjepa", "vjepa2endo"):
|
| 48 |
+
return "vjepa2"
|
| 49 |
+
raise KeyError(f"Unsupported model key: {name}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _load_trusted_checkpoint(path: str, map_location="cpu"):
|
| 53 |
+
try:
|
| 54 |
+
return torch.load(path, map_location=map_location, weights_only=False)
|
| 55 |
+
except TypeError: # pragma: no cover
|
| 56 |
+
return torch.load(path, map_location=map_location)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _strip_state_dict_prefixes(state_dict, prefixes):
|
| 60 |
+
cleaned_state = {}
|
| 61 |
+
for key, value in state_dict.items():
|
| 62 |
+
while any(key.startswith(prefix) for prefix in prefixes):
|
| 63 |
+
for prefix in prefixes:
|
| 64 |
+
if key.startswith(prefix):
|
| 65 |
+
key = key[len(prefix):]
|
| 66 |
+
cleaned_state[key] = value
|
| 67 |
+
return cleaned_state
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _validate_load_result(
|
| 71 |
+
load_result,
|
| 72 |
+
model_name: str,
|
| 73 |
+
*,
|
| 74 |
+
allowed_missing=(),
|
| 75 |
+
allowed_missing_prefixes=(),
|
| 76 |
+
allowed_unexpected=(),
|
| 77 |
+
allowed_unexpected_prefixes=(),
|
| 78 |
+
):
|
| 79 |
+
missing = [
|
| 80 |
+
key
|
| 81 |
+
for key in load_result.missing_keys
|
| 82 |
+
if key not in allowed_missing and not any(key.startswith(prefix) for prefix in allowed_missing_prefixes)
|
| 83 |
+
]
|
| 84 |
+
unexpected = [
|
| 85 |
+
key
|
| 86 |
+
for key in load_result.unexpected_keys
|
| 87 |
+
if key not in allowed_unexpected and not any(key.startswith(prefix) for prefix in allowed_unexpected_prefixes)
|
| 88 |
+
]
|
| 89 |
+
if missing or unexpected:
|
| 90 |
+
problems = []
|
| 91 |
+
if missing:
|
| 92 |
+
problems.append(f"missing={missing[:10]}")
|
| 93 |
+
if unexpected:
|
| 94 |
+
problems.append(f"unexpected={unexpected[:10]}")
|
| 95 |
+
raise RuntimeError(f"{model_name} checkpoint mismatch ({'; '.join(problems)})")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _resolve_vendor_repo(repo_name: str, extra_candidates=()):
|
| 99 |
+
app_root = _app_root()
|
| 100 |
+
candidates = [app_root / repo_name]
|
| 101 |
+
if len(app_root.parents) >= 2:
|
| 102 |
+
candidates.append(app_root.parents[1] / repo_name)
|
| 103 |
+
candidates.extend(extra_candidates)
|
| 104 |
+
|
| 105 |
+
for candidate in candidates:
|
| 106 |
+
if candidate and candidate.exists():
|
| 107 |
+
return candidate
|
| 108 |
+
raise FileNotFoundError(f"Required vendor repo '{repo_name}' not found. Stage it into this folder or keep the repo-root copy available.")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Predictor:
|
| 112 |
+
def __init__(self, model_dir: str | None = None, device: str = "cuda"):
|
| 113 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 114 |
+
self.model_dir = model_dir or default_model_dir()
|
| 115 |
+
self.seq_length = 1024
|
| 116 |
+
self.trans_seq = 30
|
| 117 |
+
self.aug = A.Compose([A.Resize(height=224, width=224), A.Normalize()])
|
| 118 |
+
self.frame_feature_cache = None
|
| 119 |
+
self.label_dict = dict(enumerate(PHASE_LABELS))
|
| 120 |
+
self.available = False
|
| 121 |
+
|
| 122 |
+
self._norm_mean = None
|
| 123 |
+
self._norm_std = None
|
| 124 |
+
if self.device.type == "cuda":
|
| 125 |
+
self._norm_mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
|
| 126 |
+
self._norm_std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
|
| 127 |
+
|
| 128 |
+
self._load_models(self.model_dir)
|
| 129 |
+
|
| 130 |
+
def _load_models(self, model_dir: str):
|
| 131 |
+
self.resnet = ResNet(out_channels=4, has_fc=False)
|
| 132 |
+
paras = torch.load(os.path.join(model_dir, "resnet50.pth"), map_location=self.device)["model"]
|
| 133 |
+
paras = {k: v for k, v in paras.items() if "fc" not in k and "embed" not in k}
|
| 134 |
+
paras = {k.replace("share.", "resnet."): v for k, v in paras.items()}
|
| 135 |
+
self.resnet.load_state_dict(paras, strict=True)
|
| 136 |
+
self.resnet.to(self.device).eval()
|
| 137 |
+
|
| 138 |
+
self.fusion = MultiStageModel(
|
| 139 |
+
mstcn_stages=2,
|
| 140 |
+
mstcn_layers=8,
|
| 141 |
+
mstcn_f_maps=32,
|
| 142 |
+
mstcn_f_dim=2048,
|
| 143 |
+
out_features=4,
|
| 144 |
+
mstcn_causal_conv=True,
|
| 145 |
+
is_train=False,
|
| 146 |
+
)
|
| 147 |
+
fusion_weights = torch.load(os.path.join(model_dir, "fusion.pth"), map_location=self.device)
|
| 148 |
+
fusion_load = self.fusion.load_state_dict(fusion_weights, strict=False)
|
| 149 |
+
_validate_load_result(
|
| 150 |
+
fusion_load,
|
| 151 |
+
"AI-Endo fusion",
|
| 152 |
+
allowed_unexpected_prefixes=("stage1.conv_out_classes.",),
|
| 153 |
+
)
|
| 154 |
+
self.fusion.to(self.device).eval()
|
| 155 |
+
|
| 156 |
+
self.transformer = Transformer(32, 2048, 4, 30, d_model=32)
|
| 157 |
+
trans_weights = torch.load(os.path.join(model_dir, "transformer.pth"), map_location=self.device)
|
| 158 |
+
self.transformer.load_state_dict(trans_weights)
|
| 159 |
+
self.transformer.to(self.device).eval()
|
| 160 |
+
self.available = True
|
| 161 |
+
|
| 162 |
+
def _amp_context(self):
|
| 163 |
+
return autocast("cuda") if MIXED_PRECISION_AVAILABLE and self.device.type == "cuda" else nullcontext()
|
| 164 |
+
|
| 165 |
+
def _preprocess_gpu(self, rgb_image: np.ndarray) -> torch.Tensor:
|
| 166 |
+
tensor = torch.from_numpy(rgb_image).permute(2, 0, 1).unsqueeze(0)
|
| 167 |
+
tensor = tensor.to(self.device, dtype=torch.float32, non_blocking=True).div_(255.0)
|
| 168 |
+
if tensor.shape[-2:] != (224, 224):
|
| 169 |
+
tensor = F.interpolate(tensor, size=(224, 224), mode="bilinear", align_corners=False)
|
| 170 |
+
return (tensor - self._norm_mean) / self._norm_std
|
| 171 |
+
|
| 172 |
+
def warm_up(self):
|
| 173 |
+
dummy = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
| 174 |
+
self.predict(dummy)
|
| 175 |
+
self.reset_state()
|
| 176 |
+
|
| 177 |
+
def reset_state(self):
|
| 178 |
+
self.frame_feature_cache = None
|
| 179 |
+
if torch.cuda.is_available():
|
| 180 |
+
torch.cuda.empty_cache()
|
| 181 |
+
|
| 182 |
+
def unload(self):
|
| 183 |
+
self.available = False
|
| 184 |
+
self.resnet.to("cpu")
|
| 185 |
+
self.fusion.to("cpu")
|
| 186 |
+
self.transformer.to("cpu")
|
| 187 |
+
self.resnet = None
|
| 188 |
+
self.fusion = None
|
| 189 |
+
self.transformer = None
|
| 190 |
+
self.frame_feature_cache = None
|
| 191 |
+
if torch.cuda.is_available():
|
| 192 |
+
torch.cuda.empty_cache()
|
| 193 |
+
|
| 194 |
+
def _cache_features(self, feature: torch.Tensor):
|
| 195 |
+
if self.frame_feature_cache is None:
|
| 196 |
+
self.frame_feature_cache = feature
|
| 197 |
+
elif self.frame_feature_cache.shape[0] > self.seq_length:
|
| 198 |
+
self.frame_feature_cache = torch.cat([self.frame_feature_cache[1:], feature], dim=0)
|
| 199 |
+
else:
|
| 200 |
+
self.frame_feature_cache = torch.cat([self.frame_feature_cache, feature], dim=0)
|
| 201 |
+
|
| 202 |
+
@torch.inference_mode()
|
| 203 |
+
def predict(self, rgb_image: np.ndarray):
|
| 204 |
+
if self._norm_mean is not None:
|
| 205 |
+
tensor = self._preprocess_gpu(rgb_image)
|
| 206 |
+
else:
|
| 207 |
+
processed = self.aug(image=rgb_image)["image"]
|
| 208 |
+
chw = np.transpose(processed, (2, 0, 1))
|
| 209 |
+
tensor = torch.from_numpy(chw).unsqueeze(0).contiguous().to(self.device)
|
| 210 |
+
|
| 211 |
+
with self._amp_context():
|
| 212 |
+
feature = self.resnet(tensor).clone()
|
| 213 |
+
self._cache_features(feature)
|
| 214 |
+
|
| 215 |
+
if self.frame_feature_cache is None:
|
| 216 |
+
single_frame_feature = feature.unsqueeze(1)
|
| 217 |
+
temporal_input = single_frame_feature.transpose(1, 2)
|
| 218 |
+
temporal_feature = self.fusion(temporal_input)
|
| 219 |
+
outputs = self.transformer(temporal_feature.detach(), single_frame_feature)
|
| 220 |
+
final_logits = outputs[-1, -1, :]
|
| 221 |
+
probs = F.softmax(final_logits.float(), dim=-1)
|
| 222 |
+
pred_np = probs.detach().cpu().numpy()
|
| 223 |
+
confidence = float(np.max(pred_np))
|
| 224 |
+
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 225 |
+
phase = self.label_dict.get(phase_idx, "idle")
|
| 226 |
+
return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": 1}
|
| 227 |
+
|
| 228 |
+
if self.frame_feature_cache.shape[0] < 30:
|
| 229 |
+
available_frames = self.frame_feature_cache.shape[0] + 1
|
| 230 |
+
cat_frame_feature = torch.cat([self.frame_feature_cache, feature], dim=0).unsqueeze(0)
|
| 231 |
+
temporal_input = cat_frame_feature.transpose(1, 2)
|
| 232 |
+
temporal_feature = self.fusion(temporal_input)
|
| 233 |
+
outputs = self.transformer(temporal_feature.detach(), cat_frame_feature)
|
| 234 |
+
final_logits = outputs[-1, -1, :]
|
| 235 |
+
probs = F.softmax(final_logits.float(), dim=-1)
|
| 236 |
+
pred_np = probs.detach().cpu().numpy()
|
| 237 |
+
confidence = float(np.max(pred_np))
|
| 238 |
+
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 239 |
+
phase = self.label_dict.get(phase_idx, "idle")
|
| 240 |
+
return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
|
| 241 |
+
|
| 242 |
+
cat_frame_feature = self.frame_feature_cache.unsqueeze(0)
|
| 243 |
+
temporal_input = cat_frame_feature.transpose(1, 2)
|
| 244 |
+
temporal_feature = self.fusion(temporal_input)
|
| 245 |
+
outputs = self.transformer(temporal_feature.detach(), cat_frame_feature)
|
| 246 |
+
final_logits = outputs[-1, -1, :]
|
| 247 |
+
probs = F.softmax(final_logits.float(), dim=-1)
|
| 248 |
+
pred_np = probs.detach().cpu().numpy()
|
| 249 |
+
|
| 250 |
+
confidence = float(np.max(pred_np))
|
| 251 |
+
phase_idx = max(0, min(3, int(np.argmax(pred_np))))
|
| 252 |
+
phase = self.label_dict.get(phase_idx, "idle")
|
| 253 |
+
return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": min(self.trans_seq, self.frame_feature_cache.shape[0])}
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class PredictorDinoV2:
|
| 257 |
+
def __init__(self, model_dir: str | None = None, device: str = "cuda"):
|
| 258 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 259 |
+
self.model_dir = model_dir or default_model_dir()
|
| 260 |
+
self.seq_length = 30
|
| 261 |
+
self.available = False
|
| 262 |
+
self.backbone = None
|
| 263 |
+
self.decoder = None
|
| 264 |
+
self.label_dict = dict(enumerate(PHASE_LABELS))
|
| 265 |
+
self.aug = A.Compose([
|
| 266 |
+
A.SmallestMaxSize(max_size=256, interpolation=cv2.INTER_LINEAR),
|
| 267 |
+
A.CenterCrop(height=224, width=224),
|
| 268 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
|
| 269 |
+
])
|
| 270 |
+
self.frame_features = []
|
| 271 |
+
self._load_models(self.model_dir)
|
| 272 |
+
|
| 273 |
+
def _amp_context(self):
|
| 274 |
+
return autocast("cuda") if MIXED_PRECISION_AVAILABLE and self.device.type == "cuda" else nullcontext()
|
| 275 |
+
|
| 276 |
+
def _resolve_local_dino_repo(self):
|
| 277 |
+
candidates = [_app_root() / "dinov2"]
|
| 278 |
+
app_root = _app_root()
|
| 279 |
+
if len(app_root.parents) >= 2:
|
| 280 |
+
candidates.append(app_root.parents[1] / "dinov2")
|
| 281 |
+
candidates.append(Path(torch.hub.get_dir()) / "facebookresearch_dinov2_main")
|
| 282 |
+
for candidate in candidates:
|
| 283 |
+
if (candidate / "hubconf.py").is_file():
|
| 284 |
+
return str(candidate)
|
| 285 |
+
raise FileNotFoundError("Local DINOv2 repo not found. Stage dinov2/ into this folder or keep the repo-root copy available.")
|
| 286 |
+
|
| 287 |
+
def _load_models(self, model_dir: str):
|
| 288 |
+
repo_path = self._resolve_local_dino_repo()
|
| 289 |
+
self.backbone = torch.hub.load(repo_path, "dinov2_vits14", source="local", pretrained=False)
|
| 290 |
+
|
| 291 |
+
encoder_path = os.path.join(model_dir, "dinov2_vit14s_latest_checkpoint.pth")
|
| 292 |
+
if not os.path.exists(encoder_path):
|
| 293 |
+
raise FileNotFoundError("DINOv2 encoder checkpoint not found")
|
| 294 |
+
encoder_checkpoint = _load_trusted_checkpoint(encoder_path, map_location="cpu")
|
| 295 |
+
encoder_state = encoder_checkpoint.get("student", encoder_checkpoint)
|
| 296 |
+
encoder_state = _strip_state_dict_prefixes(encoder_state, ("module.", "model."))
|
| 297 |
+
encoder_load = self.backbone.load_state_dict(encoder_state, strict=False)
|
| 298 |
+
_validate_load_result(encoder_load, "DINOv2 backbone")
|
| 299 |
+
self.backbone.to(self.device).eval()
|
| 300 |
+
|
| 301 |
+
decoder_path = os.path.join(model_dir, "fusion_transformer_decoder_best_model.pth")
|
| 302 |
+
if not os.path.exists(decoder_path):
|
| 303 |
+
raise FileNotFoundError("DINOv2 decoder checkpoint not found")
|
| 304 |
+
decoder_checkpoint = _load_trusted_checkpoint(decoder_path, map_location="cpu")
|
| 305 |
+
decoder_state = decoder_checkpoint.get("state_dict", decoder_checkpoint)
|
| 306 |
+
decoder_state = _strip_state_dict_prefixes(decoder_state, ("module.", "model."))
|
| 307 |
+
|
| 308 |
+
class FusionTransformerDecoder(nn.Module):
|
| 309 |
+
def __init__(self, feature_dim=384, num_classes=4, mstcn_stages=2, mstcn_layers=8, mstcn_f_maps=16, mstcn_f_dim=256, seq_length=30, d_model=256):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.reduce = nn.Linear(feature_dim, mstcn_f_dim)
|
| 312 |
+
self.mstcn = MultiStageModel(
|
| 313 |
+
mstcn_stages=mstcn_stages,
|
| 314 |
+
mstcn_layers=mstcn_layers,
|
| 315 |
+
mstcn_f_maps=mstcn_f_maps,
|
| 316 |
+
mstcn_f_dim=mstcn_f_dim,
|
| 317 |
+
out_features=num_classes,
|
| 318 |
+
mstcn_causal_conv=True,
|
| 319 |
+
is_train=False,
|
| 320 |
+
)
|
| 321 |
+
self.transformer = Transformer(
|
| 322 |
+
mstcn_f_maps=mstcn_f_maps,
|
| 323 |
+
mstcn_f_dim=mstcn_f_dim,
|
| 324 |
+
out_features=num_classes,
|
| 325 |
+
len_q=seq_length,
|
| 326 |
+
d_model=d_model,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
x = x.permute(0, 2, 1)
|
| 331 |
+
x_reduced = self.reduce(x)
|
| 332 |
+
mstcn_input = x_reduced.permute(0, 2, 1)
|
| 333 |
+
temporal_features = self.mstcn(mstcn_input)
|
| 334 |
+
if isinstance(temporal_features, (list, tuple)):
|
| 335 |
+
temporal_features = temporal_features[-1]
|
| 336 |
+
elif isinstance(temporal_features, torch.Tensor) and temporal_features.dim() == 4:
|
| 337 |
+
temporal_features = temporal_features[-1]
|
| 338 |
+
|
| 339 |
+
if temporal_features.shape[1] == mstcn_input.shape[1]:
|
| 340 |
+
transformer_input = temporal_features.detach()
|
| 341 |
+
else:
|
| 342 |
+
transformer_input = mstcn_input.detach()
|
| 343 |
+
|
| 344 |
+
transformer_out = self.transformer(transformer_input, x_reduced)
|
| 345 |
+
return transformer_out.permute(0, 2, 1)
|
| 346 |
+
|
| 347 |
+
self.decoder = FusionTransformerDecoder()
|
| 348 |
+
decoder_load = self.decoder.load_state_dict(decoder_state, strict=False)
|
| 349 |
+
_validate_load_result(
|
| 350 |
+
decoder_load,
|
| 351 |
+
"DINOv2 decoder",
|
| 352 |
+
allowed_unexpected_prefixes=(
|
| 353 |
+
"mstcn.stage1.conv_out_classes.",
|
| 354 |
+
"mstcn.stages.conv_out_classes.",
|
| 355 |
+
),
|
| 356 |
+
)
|
| 357 |
+
self.decoder.to(self.device).eval()
|
| 358 |
+
self.available = True
|
| 359 |
+
|
| 360 |
+
def reset_state(self):
|
| 361 |
+
self.frame_features = []
|
| 362 |
+
if torch.cuda.is_available():
|
| 363 |
+
torch.cuda.empty_cache()
|
| 364 |
+
|
| 365 |
+
def warm_up(self):
|
| 366 |
+
dummy_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
| 367 |
+
self.predict(dummy_img)
|
| 368 |
+
self.reset_state()
|
| 369 |
+
|
| 370 |
+
def unload(self):
|
| 371 |
+
if self.backbone is not None:
|
| 372 |
+
self.backbone.to("cpu")
|
| 373 |
+
if self.decoder is not None:
|
| 374 |
+
self.decoder.to("cpu")
|
| 375 |
+
self.backbone = None
|
| 376 |
+
self.decoder = None
|
| 377 |
+
self.frame_features = []
|
| 378 |
+
self.available = False
|
| 379 |
+
if torch.cuda.is_available():
|
| 380 |
+
torch.cuda.empty_cache()
|
| 381 |
+
|
| 382 |
+
@torch.inference_mode()
|
| 383 |
+
def predict(self, rgb_image: np.ndarray):
|
| 384 |
+
if not self.available or self.backbone is None or self.decoder is None:
|
| 385 |
+
raise RuntimeError("DINO-Endo predictor is not available")
|
| 386 |
+
|
| 387 |
+
processed = self.aug(image=rgb_image)["image"]
|
| 388 |
+
chw = np.transpose(processed, (2, 0, 1))
|
| 389 |
+
tensor = torch.tensor(chw, dtype=torch.float32).unsqueeze(0).to(self.device)
|
| 390 |
+
|
| 391 |
+
with self._amp_context():
|
| 392 |
+
feats = self.backbone.forward_features(tensor)
|
| 393 |
+
if isinstance(feats, dict):
|
| 394 |
+
feats = feats.get("x_norm_clstoken", next(iter(feats.values())))
|
| 395 |
+
if feats.dim() == 3:
|
| 396 |
+
feats = feats.mean(dim=1)
|
| 397 |
+
|
| 398 |
+
self.frame_features.append(feats.squeeze(0).detach().cpu())
|
| 399 |
+
if len(self.frame_features) > self.seq_length:
|
| 400 |
+
self.frame_features = self.frame_features[-self.seq_length:]
|
| 401 |
+
|
| 402 |
+
available_frames = len(self.frame_features)
|
| 403 |
+
seq = torch.stack(self.frame_features[-available_frames:]).unsqueeze(0).to(self.device)
|
| 404 |
+
if available_frames < self.seq_length:
|
| 405 |
+
last_frame = seq[:, -1:, :]
|
| 406 |
+
padding = last_frame.repeat(1, self.seq_length - available_frames, 1)
|
| 407 |
+
seq = torch.cat([seq, padding], dim=1)
|
| 408 |
+
|
| 409 |
+
decoder_input = seq.transpose(1, 2)
|
| 410 |
+
with self._amp_context():
|
| 411 |
+
logits = self.decoder(decoder_input)
|
| 412 |
+
|
| 413 |
+
if logits.dim() != 3:
|
| 414 |
+
raise ValueError(f"Unexpected DINOv2 decoder output shape: {tuple(logits.shape)}")
|
| 415 |
+
if logits.shape[1] == len(self.label_dict):
|
| 416 |
+
last = logits[0, :, -1]
|
| 417 |
+
elif logits.shape[2] == len(self.label_dict):
|
| 418 |
+
last = logits[0, -1, :]
|
| 419 |
+
else:
|
| 420 |
+
raise ValueError(f"Unexpected DINOv2 class dimension in decoder output: {tuple(logits.shape)}")
|
| 421 |
+
|
| 422 |
+
probs = torch.softmax(last, dim=0)
|
| 423 |
+
pred_np = probs.detach().cpu().numpy()
|
| 424 |
+
confidence = float(np.max(pred_np))
|
| 425 |
+
phase_idx = int(np.argmax(pred_np))
|
| 426 |
+
phase = self.label_dict.get(phase_idx, "idle")
|
| 427 |
+
return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class PredictorVJEPA2:
|
| 431 |
+
def __init__(self, model_dir: str | None = None, device: str = "cuda"):
|
| 432 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 433 |
+
self.model_dir = model_dir or default_model_dir()
|
| 434 |
+
self.available = False
|
| 435 |
+
self.encoder = None
|
| 436 |
+
self.decoder = None
|
| 437 |
+
self.label_dict = dict(enumerate(PHASE_LABELS))
|
| 438 |
+
self._clip_frames = 16
|
| 439 |
+
self._tubelet_size = 2
|
| 440 |
+
self._crop_size = 256
|
| 441 |
+
self._decoder_seq_length = 30
|
| 442 |
+
self._frame_buffer = []
|
| 443 |
+
self._feature_buffer = []
|
| 444 |
+
self._vjepa_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1, 1)
|
| 445 |
+
self._vjepa_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1, 1)
|
| 446 |
+
self._load_models(self.model_dir)
|
| 447 |
+
|
| 448 |
+
def _amp_context(self):
|
| 449 |
+
return autocast("cuda") if MIXED_PRECISION_AVAILABLE and self.device.type == "cuda" else nullcontext()
|
| 450 |
+
|
| 451 |
+
def _resolve_vjepa_repo(self):
|
| 452 |
+
extras = []
|
| 453 |
+
app_root = _app_root()
|
| 454 |
+
if len(app_root.parents) >= 2:
|
| 455 |
+
extras.append(app_root.parents[1] / "webapp" / "vjepa2")
|
| 456 |
+
return _resolve_vendor_repo("vjepa2", extras)
|
| 457 |
+
|
| 458 |
+
@staticmethod
|
| 459 |
+
def _clean_checkpoint_keys(state_dict):
|
| 460 |
+
cleaned_state = {}
|
| 461 |
+
for key, value in state_dict.items():
|
| 462 |
+
while key.startswith("module.") or key.startswith("backbone."):
|
| 463 |
+
if key.startswith("module."):
|
| 464 |
+
key = key[len("module.") :]
|
| 465 |
+
elif key.startswith("backbone."):
|
| 466 |
+
key = key[len("backbone.") :]
|
| 467 |
+
cleaned_state[key] = value
|
| 468 |
+
return cleaned_state
|
| 469 |
+
|
| 470 |
+
@staticmethod
|
| 471 |
+
def _validate_load_result(load_result, model_name: str):
|
| 472 |
+
if load_result.unexpected_keys:
|
| 473 |
+
sample = ", ".join(load_result.unexpected_keys[:5])
|
| 474 |
+
raise RuntimeError(f"{model_name} load had unexpected keys: {sample}")
|
| 475 |
+
if load_result.missing_keys:
|
| 476 |
+
sample = ", ".join(load_result.missing_keys[:5])
|
| 477 |
+
raise RuntimeError(f"{model_name} load missed required keys: {sample}")
|
| 478 |
+
|
| 479 |
+
def _extract_temporal_features(self, features: torch.Tensor) -> torch.Tensor:
|
| 480 |
+
if isinstance(features, dict):
|
| 481 |
+
features = features.get("x_norm_patchtokens", features.get("x_norm_clstoken", next(iter(features.values()))))
|
| 482 |
+
|
| 483 |
+
if features.dim() == 2:
|
| 484 |
+
return features.unsqueeze(1).repeat(1, self._clip_frames, 1)
|
| 485 |
+
if features.dim() != 3:
|
| 486 |
+
raise ValueError(f"Unexpected V-JEPA2 encoder output shape: {tuple(features.shape)}")
|
| 487 |
+
|
| 488 |
+
temporal_tokens = self._clip_frames // self._tubelet_size
|
| 489 |
+
if temporal_tokens <= 0:
|
| 490 |
+
raise ValueError("Invalid V-JEPA2 temporal configuration")
|
| 491 |
+
if features.shape[1] % temporal_tokens != 0:
|
| 492 |
+
raise ValueError(
|
| 493 |
+
f"Cannot reshape V-JEPA2 features of shape {tuple(features.shape)} into {temporal_tokens} temporal groups"
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
spatial_tokens = features.shape[1] // temporal_tokens
|
| 497 |
+
features = features.view(features.shape[0], temporal_tokens, spatial_tokens, features.shape[2]).mean(dim=2)
|
| 498 |
+
return features.repeat_interleave(self._tubelet_size, dim=1)[:, : self._clip_frames, :]
|
| 499 |
+
|
| 500 |
+
def _preprocess_clip(self, frames) -> torch.Tensor:
|
| 501 |
+
resized_frames = [cv2.resize(frame, (self._crop_size, self._crop_size), interpolation=cv2.INTER_LINEAR) for frame in frames]
|
| 502 |
+
clip = np.stack(resized_frames, axis=0).astype(np.float32) / 255.0
|
| 503 |
+
tensor = torch.from_numpy(np.transpose(clip, (3, 0, 1, 2)))
|
| 504 |
+
return (tensor - self._vjepa_mean) / self._vjepa_std
|
| 505 |
+
|
| 506 |
+
def _load_models(self, model_dir: str):
|
| 507 |
+
vjepa2_path = self._resolve_vjepa_repo()
|
| 508 |
+
if str(vjepa2_path) not in sys.path:
|
| 509 |
+
sys.path.insert(0, str(vjepa2_path))
|
| 510 |
+
|
| 511 |
+
from src.models import vision_transformer as vjepa_vit
|
| 512 |
+
from src.utils.checkpoint_loader import robust_checkpoint_loader
|
| 513 |
+
|
| 514 |
+
encoder_path = os.path.join(model_dir, "vjepa_encoder_human.pt")
|
| 515 |
+
if not os.path.exists(encoder_path):
|
| 516 |
+
raise FileNotFoundError("V-JEPA2 encoder not found")
|
| 517 |
+
|
| 518 |
+
checkpoint = robust_checkpoint_loader(encoder_path, map_location=torch.device("cpu"))
|
| 519 |
+
encoder_state = self._clean_checkpoint_keys(checkpoint.get("encoder", checkpoint))
|
| 520 |
+
|
| 521 |
+
self.encoder = vjepa_vit.vit_large(
|
| 522 |
+
patch_size=16,
|
| 523 |
+
num_frames=self._clip_frames,
|
| 524 |
+
tubelet_size=self._tubelet_size,
|
| 525 |
+
img_size=self._crop_size,
|
| 526 |
+
uniform_power=True,
|
| 527 |
+
use_sdpa=True,
|
| 528 |
+
use_rope=True,
|
| 529 |
+
)
|
| 530 |
+
encoder_load = self.encoder.load_state_dict(encoder_state, strict=False)
|
| 531 |
+
self._validate_load_result(encoder_load, "V-JEPA2 encoder")
|
| 532 |
+
self.encoder.to(self.device).eval()
|
| 533 |
+
|
| 534 |
+
decoder_path = os.path.join(model_dir, "mlp_decoder_human.pth")
|
| 535 |
+
if not os.path.exists(decoder_path):
|
| 536 |
+
raise FileNotFoundError("V-JEPA2 MLP decoder not found")
|
| 537 |
+
|
| 538 |
+
decoder_checkpoint = torch.load(decoder_path, map_location="cpu")
|
| 539 |
+
decoder_state = decoder_checkpoint.get("model", decoder_checkpoint)
|
| 540 |
+
decoder_in_dim = int(decoder_checkpoint.get("in_dim", 1024))
|
| 541 |
+
decoder_num_classes = int(decoder_checkpoint.get("num_classes", len(self.label_dict)))
|
| 542 |
+
self._decoder_seq_length = int(decoder_checkpoint.get("seq_length", self._decoder_seq_length))
|
| 543 |
+
|
| 544 |
+
class MLPDecoder(nn.Module):
|
| 545 |
+
def __init__(self, in_dim=1024, hidden_dim=256, num_classes=4):
|
| 546 |
+
super().__init__()
|
| 547 |
+
self.norm = nn.LayerNorm(in_dim)
|
| 548 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
| 549 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 550 |
+
self.fc3 = nn.Linear(hidden_dim, num_classes)
|
| 551 |
+
self.relu = nn.ReLU()
|
| 552 |
+
self.drop = nn.Dropout(0.5)
|
| 553 |
+
|
| 554 |
+
def forward(self, x):
|
| 555 |
+
x = x.mean(dim=1)
|
| 556 |
+
x = self.norm(x)
|
| 557 |
+
x = self.drop(self.relu(self.fc1(x)))
|
| 558 |
+
x = self.drop(self.relu(self.fc2(x)))
|
| 559 |
+
return self.fc3(x)
|
| 560 |
+
|
| 561 |
+
self.decoder = MLPDecoder(in_dim=decoder_in_dim, num_classes=decoder_num_classes)
|
| 562 |
+
self.decoder.load_state_dict(decoder_state, strict=True)
|
| 563 |
+
self.decoder.to(self.device).eval()
|
| 564 |
+
self.available = True
|
| 565 |
+
|
| 566 |
+
def reset_state(self):
|
| 567 |
+
self._frame_buffer = []
|
| 568 |
+
self._feature_buffer = []
|
| 569 |
+
if torch.cuda.is_available():
|
| 570 |
+
torch.cuda.empty_cache()
|
| 571 |
+
|
| 572 |
+
def warm_up(self):
|
| 573 |
+
dummy = np.random.randint(0, 255, (self._crop_size, self._crop_size, 3), dtype=np.uint8)
|
| 574 |
+
self.predict(dummy)
|
| 575 |
+
self.reset_state()
|
| 576 |
+
|
| 577 |
+
def unload(self):
|
| 578 |
+
if self.encoder is not None:
|
| 579 |
+
self.encoder.to("cpu")
|
| 580 |
+
if self.decoder is not None:
|
| 581 |
+
self.decoder.to("cpu")
|
| 582 |
+
self.encoder = None
|
| 583 |
+
self.decoder = None
|
| 584 |
+
self._frame_buffer = []
|
| 585 |
+
self._feature_buffer = []
|
| 586 |
+
self.available = False
|
| 587 |
+
if torch.cuda.is_available():
|
| 588 |
+
torch.cuda.empty_cache()
|
| 589 |
+
|
| 590 |
+
@torch.inference_mode()
|
| 591 |
+
def predict(self, rgb_image: np.ndarray):
|
| 592 |
+
if not self.available:
|
| 593 |
+
raise RuntimeError("V-JEPA2 predictor is not available")
|
| 594 |
+
|
| 595 |
+
frame = np.ascontiguousarray(rgb_image, dtype=np.uint8)
|
| 596 |
+
self._frame_buffer.append(frame)
|
| 597 |
+
if len(self._frame_buffer) > self._clip_frames:
|
| 598 |
+
self._frame_buffer = self._frame_buffer[-self._clip_frames:]
|
| 599 |
+
|
| 600 |
+
clip_frames = list(self._frame_buffer)
|
| 601 |
+
while len(clip_frames) < self._clip_frames:
|
| 602 |
+
clip_frames.append(clip_frames[-1])
|
| 603 |
+
|
| 604 |
+
tensor = self._preprocess_clip(clip_frames).unsqueeze(0).to(self.device)
|
| 605 |
+
with self._amp_context():
|
| 606 |
+
features = self._extract_temporal_features(self.encoder(tensor))
|
| 607 |
+
|
| 608 |
+
latest_feature_idx = min(len(self._frame_buffer), self._clip_frames) - 1
|
| 609 |
+
latest_feature = features[0, latest_feature_idx].float().detach().cpu()
|
| 610 |
+
self._feature_buffer.append(latest_feature)
|
| 611 |
+
if len(self._feature_buffer) > self._decoder_seq_length:
|
| 612 |
+
self._feature_buffer = self._feature_buffer[-self._decoder_seq_length:]
|
| 613 |
+
|
| 614 |
+
available_frames = len(self._feature_buffer)
|
| 615 |
+
seq = torch.stack(self._feature_buffer, dim=0).unsqueeze(0).to(self.device)
|
| 616 |
+
if available_frames < self._decoder_seq_length:
|
| 617 |
+
padding = seq[:, -1:, :].repeat(1, self._decoder_seq_length - available_frames, 1)
|
| 618 |
+
seq = torch.cat([seq, padding], dim=1)
|
| 619 |
+
|
| 620 |
+
with self._amp_context():
|
| 621 |
+
logits = self.decoder(seq)
|
| 622 |
+
|
| 623 |
+
probs = torch.softmax(logits[0], dim=0)
|
| 624 |
+
pred_np = probs.detach().cpu().numpy()
|
| 625 |
+
confidence = float(np.max(pred_np))
|
| 626 |
+
phase_idx = int(np.argmax(pred_np))
|
| 627 |
+
phase = self.label_dict.get(phase_idx, "idle")
|
| 628 |
+
return {"phase": phase, "probs": pred_np.tolist(), "confidence": confidence, "frames_used": available_frames}
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def create_predictor(model_key: str, model_dir: str | None = None, device: str | None = None):
|
| 632 |
+
resolved_key = normalize_model_key(model_key)
|
| 633 |
+
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 634 |
+
resolved_model_dir = model_dir or default_model_dir()
|
| 635 |
+
|
| 636 |
+
if resolved_key == "aiendo":
|
| 637 |
+
return Predictor(model_dir=resolved_model_dir, device=resolved_device)
|
| 638 |
+
if resolved_key == "dinov2":
|
| 639 |
+
return PredictorDinoV2(model_dir=resolved_model_dir, device=resolved_device)
|
| 640 |
+
if resolved_key == "vjepa2":
|
| 641 |
+
return PredictorVJEPA2(model_dir=resolved_model_dir, device=resolved_device)
|
| 642 |
+
raise KeyError(f"Unsupported model key: {model_key}")
|
requirements.txt
CHANGED
|
@@ -1,4 +1,13 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
streamlit>=1.40,<2
|
| 3 |
+
torch==2.5.1
|
| 4 |
+
torchvision==0.20.1
|
| 5 |
+
numpy>=1.26,<3
|
| 6 |
+
pandas>=2.2,<3
|
| 7 |
+
opencv-python-headless>=4.10,<5
|
| 8 |
+
pillow>=10,<12
|
| 9 |
+
albumentations>=2.0,<3
|
| 10 |
+
huggingface_hub>=0.27,<1
|
| 11 |
+
pyyaml>=6,<7
|
| 12 |
+
timm>=1.0,<2
|
| 13 |
+
einops>=0.8,<1
|
scripts/smoke_test.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
SCRIPT_PATH = Path(__file__).resolve()
|
| 12 |
+
SPACE_ROOT = SCRIPT_PATH.parents[1]
|
| 13 |
+
if str(SPACE_ROOT) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(SPACE_ROOT))
|
| 15 |
+
|
| 16 |
+
from predictor import create_predictor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
MODEL_REQUIREMENTS = {
|
| 20 |
+
"aiendo": ("resnet50.pth", "fusion.pth", "transformer.pth"),
|
| 21 |
+
"dinov2": ("dinov2_vit14s_latest_checkpoint.pth", "fusion_transformer_decoder_best_model.pth"),
|
| 22 |
+
"vjepa2": ("vjepa_encoder_human.pt", "mlp_decoder_human.pth"),
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parse_args() -> argparse.Namespace:
|
| 27 |
+
parser = argparse.ArgumentParser(description="Smoke test the isolated HF Space predictors.")
|
| 28 |
+
parser.add_argument("--model", choices=sorted(MODEL_REQUIREMENTS), required=True)
|
| 29 |
+
parser.add_argument("--model-dir", default=str(SPACE_ROOT / "model"))
|
| 30 |
+
parser.add_argument("--device", default="cuda")
|
| 31 |
+
parser.add_argument("--image-size", type=int, default=256)
|
| 32 |
+
return parser.parse_args()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main() -> None:
|
| 36 |
+
args = parse_args()
|
| 37 |
+
model_dir = Path(args.model_dir).expanduser().resolve()
|
| 38 |
+
missing = [name for name in MODEL_REQUIREMENTS[args.model] if not (model_dir / name).exists()]
|
| 39 |
+
if missing:
|
| 40 |
+
raise FileNotFoundError(f"Missing required checkpoints in {model_dir}: {', '.join(missing)}")
|
| 41 |
+
|
| 42 |
+
os.environ["SPACE_MODEL_DIR"] = str(model_dir)
|
| 43 |
+
dummy = np.random.randint(0, 255, (args.image_size, args.image_size, 3), dtype=np.uint8)
|
| 44 |
+
|
| 45 |
+
predictor = create_predictor(args.model, model_dir=str(model_dir), device=args.device)
|
| 46 |
+
predictor.reset_state()
|
| 47 |
+
result = predictor.predict(dummy)
|
| 48 |
+
predictor.unload()
|
| 49 |
+
|
| 50 |
+
print(f"model={args.model}")
|
| 51 |
+
print(f"phase={result.get('phase')}")
|
| 52 |
+
print(f"confidence={result.get('confidence')}")
|
| 53 |
+
print(f"frames_used={result.get('frames_used')}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
main()
|