trajtok-segmenter / README.md
michaelzch001's picture
Add model card
97332cc verified
---
license: apache-2.0
library_name: pytorch
tags:
- video-segmentation
- object-grouping
- trajectory-tokens
- perceiver
- dinov3
- vision
datasets:
- facebook/segment-anything
- facebook/sav-dataset
pipeline_tag: image-segmentation
---
# TrajTok-v2 Segmenter
The trajectory segmenter from **TrajTok-v2** — a class-agnostic spatio-temporal
object grouper that maps an image or video clip into ≤ K (default 128)
**trajectory tokens**. Each token binds patches that belong to the same
object instance over space and time.
This checkpoint is the headline release: trained on a mixture of ~12 M
samples (SA-1B, SA-V, filtered image / video pairs) for 3 epochs on
2 nodes × 8 H100s.
## Architecture
```
video / image (T, 3, 224, 224)
DINOv3-small ConvNeXt → patch features F (T·56·56, D=512)
PerceiverResampler (K=128 learnable trajectory queries, depth=2)
soft-mask assignment: M[k, p] = softmax_k(q_k · F_p) (paper Eq. 1)
trajectory tokens: z_k = Σ_p M[k, p] · F_p (paper Eq. 2)
```
Total parameters: ~59 M.
## Intended use
- **Token-efficient visual encoders** for downstream models (e.g. VLMs):
swap your patch-token-based vision encoder for this segmenter to get ≤ 128
object-grounded tokens per clip instead of hundreds of grid patches.
- **Class-agnostic object proposal / tracking** for retrieval, captioning,
or analytics pipelines that need lightweight instance grouping.
- **Starting point for fine-tuning** on specialized domains (medical,
satellite, robotics) where you have unlabeled video.
## How to use
```python
import torch, yaml
from easydict import EasyDict as edict
from trajtok_segmenter.model.segmenter import SimpleSegmenter
# Load the released checkpoint
state = torch.load("path/to/latest.pth", map_location="cpu", weights_only=False)
sd = state["model"]
# Strip outer SegmentWrapper prefix (the training script wraps SimpleSegmenter)
sd = {k[len("vision_encoder."):] if k.startswith("vision_encoder.") else k: v for k, v in sd.items()}
# Build matching architecture
cfg = yaml.safe_load(open("trajtokv2/segmenter/configs/pretrain.yaml"))
model = SimpleSegmenter(
config=edict(cfg["traj_model"]),
backbone_config=edict(cfg["backbone"]),
perceiver_config=edict(cfg["perceiver"]),
high_res=False,
).cuda().eval()
model.load_state_dict(sd, strict=False)
# Run forward on a clip
video = torch.randn(1, 8, 3, 224, 224).cuda() # (B, T, 3, H, W); T=1 for images
with torch.no_grad():
logits = model(video) # (B, N=T·56·56, K=128)
traj_id = logits.argmax(-1) # per-patch trajectory ID
soft_mask = logits.softmax(-1) # per-patch trajectory weight
```
See the [main repository](https://github.com/hellomuffin/trajtokv2) for the
full demo (`segmenter/scripts/demo_image.py`), evaluation drivers
(DAVIS / MOSE / YT-VIS), and training code.
## Training data
The released checkpoint was trained on the `filteredmixdata_all` mixture:
| Source | Samples | Type |
|---|---|---|
| `big_image_new` | ~300 K | filtered image-caption pairs with auto-generated trajectory masks |
| `big_video_new` | ~1 M | filtered video-caption pairs with auto-generated per-frame trajectory masks |
| **SA-1B** ([Meta AI](https://ai.meta.com/datasets/segment-anything/)) | ~11 M | original SA-1B images + instance masks |
| **SA-V** ([Meta AI](https://ai.meta.com/datasets/segment-anything-video/)) | ~48 K | SA-V videos + per-frame instance masks |
Roughly 12.4 M samples in total, interleaved by media type via a MetaLoader.
The segmenter's perceiver was trained from scratch (random Fourier init); the
DINOv3-small backbone was initialised from Meta's
[DINOv3 ConvNeXt-small public release](https://github.com/facebookresearch/dinov3)
and fine-tuned end-to-end.
## Training configuration
| Knob | Value |
|---|---|
| Trajectory tokens K | 128 |
| Embedding dim | 512 |
| Backbone | DINOv3-small ConvNeXt |
| Perceiver depth | 2 |
| Input resolution | 224×224 |
| Latent grid | 56×56 |
| Loss | dice + focal (per-patch class loss) + per-patch pixel loss |
| Optimizer | AdamW (lr=1e-4, wd=0.02) |
| Schedule | cosine, 1-epoch warmup |
| Epochs | 3 |
| Per-modality batch size | image=64, video=8, sa1b=64, sav=8 |
| Hardware | 2 nodes × 8 × H100 (80 GB) |
## Limitations
- **Frame count cap**: trained at T ≤ 8 frames per clip. Longer-clip
behaviour at inference is untested; use `merge_tracklets` from
`trajtok_segmenter.eval.eval_segmenter` to stitch IDs across windows.
- **Resolution**: trained at 224×224 inputs producing a 56×56 trajectory
grid. Other resolutions work but degrade quality away from this point.
- **Class-agnostic only**: outputs trajectory IDs, not class labels. Pair
with an open-vocabulary captioner / classifier for semantic tags.
- **Domain bias**: SA-1B + SA-V are skewed towards everyday scenes; expect
domain-shift drops on medical, satellite, or stylised content.
## Citation
```bibtex
@article{zheng2026trajtokv2,
title = {TrajTok-v2: Trajectory-aware visual tokenization for vision-language models},
author = {Zheng, Chenhao and others},
journal = {arXiv preprint arXiv:2602.22779},
year = {2026},
}
```
## License
Apache-2.0. Bundled DINOv3 ConvNeXt-small backbone weights (downloaded
separately) are also Apache-2.0 (Meta AI). SA-1B and SA-V training data
are licensed under their respective terms by Meta AI.