| --- |
| license: apache-2.0 |
| library_name: pytorch |
| tags: |
| - video-segmentation |
| - object-grouping |
| - trajectory-tokens |
| - perceiver |
| - dinov3 |
| - vision |
| datasets: |
| - facebook/segment-anything |
| - facebook/sav-dataset |
| pipeline_tag: image-segmentation |
| --- |
| |
| # TrajTok-v2 Segmenter |
|
|
| The trajectory segmenter from **TrajTok-v2** — a class-agnostic spatio-temporal |
| object grouper that maps an image or video clip into ≤ K (default 128) |
| **trajectory tokens**. Each token binds patches that belong to the same |
| object instance over space and time. |
|
|
| This checkpoint is the headline release: trained on a mixture of ~12 M |
| samples (SA-1B, SA-V, filtered image / video pairs) for 3 epochs on |
| 2 nodes × 8 H100s. |
|
|
| ## Architecture |
|
|
| ``` |
| video / image (T, 3, 224, 224) |
| │ |
| ↓ |
| DINOv3-small ConvNeXt → patch features F (T·56·56, D=512) |
| │ |
| ↓ |
| PerceiverResampler (K=128 learnable trajectory queries, depth=2) |
| │ |
| ↓ |
| soft-mask assignment: M[k, p] = softmax_k(q_k · F_p) (paper Eq. 1) |
| │ |
| ↓ |
| trajectory tokens: z_k = Σ_p M[k, p] · F_p (paper Eq. 2) |
| ``` |
|
|
| Total parameters: ~59 M. |
|
|
| ## Intended use |
|
|
| - **Token-efficient visual encoders** for downstream models (e.g. VLMs): |
| swap your patch-token-based vision encoder for this segmenter to get ≤ 128 |
| object-grounded tokens per clip instead of hundreds of grid patches. |
| - **Class-agnostic object proposal / tracking** for retrieval, captioning, |
| or analytics pipelines that need lightweight instance grouping. |
| - **Starting point for fine-tuning** on specialized domains (medical, |
| satellite, robotics) where you have unlabeled video. |
|
|
| ## How to use |
|
|
| ```python |
| import torch, yaml |
| from easydict import EasyDict as edict |
| from trajtok_segmenter.model.segmenter import SimpleSegmenter |
| |
| # Load the released checkpoint |
| state = torch.load("path/to/latest.pth", map_location="cpu", weights_only=False) |
| sd = state["model"] |
| # Strip outer SegmentWrapper prefix (the training script wraps SimpleSegmenter) |
| sd = {k[len("vision_encoder."):] if k.startswith("vision_encoder.") else k: v for k, v in sd.items()} |
| |
| # Build matching architecture |
| cfg = yaml.safe_load(open("trajtokv2/segmenter/configs/pretrain.yaml")) |
| model = SimpleSegmenter( |
| config=edict(cfg["traj_model"]), |
| backbone_config=edict(cfg["backbone"]), |
| perceiver_config=edict(cfg["perceiver"]), |
| high_res=False, |
| ).cuda().eval() |
| model.load_state_dict(sd, strict=False) |
| |
| # Run forward on a clip |
| video = torch.randn(1, 8, 3, 224, 224).cuda() # (B, T, 3, H, W); T=1 for images |
| with torch.no_grad(): |
| logits = model(video) # (B, N=T·56·56, K=128) |
| traj_id = logits.argmax(-1) # per-patch trajectory ID |
| soft_mask = logits.softmax(-1) # per-patch trajectory weight |
| ``` |
|
|
| See the [main repository](https://github.com/hellomuffin/trajtokv2) for the |
| full demo (`segmenter/scripts/demo_image.py`), evaluation drivers |
| (DAVIS / MOSE / YT-VIS), and training code. |
|
|
| ## Training data |
|
|
| The released checkpoint was trained on the `filteredmixdata_all` mixture: |
|
|
| | Source | Samples | Type | |
| |---|---|---| |
| | `big_image_new` | ~300 K | filtered image-caption pairs with auto-generated trajectory masks | |
| | `big_video_new` | ~1 M | filtered video-caption pairs with auto-generated per-frame trajectory masks | |
| | **SA-1B** ([Meta AI](https://ai.meta.com/datasets/segment-anything/)) | ~11 M | original SA-1B images + instance masks | |
| | **SA-V** ([Meta AI](https://ai.meta.com/datasets/segment-anything-video/)) | ~48 K | SA-V videos + per-frame instance masks | |
|
|
| Roughly 12.4 M samples in total, interleaved by media type via a MetaLoader. |
| The segmenter's perceiver was trained from scratch (random Fourier init); the |
| DINOv3-small backbone was initialised from Meta's |
| [DINOv3 ConvNeXt-small public release](https://github.com/facebookresearch/dinov3) |
| and fine-tuned end-to-end. |
|
|
| ## Training configuration |
|
|
| | Knob | Value | |
| |---|---| |
| | Trajectory tokens K | 128 | |
| | Embedding dim | 512 | |
| | Backbone | DINOv3-small ConvNeXt | |
| | Perceiver depth | 2 | |
| | Input resolution | 224×224 | |
| | Latent grid | 56×56 | |
| | Loss | dice + focal (per-patch class loss) + per-patch pixel loss | |
| | Optimizer | AdamW (lr=1e-4, wd=0.02) | |
| | Schedule | cosine, 1-epoch warmup | |
| | Epochs | 3 | |
| | Per-modality batch size | image=64, video=8, sa1b=64, sav=8 | |
| | Hardware | 2 nodes × 8 × H100 (80 GB) | |
|
|
| ## Limitations |
|
|
| - **Frame count cap**: trained at T ≤ 8 frames per clip. Longer-clip |
| behaviour at inference is untested; use `merge_tracklets` from |
| `trajtok_segmenter.eval.eval_segmenter` to stitch IDs across windows. |
| - **Resolution**: trained at 224×224 inputs producing a 56×56 trajectory |
| grid. Other resolutions work but degrade quality away from this point. |
| - **Class-agnostic only**: outputs trajectory IDs, not class labels. Pair |
| with an open-vocabulary captioner / classifier for semantic tags. |
| - **Domain bias**: SA-1B + SA-V are skewed towards everyday scenes; expect |
| domain-shift drops on medical, satellite, or stylised content. |
|
|
| ## Citation |
|
|
| ```bibtex |
| @article{zheng2026trajtokv2, |
| title = {TrajTok-v2: Trajectory-aware visual tokenization for vision-language models}, |
| author = {Zheng, Chenhao and others}, |
| journal = {arXiv preprint arXiv:2602.22779}, |
| year = {2026}, |
| } |
| ``` |
|
|
| ## License |
|
|
| Apache-2.0. Bundled DINOv3 ConvNeXt-small backbone weights (downloaded |
| separately) are also Apache-2.0 (Meta AI). SA-1B and SA-V training data |
| are licensed under their respective terms by Meta AI. |
|
|