TrajTok-v2 Segmenter
The trajectory segmenter from TrajTok-v2 — a class-agnostic spatio-temporal object grouper that maps an image or video clip into ≤ K (default 128) trajectory tokens. Each token binds patches that belong to the same object instance over space and time.
This checkpoint is the headline release: trained on a mixture of ~12 M samples (SA-1B, SA-V, filtered image / video pairs) for 3 epochs on 2 nodes × 8 H100s.
Architecture
video / image (T, 3, 224, 224)
│
↓
DINOv3-small ConvNeXt → patch features F (T·56·56, D=512)
│
↓
PerceiverResampler (K=128 learnable trajectory queries, depth=2)
│
↓
soft-mask assignment: M[k, p] = softmax_k(q_k · F_p) (paper Eq. 1)
│
↓
trajectory tokens: z_k = Σ_p M[k, p] · F_p (paper Eq. 2)
Total parameters: ~59 M.
Intended use
- Token-efficient visual encoders for downstream models (e.g. VLMs): swap your patch-token-based vision encoder for this segmenter to get ≤ 128 object-grounded tokens per clip instead of hundreds of grid patches.
- Class-agnostic object proposal / tracking for retrieval, captioning, or analytics pipelines that need lightweight instance grouping.
- Starting point for fine-tuning on specialized domains (medical, satellite, robotics) where you have unlabeled video.
How to use
import torch, yaml
from easydict import EasyDict as edict
from trajtok_segmenter.model.segmenter import SimpleSegmenter
# Load the released checkpoint
state = torch.load("path/to/latest.pth", map_location="cpu", weights_only=False)
sd = state["model"]
# Strip outer SegmentWrapper prefix (the training script wraps SimpleSegmenter)
sd = {k[len("vision_encoder."):] if k.startswith("vision_encoder.") else k: v for k, v in sd.items()}
# Build matching architecture
cfg = yaml.safe_load(open("trajtokv2/segmenter/configs/pretrain.yaml"))
model = SimpleSegmenter(
config=edict(cfg["traj_model"]),
backbone_config=edict(cfg["backbone"]),
perceiver_config=edict(cfg["perceiver"]),
high_res=False,
).cuda().eval()
model.load_state_dict(sd, strict=False)
# Run forward on a clip
video = torch.randn(1, 8, 3, 224, 224).cuda() # (B, T, 3, H, W); T=1 for images
with torch.no_grad():
logits = model(video) # (B, N=T·56·56, K=128)
traj_id = logits.argmax(-1) # per-patch trajectory ID
soft_mask = logits.softmax(-1) # per-patch trajectory weight
See the main repository for the
full demo (segmenter/scripts/demo_image.py), evaluation drivers
(DAVIS / MOSE / YT-VIS), and training code.
Training data
The released checkpoint was trained on the filteredmixdata_all mixture:
| Source | Samples | Type |
|---|---|---|
big_image_new |
~300 K | filtered image-caption pairs with auto-generated trajectory masks |
big_video_new |
~1 M | filtered video-caption pairs with auto-generated per-frame trajectory masks |
| SA-1B (Meta AI) | ~11 M | original SA-1B images + instance masks |
| SA-V (Meta AI) | ~48 K | SA-V videos + per-frame instance masks |
Roughly 12.4 M samples in total, interleaved by media type via a MetaLoader. The segmenter's perceiver was trained from scratch (random Fourier init); the DINOv3-small backbone was initialised from Meta's DINOv3 ConvNeXt-small public release and fine-tuned end-to-end.
Training configuration
| Knob | Value |
|---|---|
| Trajectory tokens K | 128 |
| Embedding dim | 512 |
| Backbone | DINOv3-small ConvNeXt |
| Perceiver depth | 2 |
| Input resolution | 224×224 |
| Latent grid | 56×56 |
| Loss | dice + focal (per-patch class loss) + per-patch pixel loss |
| Optimizer | AdamW (lr=1e-4, wd=0.02) |
| Schedule | cosine, 1-epoch warmup |
| Epochs | 3 |
| Per-modality batch size | image=64, video=8, sa1b=64, sav=8 |
| Hardware | 2 nodes × 8 × H100 (80 GB) |
Limitations
- Frame count cap: trained at T ≤ 8 frames per clip. Longer-clip
behaviour at inference is untested; use
merge_trackletsfromtrajtok_segmenter.eval.eval_segmenterto stitch IDs across windows. - Resolution: trained at 224×224 inputs producing a 56×56 trajectory grid. Other resolutions work but degrade quality away from this point.
- Class-agnostic only: outputs trajectory IDs, not class labels. Pair with an open-vocabulary captioner / classifier for semantic tags.
- Domain bias: SA-1B + SA-V are skewed towards everyday scenes; expect domain-shift drops on medical, satellite, or stylised content.
Citation
@article{zheng2026trajtokv2,
title = {TrajTok-v2: Trajectory-aware visual tokenization for vision-language models},
author = {Zheng, Chenhao and others},
journal = {arXiv preprint arXiv:2602.22779},
year = {2026},
}
License
Apache-2.0. Bundled DINOv3 ConvNeXt-small backbone weights (downloaded separately) are also Apache-2.0 (Meta AI). SA-1B and SA-V training data are licensed under their respective terms by Meta AI.