ViT + Recursive Reasoning (TRM) for Video Classification
A lightweight video classification model that applies recursive reasoning cycles over spatial tokens from a Vision Transformer, achieving strong results on HMDB51 with only ~6M parameters.
Model Description
This model uses the Tiny Recursive Model (TRM) architecture adapted for video understanding. Instead of extracting a single CLS token per frame and compensating with deep temporal modeling, this approach retains all spatial patch tokens and refines them through shared-weight iterative reasoning cycles before temporal aggregation.
Architecture
Video Frames β ViT (per-frame) β All Patch Tokens β Mean Pool
β TRM Reasoning (H=2 cycles, L=2 shared layers) β Temporal Transformer (1 layer)
β Classifier (51 classes)
- Backbone:
vit_tiny_patch16_224(pretrained on ImageNet) - TRM: 2 reasoning cycles (H=2), 2 shared transformer layers per cycle (L=2), 4 attention heads
- Temporal: 1 transformer encoder layer
- Total parameters: ~6M
Key Innovation
The TRM reasoning module uses weight sharing across cycles β the same transformer layer is applied multiple times, producing deeper understanding without increasing model size. This is consistent with the original TRM paper (Jolicoeur-Martineau, 2025).
Results on HMDB51
| Model | Parameters | Val Accuracy | Test Accuracy |
|---|---|---|---|
| Baseline (standard pipeline) | ~6M | 60% | 58% |
| Recursive Reasoning (2 cycles) | ~6M | 69.4% | 69% |
| Improvement | β | +9.4 pts | +11 pts |
Optimal Reasoning Depth
Performance peaks at 2 reasoning cycles:
- 1 cycle: 68.4%
- 2 cycles: 69.4% (best)
- 3 cycles: 67.0%
Training Configuration
backbone: vit_tiny_patch16_224 (pretrained)
trm_H_cycles: 2
trm_L_layers: 2
trm_num_heads: 4
temporal_num_layers: 1
temporal_num_heads: 4
learning_rate: 3e-4
weight_decay: 0.05
warmup_epochs: 5
max_epochs: 30
batch_size: 8
num_temporal_clips: 4
num_frames: 16
frame_stride: 4
label_smoothing: 0.1
seed: 22
optimizer: AdamW
scheduler: cosine with warmup
Usage
import torch
from vit_trm_video import ViTTRMVideo
# Load from checkpoint
model = ViTTRMVideo.load_from_checkpoint(
"vit-trm-epoch=29-val_acc=0.7113.ckpt",
strict=False,
)
model.eval()
# Inference: video tensor of shape (batch, num_frames, 3, 224, 224)
video = torch.randn(1, 16, 3, 224, 224)
with torch.no_grad():
logits = model(video)
predicted_class = logits.argmax(dim=-1)
Files
vit-trm-epoch=29-val_acc=0.7113.ckptβ PyTorch Lightning checkpoint (best validation accuracy)vit_trm_video.pyβ Model architecture (ViTTRMVideo)vit_video_baseline.pyβ Baseline model for comparison (ViTVideoBaseline)
Citation
If you use this model, please cite:
@article{jolicoeur2025less,
title={Less is More: Recursive Reasoning with Tiny Networks},
author={Jolicoeur-Martineau, Alexia},
journal={arXiv preprint arXiv:2510.04871},
year={2025}
}
License
Apache 2.0
Model tree for adelabdalla221/vit-trm-hmdb51
Spaces using adelabdalla221/vit-trm-hmdb51 2
Paper for adelabdalla221/vit-trm-hmdb51
Evaluation results
- Video-Level Test Accuracy on HMDB51test set self-reported69.000
- Video-Level Val Accuracy on HMDB51test set self-reported71.100