ViT + Recursive Reasoning (TRM) for Video Classification

A lightweight video classification model that applies recursive reasoning cycles over spatial tokens from a Vision Transformer, achieving strong results on HMDB51 with only ~6M parameters.

Model Description

This model uses the Tiny Recursive Model (TRM) architecture adapted for video understanding. Instead of extracting a single CLS token per frame and compensating with deep temporal modeling, this approach retains all spatial patch tokens and refines them through shared-weight iterative reasoning cycles before temporal aggregation.

Architecture

Video Frames β†’ ViT (per-frame) β†’ All Patch Tokens β†’ Mean Pool
    β†’ TRM Reasoning (H=2 cycles, L=2 shared layers) β†’ Temporal Transformer (1 layer)
    β†’ Classifier (51 classes)
  • Backbone: vit_tiny_patch16_224 (pretrained on ImageNet)
  • TRM: 2 reasoning cycles (H=2), 2 shared transformer layers per cycle (L=2), 4 attention heads
  • Temporal: 1 transformer encoder layer
  • Total parameters: ~6M

Key Innovation

The TRM reasoning module uses weight sharing across cycles β€” the same transformer layer is applied multiple times, producing deeper understanding without increasing model size. This is consistent with the original TRM paper (Jolicoeur-Martineau, 2025).

Results on HMDB51

Model Parameters Val Accuracy Test Accuracy
Baseline (standard pipeline) ~6M 60% 58%
Recursive Reasoning (2 cycles) ~6M 69.4% 69%
Improvement β€” +9.4 pts +11 pts

Optimal Reasoning Depth

Performance peaks at 2 reasoning cycles:

  • 1 cycle: 68.4%
  • 2 cycles: 69.4% (best)
  • 3 cycles: 67.0%

Training Configuration

backbone: vit_tiny_patch16_224 (pretrained)
trm_H_cycles: 2
trm_L_layers: 2
trm_num_heads: 4
temporal_num_layers: 1
temporal_num_heads: 4
learning_rate: 3e-4
weight_decay: 0.05
warmup_epochs: 5
max_epochs: 30
batch_size: 8
num_temporal_clips: 4
num_frames: 16
frame_stride: 4
label_smoothing: 0.1
seed: 22
optimizer: AdamW
scheduler: cosine with warmup

Usage

import torch
from vit_trm_video import ViTTRMVideo

# Load from checkpoint
model = ViTTRMVideo.load_from_checkpoint(
    "vit-trm-epoch=29-val_acc=0.7113.ckpt",
    strict=False,
)
model.eval()

# Inference: video tensor of shape (batch, num_frames, 3, 224, 224)
video = torch.randn(1, 16, 3, 224, 224)
with torch.no_grad():
    logits = model(video)
    predicted_class = logits.argmax(dim=-1)

Files

  • vit-trm-epoch=29-val_acc=0.7113.ckpt β€” PyTorch Lightning checkpoint (best validation accuracy)
  • vit_trm_video.py β€” Model architecture (ViTTRMVideo)
  • vit_video_baseline.py β€” Baseline model for comparison (ViTVideoBaseline)

Citation

If you use this model, please cite:

@article{jolicoeur2025less,
  title={Less is More: Recursive Reasoning with Tiny Networks},
  author={Jolicoeur-Martineau, Alexia},
  journal={arXiv preprint arXiv:2510.04871},
  year={2025}
}

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for adelabdalla221/vit-trm-hmdb51

Finetunes
1 model

Spaces using adelabdalla221/vit-trm-hmdb51 2

Paper for adelabdalla221/vit-trm-hmdb51

Evaluation results