TrajTok-v2 Segmenter

The trajectory segmenter from TrajTok-v2 — a class-agnostic spatio-temporal object grouper that maps an image or video clip into ≤ K (default 128) trajectory tokens. Each token binds patches that belong to the same object instance over space and time.

This checkpoint is the headline release: trained on a mixture of ~12 M samples (SA-1B, SA-V, filtered image / video pairs) for 3 epochs on 2 nodes × 8 H100s.

Architecture

video / image  (T, 3, 224, 224)
        │
        ↓
  DINOv3-small ConvNeXt  →  patch features F  (T·56·56, D=512)
        │
        ↓
  PerceiverResampler (K=128 learnable trajectory queries, depth=2)
        │
        ↓
  soft-mask assignment:  M[k, p] = softmax_k(q_k · F_p)    (paper Eq. 1)
        │
        ↓
  trajectory tokens:     z_k = Σ_p M[k, p] · F_p           (paper Eq. 2)

Total parameters: ~59 M.

Intended use

  • Token-efficient visual encoders for downstream models (e.g. VLMs): swap your patch-token-based vision encoder for this segmenter to get ≤ 128 object-grounded tokens per clip instead of hundreds of grid patches.
  • Class-agnostic object proposal / tracking for retrieval, captioning, or analytics pipelines that need lightweight instance grouping.
  • Starting point for fine-tuning on specialized domains (medical, satellite, robotics) where you have unlabeled video.

How to use

import torch, yaml
from easydict import EasyDict as edict
from trajtok_segmenter.model.segmenter import SimpleSegmenter

# Load the released checkpoint
state = torch.load("path/to/latest.pth", map_location="cpu", weights_only=False)
sd = state["model"]
# Strip outer SegmentWrapper prefix (the training script wraps SimpleSegmenter)
sd = {k[len("vision_encoder."):] if k.startswith("vision_encoder.") else k: v for k, v in sd.items()}

# Build matching architecture
cfg = yaml.safe_load(open("trajtokv2/segmenter/configs/pretrain.yaml"))
model = SimpleSegmenter(
    config=edict(cfg["traj_model"]),
    backbone_config=edict(cfg["backbone"]),
    perceiver_config=edict(cfg["perceiver"]),
    high_res=False,
).cuda().eval()
model.load_state_dict(sd, strict=False)

# Run forward on a clip
video = torch.randn(1, 8, 3, 224, 224).cuda()         # (B, T, 3, H, W); T=1 for images
with torch.no_grad():
    logits = model(video)                              # (B, N=T·56·56, K=128)
traj_id = logits.argmax(-1)                            # per-patch trajectory ID
soft_mask = logits.softmax(-1)                         # per-patch trajectory weight

See the main repository for the full demo (segmenter/scripts/demo_image.py), evaluation drivers (DAVIS / MOSE / YT-VIS), and training code.

Training data

The released checkpoint was trained on the filteredmixdata_all mixture:

Source Samples Type
big_image_new ~300 K filtered image-caption pairs with auto-generated trajectory masks
big_video_new ~1 M filtered video-caption pairs with auto-generated per-frame trajectory masks
SA-1B (Meta AI) ~11 M original SA-1B images + instance masks
SA-V (Meta AI) ~48 K SA-V videos + per-frame instance masks

Roughly 12.4 M samples in total, interleaved by media type via a MetaLoader. The segmenter's perceiver was trained from scratch (random Fourier init); the DINOv3-small backbone was initialised from Meta's DINOv3 ConvNeXt-small public release and fine-tuned end-to-end.

Training configuration

Knob Value
Trajectory tokens K 128
Embedding dim 512
Backbone DINOv3-small ConvNeXt
Perceiver depth 2
Input resolution 224×224
Latent grid 56×56
Loss dice + focal (per-patch class loss) + per-patch pixel loss
Optimizer AdamW (lr=1e-4, wd=0.02)
Schedule cosine, 1-epoch warmup
Epochs 3
Per-modality batch size image=64, video=8, sa1b=64, sav=8
Hardware 2 nodes × 8 × H100 (80 GB)

Limitations

  • Frame count cap: trained at T ≤ 8 frames per clip. Longer-clip behaviour at inference is untested; use merge_tracklets from trajtok_segmenter.eval.eval_segmenter to stitch IDs across windows.
  • Resolution: trained at 224×224 inputs producing a 56×56 trajectory grid. Other resolutions work but degrade quality away from this point.
  • Class-agnostic only: outputs trajectory IDs, not class labels. Pair with an open-vocabulary captioner / classifier for semantic tags.
  • Domain bias: SA-1B + SA-V are skewed towards everyday scenes; expect domain-shift drops on medical, satellite, or stylised content.

Citation

@article{zheng2026trajtokv2,
  title   = {TrajTok-v2: Trajectory-aware visual tokenization for vision-language models},
  author  = {Zheng, Chenhao and others},
  journal = {arXiv preprint arXiv:2602.22779},
  year    = {2026},
}

License

Apache-2.0. Bundled DINOv3 ConvNeXt-small backbone weights (downloaded separately) are also Apache-2.0 (Meta AI). SA-1B and SA-V training data are licensed under their respective terms by Meta AI.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for michaelzch001/trajtokv2-segmenter