bcgxtberg commited on
Commit
77191d4
·
verified ·
1 Parent(s): 435c34f

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +69 -0
  2. ssv2_datamodule.py +362 -0
  3. train_ssv2.py +131 -0
  4. vit_trm_video.py +348 -0
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ViT-TRM on Something-Something V2
2
+
3
+ Extends the [ViT-TRM architecture](https://hf.co/adelabdalla221/vit-trm-hmdb51) from HMDB51 (51 classes) to **Something-Something V2** (174 fine-grained hand-object interaction classes).
4
+
5
+ ## Architecture
6
+
7
+ ```
8
+ Video Frames → ViT (per-frame) → Mean Pool → Positional Encoding
9
+ → TRM Reasoning (H=2 cycles, L=2 shared layers) → Mean Pool → Classifier (174 classes)
10
+ ```
11
+
12
+ - **Backbone**: `vit_tiny_patch16_224` (ImageNet pretrained)
13
+ - **TRM**: 2 cycles × 2 shared transformer layers, 4 heads (~6M params)
14
+ - **Dataset**: SSv2 — 174 template actions, ~220K videos of hand-object interactions
15
+
16
+ ## Setup
17
+
18
+ ```bash
19
+ pip install torch torchvision pytorch-lightning timm torchmetrics decord
20
+ ```
21
+
22
+ ## Getting the Data
23
+
24
+ **Option A: Local download** from [20BN](https://developer.qualcomm.com/software/ai-datasets/something-something):
25
+ ```
26
+ ssv2/
27
+ videos/ # .webm files (1.webm, 2.webm, ...)
28
+ labels/
29
+ train.json
30
+ validation.json
31
+ labels.json
32
+ ```
33
+
34
+ **Option B: HF Hub** (requires access): `HuggingFaceM4/something-something-v2`
35
+
36
+ ## Training
37
+
38
+ ```bash
39
+ # From scratch
40
+ python train_ssv2.py --data_dir /path/to/ssv2
41
+
42
+ # Transfer learning from HMDB51 checkpoint (recommended)
43
+ python train_ssv2.py \
44
+ --data_dir /path/to/ssv2 \
45
+ --pretrained_ckpt ../vit-trm-hmdb51/vit-trm-epoch=29-val_acc=0.7113.ckpt
46
+
47
+ # From HF Hub
48
+ python train_ssv2.py --from_hub
49
+ ```
50
+
51
+ ### Key flags
52
+
53
+ | Flag | Default | Description |
54
+ |------|---------|-------------|
55
+ | `--pretrained_ckpt` | None | Transfer backbone+TRM from HMDB51 |
56
+ | `--trm_H_cycles` | 2 | Number of recursive reasoning cycles |
57
+ | `--frame_stride` | 2 | Temporal stride (SSv2 videos are short) |
58
+ | `--num_frames` | 16 | Frames sampled per clip |
59
+ | `--batch_size` | 8 | Training batch size |
60
+ | `--max_epochs` | 30 | Training epochs |
61
+ | `--precision` | 16-mixed | Mixed precision training |
62
+
63
+ ## Why SSv2?
64
+
65
+ Unlike HMDB51 which can be solved partly by scene/object appearance, SSv2 requires **temporal reasoning** — understanding the motion and interaction pattern. This makes it a better test of the TRM recursive reasoning approach:
66
+
67
+ - "Pushing something from left to right" vs "Pushing something from right to left" differ only in motion direction
68
+ - 174 fine-grained template actions, ~220K training videos
69
+ - Standard benchmark for temporal modeling in video understanding
ssv2_datamodule.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Something-Something V2 DataModule for PyTorch Lightning.
4
+
5
+ Loads SSv2 from the Hugging Face Hub or from a local directory of webm files.
6
+ Each sample is a short video (~2-6 s) of a hand performing one of 174 template actions
7
+ (e.g. "Pushing [something] from left to right").
8
+
9
+ Usage:
10
+ dm = SSv2DataModule(data_dir="/path/to/ssv2", batch_size=8)
11
+ dm.setup()
12
+ for batch in dm.train_dataloader():
13
+ ...
14
+ """
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Optional, Callable, List, Dict, Tuple
20
+
21
+ import torch
22
+ from torch.utils.data import Dataset, DataLoader
23
+ import pytorch_lightning as pl
24
+ import torchvision.transforms as T
25
+
26
+ try:
27
+ import decord
28
+ decord.bridge.set_bridge("torch")
29
+ HAS_DECORD = True
30
+ except ImportError:
31
+ HAS_DECORD = False
32
+
33
+ try:
34
+ from datasets import load_dataset as hf_load_dataset
35
+ HAS_HF_DATASETS = True
36
+ except ImportError:
37
+ HAS_HF_DATASETS = False
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Video sampling helpers
42
+ # ---------------------------------------------------------------------------
43
+
44
+ def sample_frames_uniform(total_frames: int, num_frames: int) -> List[int]:
45
+ """Uniformly sample `num_frames` indices from [0, total_frames)."""
46
+ if total_frames <= num_frames:
47
+ indices = list(range(total_frames)) + [total_frames - 1] * (num_frames - total_frames)
48
+ return indices
49
+ stride = total_frames / num_frames
50
+ return [int(i * stride) for i in range(num_frames)]
51
+
52
+
53
+ def sample_frames_stride(total_frames: int, num_frames: int, stride: int) -> List[int]:
54
+ """Sample `num_frames` with fixed stride, centered in the video."""
55
+ needed = (num_frames - 1) * stride + 1
56
+ if needed > total_frames:
57
+ return sample_frames_uniform(total_frames, num_frames)
58
+ start = (total_frames - needed) // 2
59
+ return [start + i * stride for i in range(num_frames)]
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Dataset: local directory of webm/mp4 files + label JSON
64
+ # ---------------------------------------------------------------------------
65
+
66
+ class SSv2LocalDataset(Dataset):
67
+ """
68
+ Loads SSv2 from a local directory.
69
+
70
+ Expected layout:
71
+ data_dir/
72
+ videos/ # or 20bn-something-something-v2/
73
+ 1.webm
74
+ 2.webm
75
+ ...
76
+ labels/
77
+ train.json # [{"id": "1", "template": "...", "label": "..."}, ...]
78
+ validation.json
79
+ test.json # (no labels)
80
+ labels.json # {"0": "Approaching [something] with your camera", ...}
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ data_dir: str,
86
+ split: str = "train",
87
+ num_frames: int = 16,
88
+ frame_stride: int = 2,
89
+ transform: Optional[Callable] = None,
90
+ num_clips: int = 1,
91
+ ):
92
+ super().__init__()
93
+ self.data_dir = Path(data_dir)
94
+ self.split = split
95
+ self.num_frames = num_frames
96
+ self.frame_stride = frame_stride
97
+ self.transform = transform
98
+ self.num_clips = num_clips
99
+
100
+ if not HAS_DECORD:
101
+ raise ImportError("decord is required for local video loading. Install: pip install decord")
102
+
103
+ # Find video directory
104
+ vid_dirs = ["videos", "20bn-something-something-v2"]
105
+ self.video_dir = None
106
+ for d in vid_dirs:
107
+ candidate = self.data_dir / d
108
+ if candidate.exists():
109
+ self.video_dir = candidate
110
+ break
111
+ if self.video_dir is None:
112
+ self.video_dir = self.data_dir / "videos"
113
+
114
+ # Load labels mapping
115
+ labels_file = self.data_dir / "labels" / "labels.json"
116
+ if labels_file.exists():
117
+ with open(labels_file) as f:
118
+ idx_to_label = json.load(f)
119
+ self.label_to_idx = {v: int(k) for k, v in idx_to_label.items()}
120
+ else:
121
+ self.label_to_idx = {}
122
+
123
+ # Load split annotations
124
+ split_file = self.data_dir / "labels" / f"{split}.json"
125
+ if not split_file.exists():
126
+ # Try alternate naming
127
+ alt = self.data_dir / "labels" / f"something-something-v2-{split}.json"
128
+ if alt.exists():
129
+ split_file = alt
130
+ else:
131
+ raise FileNotFoundError(f"Cannot find annotation file for split '{split}' in {self.data_dir / 'labels'}")
132
+
133
+ with open(split_file) as f:
134
+ self.annotations = json.load(f)
135
+
136
+ # Build label_to_idx from annotations if not loaded from labels.json
137
+ if not self.label_to_idx:
138
+ all_labels = sorted(set(a.get("template", a.get("label", "")) for a in self.annotations if "template" in a or "label" in a))
139
+ self.label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}
140
+
141
+ self.num_classes = len(self.label_to_idx)
142
+ print(f"SSv2 [{split}]: {len(self.annotations)} videos, {self.num_classes} classes")
143
+
144
+ def __len__(self):
145
+ return len(self.annotations) * self.num_clips
146
+
147
+ def __getitem__(self, idx):
148
+ clip_idx = idx % self.num_clips
149
+ video_idx = idx // self.num_clips
150
+ ann = self.annotations[video_idx]
151
+
152
+ video_id = str(ann["id"])
153
+ label_str = ann.get("template", ann.get("label", None))
154
+ label = self.label_to_idx.get(label_str, -1) if label_str else -1
155
+
156
+ # Find video file
157
+ video_path = None
158
+ for ext in [".webm", ".mp4"]:
159
+ candidate = self.video_dir / f"{video_id}{ext}"
160
+ if candidate.exists():
161
+ video_path = str(candidate)
162
+ break
163
+ if video_path is None:
164
+ raise FileNotFoundError(f"Video not found: {video_id} in {self.video_dir}")
165
+
166
+ # Decode frames
167
+ vr = decord.VideoReader(video_path)
168
+ total = len(vr)
169
+ indices = sample_frames_stride(total, self.num_frames, self.frame_stride)
170
+ frames = vr.get_batch(indices) # (T, H, W, C) as torch tensor
171
+
172
+ # Convert to (T, C, H, W) float [0,1]
173
+ frames = frames.permute(0, 3, 1, 2).float() / 255.0
174
+
175
+ if self.transform is not None:
176
+ frames = torch.stack([self.transform(f) for f in frames])
177
+
178
+ return {"video": frames, "label": label, "video_id": video_id}
179
+
180
+
181
+ # ---------------------------------------------------------------------------
182
+ # Dataset: Hugging Face Hub streaming
183
+ # ---------------------------------------------------------------------------
184
+
185
+ class SSv2HFDataset(Dataset):
186
+ """
187
+ Loads SSv2 from the Hugging Face Hub using the `datasets` library.
188
+ Tries known Hub IDs: 'HuggingFaceM4/something-something-v2' or 'lmms-lab/SSv2'.
189
+ Falls back to manual download instructions if gated.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ split: str = "train",
195
+ num_frames: int = 16,
196
+ frame_stride: int = 2,
197
+ transform: Optional[Callable] = None,
198
+ num_clips: int = 1,
199
+ hf_dataset_id: str = "HuggingFaceM4/something-something-v2",
200
+ ):
201
+ super().__init__()
202
+ if not HAS_HF_DATASETS:
203
+ raise ImportError("Install: pip install datasets")
204
+
205
+ self.num_frames = num_frames
206
+ self.frame_stride = frame_stride
207
+ self.transform = transform
208
+ self.num_clips = num_clips
209
+
210
+ print(f"Loading SSv2 from Hub: {hf_dataset_id} (split={split}) ...")
211
+ self.ds = hf_load_dataset(hf_dataset_id, split=split)
212
+
213
+ # Infer label column and build mapping
214
+ if "label" in self.ds.features:
215
+ feat = self.ds.features["label"]
216
+ if hasattr(feat, "names"):
217
+ self.num_classes = len(feat.names)
218
+ else:
219
+ self.num_classes = 174
220
+ else:
221
+ self.num_classes = 174
222
+
223
+ print(f"SSv2 HF [{split}]: {len(self.ds)} samples, {self.num_classes} classes")
224
+
225
+ def __len__(self):
226
+ return len(self.ds) * self.num_clips
227
+
228
+ def __getitem__(self, idx):
229
+ video_idx = idx // self.num_clips
230
+ sample = self.ds[video_idx]
231
+
232
+ label = sample.get("label", -1)
233
+ video_id = str(sample.get("video_id", sample.get("id", video_idx)))
234
+
235
+ # The HF dataset typically stores video as bytes or decoded frames
236
+ video_data = sample.get("video", None)
237
+ if video_data is None:
238
+ raise ValueError("No 'video' column in HF dataset")
239
+
240
+ # If video_data is a dict with 'path'/'bytes', decode with decord
241
+ if isinstance(video_data, dict):
242
+ import io
243
+ video_bytes = video_data.get("bytes", None)
244
+ if video_bytes:
245
+ vr = decord.VideoReader(io.BytesIO(video_bytes))
246
+ total = len(vr)
247
+ indices = sample_frames_stride(total, self.num_frames, self.frame_stride)
248
+ frames = vr.get_batch(indices).permute(0, 3, 1, 2).float() / 255.0
249
+ else:
250
+ raise ValueError("Cannot decode video from HF dataset sample")
251
+ elif isinstance(video_data, torch.Tensor):
252
+ frames = video_data
253
+ if frames.ndim == 4 and frames.shape[-1] in (1, 3):
254
+ frames = frames.permute(0, 3, 1, 2).float()
255
+ if frames.max() > 1.0:
256
+ frames = frames / 255.0
257
+ total = frames.shape[0]
258
+ indices = sample_frames_stride(total, self.num_frames, self.frame_stride)
259
+ frames = frames[indices]
260
+ else:
261
+ raise ValueError(f"Unexpected video format: {type(video_data)}")
262
+
263
+ if self.transform is not None:
264
+ frames = torch.stack([self.transform(f) for f in frames])
265
+
266
+ return {"video": frames, "label": label, "video_id": video_id}
267
+
268
+
269
+ # ---------------------------------------------------------------------------
270
+ # Lightning DataModule
271
+ # ---------------------------------------------------------------------------
272
+
273
+ def build_train_transform(img_size: int = 224):
274
+ return T.Compose([
275
+ T.Resize((img_size, img_size)),
276
+ T.RandomHorizontalFlip(),
277
+ T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
278
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
279
+ ])
280
+
281
+
282
+ def build_val_transform(img_size: int = 224):
283
+ return T.Compose([
284
+ T.Resize((img_size, img_size)),
285
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
286
+ ])
287
+
288
+
289
+ class SSv2DataModule(pl.LightningDataModule):
290
+ """
291
+ SSv2 DataModule supporting both local files and HF Hub.
292
+
293
+ Args:
294
+ data_dir: Path to local SSv2 data. If None, loads from HF Hub.
295
+ hf_dataset_id: HF Hub dataset ID (used when data_dir is None).
296
+ num_frames: Frames to sample per clip.
297
+ frame_stride: Temporal stride between sampled frames.
298
+ img_size: Spatial resize target.
299
+ batch_size: Training batch size.
300
+ num_workers: DataLoader workers.
301
+ num_clips_val: Number of clips per video at val/test time.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ data_dir: Optional[str] = None,
307
+ hf_dataset_id: str = "HuggingFaceM4/something-something-v2",
308
+ num_frames: int = 16,
309
+ frame_stride: int = 2,
310
+ img_size: int = 224,
311
+ batch_size: int = 8,
312
+ num_workers: int = 4,
313
+ num_clips_val: int = 4,
314
+ ):
315
+ super().__init__()
316
+ self.save_hyperparameters()
317
+ self.data_dir = data_dir
318
+ self.hf_dataset_id = hf_dataset_id
319
+ self.num_frames = num_frames
320
+ self.frame_stride = frame_stride
321
+ self.img_size = img_size
322
+ self.batch_size = batch_size
323
+ self.num_workers = num_workers
324
+ self.num_clips_val = num_clips_val
325
+
326
+ def setup(self, stage=None):
327
+ train_tf = build_train_transform(self.img_size)
328
+ val_tf = build_val_transform(self.img_size)
329
+
330
+ if self.data_dir is not None:
331
+ self.train_ds = SSv2LocalDataset(
332
+ self.data_dir, "train", self.num_frames, self.frame_stride, train_tf, num_clips=1,
333
+ )
334
+ self.val_ds = SSv2LocalDataset(
335
+ self.data_dir, "validation", self.num_frames, self.frame_stride, val_tf, num_clips=self.num_clips_val,
336
+ )
337
+ self.num_classes = self.train_ds.num_classes
338
+ else:
339
+ self.train_ds = SSv2HFDataset(
340
+ "train", self.num_frames, self.frame_stride, train_tf, num_clips=1,
341
+ hf_dataset_id=self.hf_dataset_id,
342
+ )
343
+ self.val_ds = SSv2HFDataset(
344
+ "validation", self.num_frames, self.frame_stride, val_tf, num_clips=self.num_clips_val,
345
+ hf_dataset_id=self.hf_dataset_id,
346
+ )
347
+ self.num_classes = self.train_ds.num_classes
348
+
349
+ def train_dataloader(self):
350
+ return DataLoader(
351
+ self.train_ds, batch_size=self.batch_size, shuffle=True,
352
+ num_workers=self.num_workers, pin_memory=True, drop_last=True,
353
+ )
354
+
355
+ def val_dataloader(self):
356
+ return DataLoader(
357
+ self.val_ds, batch_size=self.batch_size, shuffle=False,
358
+ num_workers=self.num_workers, pin_memory=True,
359
+ )
360
+
361
+ def test_dataloader(self):
362
+ return self.val_dataloader()
train_ssv2.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train ViT-TRM on Something-Something V2.
4
+
5
+ Examples:
6
+ # From scratch on local SSv2 data:
7
+ python train_ssv2.py --data_dir /path/to/ssv2
8
+
9
+ # Transfer from HMDB51 pretrained checkpoint:
10
+ python train_ssv2.py --data_dir /path/to/ssv2 --pretrained_ckpt ../vit-trm-hmdb51/vit-trm-epoch=29-val_acc=0.7113.ckpt
11
+
12
+ # From HF Hub (if you have access):
13
+ python train_ssv2.py --from_hub
14
+
15
+ # Quick smoke test (2 epochs, 1 batch):
16
+ python train_ssv2.py --data_dir /path/to/ssv2 --fast_dev_run
17
+ """
18
+
19
+ import argparse
20
+ import pytorch_lightning as pl
21
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
22
+
23
+ from vit_trm_video import ViTTRMVideo
24
+ from ssv2_datamodule import SSv2DataModule
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser(description="Train ViT-TRM on SSv2")
29
+
30
+ # Data
31
+ parser.add_argument("--data_dir", type=str, default=None, help="Local SSv2 data directory")
32
+ parser.add_argument("--from_hub", action="store_true", help="Load SSv2 from HF Hub")
33
+ parser.add_argument("--hf_dataset_id", type=str, default="HuggingFaceM4/something-something-v2")
34
+ parser.add_argument("--num_frames", type=int, default=16)
35
+ parser.add_argument("--frame_stride", type=int, default=2, help="SSv2 videos are short, use stride=2")
36
+ parser.add_argument("--img_size", type=int, default=224)
37
+ parser.add_argument("--batch_size", type=int, default=8)
38
+ parser.add_argument("--num_workers", type=int, default=4)
39
+ parser.add_argument("--num_clips_val", type=int, default=4)
40
+
41
+ # Model
42
+ parser.add_argument("--vit_name", type=str, default="vit_tiny_patch16_224")
43
+ parser.add_argument("--vit_pretrained", action="store_true", default=True)
44
+ parser.add_argument("--vit_freeze", action="store_true", default=False)
45
+ parser.add_argument("--trm_H_cycles", type=int, default=2)
46
+ parser.add_argument("--trm_L_layers", type=int, default=2)
47
+ parser.add_argument("--trm_num_heads", type=int, default=4)
48
+ parser.add_argument("--num_classes", type=int, default=174)
49
+ parser.add_argument("--pretrained_ckpt", type=str, default=None,
50
+ help="Path to HMDB51 checkpoint to transfer backbone+TRM weights from")
51
+
52
+ # Training
53
+ parser.add_argument("--lr", type=float, default=3e-4)
54
+ parser.add_argument("--weight_decay", type=float, default=0.05)
55
+ parser.add_argument("--warmup_epochs", type=int, default=5)
56
+ parser.add_argument("--max_epochs", type=int, default=30)
57
+ parser.add_argument("--label_smoothing", type=float, default=0.1)
58
+ parser.add_argument("--iterative_refinement", action="store_true", default=False)
59
+
60
+ # Trainer
61
+ parser.add_argument("--accelerator", type=str, default="auto")
62
+ parser.add_argument("--devices", type=int, default=1)
63
+ parser.add_argument("--precision", type=str, default="16-mixed")
64
+ parser.add_argument("--fast_dev_run", action="store_true", default=False)
65
+ parser.add_argument("--seed", type=int, default=42)
66
+
67
+ args = parser.parse_args()
68
+ pl.seed_everything(args.seed)
69
+
70
+ # Data
71
+ data_dir = args.data_dir if not args.from_hub else None
72
+ dm = SSv2DataModule(
73
+ data_dir=data_dir,
74
+ hf_dataset_id=args.hf_dataset_id,
75
+ num_frames=args.num_frames,
76
+ frame_stride=args.frame_stride,
77
+ img_size=args.img_size,
78
+ batch_size=args.batch_size,
79
+ num_workers=args.num_workers,
80
+ num_clips_val=args.num_clips_val,
81
+ )
82
+
83
+ # Model
84
+ model = ViTTRMVideo(
85
+ img_size=args.img_size,
86
+ vit_name=args.vit_name,
87
+ vit_pretrained=args.vit_pretrained,
88
+ vit_freeze=args.vit_freeze,
89
+ trm_H_cycles=args.trm_H_cycles,
90
+ trm_L_layers=args.trm_L_layers,
91
+ trm_num_heads=args.trm_num_heads,
92
+ num_classes=args.num_classes,
93
+ lr=args.lr,
94
+ weight_decay=args.weight_decay,
95
+ warmup_epochs=args.warmup_epochs,
96
+ max_epochs=args.max_epochs,
97
+ label_smoothing=args.label_smoothing,
98
+ iterative_refinement=args.iterative_refinement,
99
+ pretrained_ckpt=args.pretrained_ckpt,
100
+ )
101
+
102
+ # Callbacks
103
+ ckpt_callback = ModelCheckpoint(
104
+ dirpath="checkpoints",
105
+ filename="vit-trm-ssv2-{epoch:02d}-{val_acc:.4f}",
106
+ monitor="val_acc",
107
+ mode="max",
108
+ save_top_k=3,
109
+ )
110
+ lr_monitor = LearningRateMonitor(logging_interval="epoch")
111
+
112
+ # Trainer
113
+ trainer = pl.Trainer(
114
+ accelerator=args.accelerator,
115
+ devices=args.devices,
116
+ precision=args.precision,
117
+ max_epochs=args.max_epochs,
118
+ callbacks=[ckpt_callback, lr_monitor],
119
+ fast_dev_run=args.fast_dev_run,
120
+ log_every_n_steps=50,
121
+ )
122
+
123
+ trainer.fit(model, dm)
124
+
125
+ # Test with best checkpoint
126
+ if not args.fast_dev_run:
127
+ trainer.test(model, dm, ckpt_path="best")
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
vit_trm_video.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ViT + TRM Video Classifier — dataset-agnostic version.
4
+
5
+ Architecture:
6
+ - ViT per-frame feature extraction
7
+ - TRM reasoning cycles (shared-weight transformer layers)
8
+ - Temporal pooling
9
+ - Classifier
10
+
11
+ Supports video-level evaluation by aggregating multi-clip predictions.
12
+ """
13
+
14
+ from typing import Optional, Dict
15
+ import math
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import pytorch_lightning as pl
20
+ import timm
21
+
22
+
23
+ def build_sinusoidal_positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
24
+ position = torch.arange(seq_len, device=device).unsqueeze(1)
25
+ div_term = torch.exp(torch.arange(0, dim, 2, device=device) * (-torch.log(torch.tensor(10000.0)) / dim))
26
+ pe = torch.zeros(seq_len, dim, device=device)
27
+ pe[:, 0::2] = torch.sin(position * div_term)
28
+ pe[:, 1::2] = torch.cos(position * div_term)
29
+ return pe.unsqueeze(0)
30
+
31
+
32
+ class ReasoningCycle(nn.Module):
33
+ """
34
+ Single reasoning cycle (TRM's H-cycle).
35
+ Applies L shared transformer layers to refine representations.
36
+ """
37
+
38
+ def __init__(self, hidden_size: int, num_heads: int, num_layers: int, dropout: float = 0.1):
39
+ super().__init__()
40
+ self.num_layers = num_layers
41
+ self.shared_layer = nn.TransformerEncoderLayer(
42
+ d_model=hidden_size,
43
+ nhead=num_heads,
44
+ dim_feedforward=hidden_size * 4,
45
+ dropout=dropout,
46
+ batch_first=True,
47
+ )
48
+ self.norm = nn.LayerNorm(hidden_size)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ for _ in range(self.num_layers):
52
+ x = self.shared_layer(x)
53
+ return self.norm(x)
54
+
55
+
56
+ class ViTTRMVideo(pl.LightningModule):
57
+ """
58
+ ViT + TRM for video classification.
59
+
60
+ Architecture:
61
+ 1. ViT per-frame feature extraction
62
+ 2. TRM recursive reasoning over temporal tokens
63
+ 3. Mean-pool + Classifier
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ # Frame encoder (ViT) config
69
+ img_size: int = 224,
70
+ vit_name: str = "vit_tiny_patch16_224",
71
+ vit_pretrained: bool = True,
72
+ vit_freeze: bool = False,
73
+ # TRM config
74
+ trm_H_cycles: int = 2,
75
+ trm_L_layers: int = 2,
76
+ trm_hidden_size: Optional[int] = None,
77
+ trm_num_heads: int = 4,
78
+ # Task config
79
+ num_classes: int = 174,
80
+ # Training config
81
+ lr: float = 3e-4,
82
+ weight_decay: float = 0.05,
83
+ warmup_epochs: int = 5,
84
+ max_epochs: int = 50,
85
+ label_smoothing: float = 0.1,
86
+ # Iterative refinement
87
+ iterative_refinement: bool = False,
88
+ num_refinement_steps: int = None,
89
+ # Transfer learning — path to a pretrained checkpoint to load backbone + TRM from
90
+ pretrained_ckpt: Optional[str] = None,
91
+ ):
92
+ super().__init__()
93
+ self.save_hyperparameters()
94
+
95
+ self.lr = lr
96
+ self.weight_decay = weight_decay
97
+ self.warmup_epochs = warmup_epochs
98
+ self.max_epochs = max_epochs
99
+ self.num_classes = num_classes
100
+ self.label_smoothing = label_smoothing
101
+ self.trm_H_cycles = trm_H_cycles
102
+ self.iterative_refinement = iterative_refinement
103
+
104
+ if num_refinement_steps is None:
105
+ self.num_refinement_steps = trm_H_cycles
106
+ else:
107
+ self.num_refinement_steps = num_refinement_steps
108
+
109
+ if iterative_refinement:
110
+ self.automatic_optimization = False
111
+
112
+ # ViT backbone
113
+ self.vit = timm.create_model(
114
+ vit_name,
115
+ pretrained=vit_pretrained,
116
+ num_classes=0,
117
+ img_size=img_size,
118
+ dynamic_img_size=True,
119
+ )
120
+ if hasattr(self.vit, "reset_classifier"):
121
+ self.vit.reset_classifier(0, global_pool="")
122
+
123
+ self.vit_freeze = vit_freeze
124
+ if vit_freeze:
125
+ for p in self.vit.parameters():
126
+ p.requires_grad = False
127
+ self.vit.eval()
128
+
129
+ vit_embed_dim = getattr(self.vit, "num_features", None) or getattr(self.vit, "embed_dim", None)
130
+ if vit_embed_dim is None:
131
+ raise ValueError("Could not infer ViT embedding dimension from timm model.")
132
+
133
+ if trm_hidden_size is None:
134
+ trm_hidden_size = int(vit_embed_dim)
135
+ self.trm_hidden_size = trm_hidden_size
136
+
137
+ # TRM reasoning cycles
138
+ self.reasoning_cycle = ReasoningCycle(
139
+ hidden_size=self.trm_hidden_size,
140
+ num_heads=trm_num_heads,
141
+ num_layers=trm_L_layers,
142
+ dropout=0.1,
143
+ )
144
+
145
+ # Classification head
146
+ self.classifier = nn.Sequential(
147
+ nn.LayerNorm(self.trm_hidden_size),
148
+ nn.Linear(self.trm_hidden_size, num_classes),
149
+ )
150
+
151
+ # Metrics
152
+ import torchmetrics
153
+ self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
154
+ self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
155
+
156
+ self.validation_outputs = []
157
+
158
+ # Optionally load pretrained weights (e.g. from HMDB51 checkpoint)
159
+ if pretrained_ckpt is not None:
160
+ self._load_pretrained(pretrained_ckpt)
161
+
162
+ def _load_pretrained(self, ckpt_path: str):
163
+ """Load backbone + TRM weights from a prior checkpoint, skip classifier."""
164
+ ckpt = torch.load(ckpt_path, map_location="cpu")
165
+ state_dict = ckpt.get("state_dict", ckpt)
166
+ # Filter out classifier weights (different num_classes)
167
+ filtered = {k: v for k, v in state_dict.items() if not k.startswith("classifier.")}
168
+ missing, unexpected = self.load_state_dict(filtered, strict=False)
169
+ print(f"Loaded pretrained weights from {ckpt_path}")
170
+ print(f" Missing keys (expected — new classifier): {missing}")
171
+ if unexpected:
172
+ print(f" Unexpected keys (ignored): {unexpected}")
173
+
174
+ def forward(self, video, num_cycles=None):
175
+ if num_cycles is None:
176
+ num_cycles = self.trm_H_cycles
177
+
178
+ B, T, C, H, W = video.shape
179
+ frames_bt = video.view(B * T, C, H, W)
180
+
181
+ tokens = self.vit.forward_features(frames_bt)
182
+ frame_features = tokens.mean(dim=1)
183
+ features = frame_features.view(B, T, -1)
184
+ pos = build_sinusoidal_positional_encoding(T, features.size(-1), features.device)
185
+ features = features + pos
186
+
187
+ if num_cycles > 0:
188
+ for _ in range(num_cycles):
189
+ features = self.reasoning_cycle(features)
190
+
191
+ pooled = features.mean(dim=1)
192
+ logits = self.classifier(pooled)
193
+ return logits
194
+
195
+ def _unpack_batch(self, batch: Dict[str, torch.Tensor]):
196
+ if isinstance(batch, tuple):
197
+ return batch[0], batch[1], None
198
+ video_ids = batch.get("video_id", None)
199
+ return batch["video"], batch["label"], video_ids
200
+
201
+ def training_step(self, batch, batch_idx):
202
+ videos, labels, _ = self._unpack_batch(batch)
203
+
204
+ if self.iterative_refinement:
205
+ opt = self.optimizers()
206
+ opt.zero_grad()
207
+ total_loss = 0.0
208
+ for step in range(1, self.num_refinement_steps + 1):
209
+ logits = self(videos, num_cycles=step)
210
+ loss = nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
211
+ total_loss += loss / self.num_refinement_steps
212
+ self.manual_backward(total_loss)
213
+ opt.step()
214
+ with torch.no_grad():
215
+ final_logits = self(videos)
216
+ final_preds = torch.argmax(final_logits, dim=1)
217
+ sch = self.lr_schedulers()
218
+ if sch is not None:
219
+ sch.step()
220
+ acc = self.train_acc(final_preds, labels)
221
+ self.log("train_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True)
222
+ self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
223
+ return total_loss
224
+ else:
225
+ logits = self(videos)
226
+ loss = nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
227
+ preds = torch.argmax(logits, dim=1)
228
+ acc = self.train_acc(preds, labels)
229
+ self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
230
+ self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
231
+ return loss
232
+
233
+ def validation_step(self, batch, batch_idx):
234
+ videos, labels, video_ids = self._unpack_batch(batch)
235
+ logits = self(videos)
236
+ loss = nn.functional.cross_entropy(logits, labels)
237
+ preds = torch.argmax(logits, dim=1)
238
+ acc = self.val_acc(preds, labels)
239
+ self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
240
+ self.log("val_acc_clip", acc, on_step=False, on_epoch=True, prog_bar=True)
241
+ if video_ids is not None:
242
+ self.validation_outputs.append({
243
+ "video_ids": video_ids,
244
+ "logits": logits.detach().cpu(),
245
+ "labels": labels.detach().cpu(),
246
+ "preds": preds.detach().cpu(),
247
+ })
248
+ return loss
249
+
250
+ def on_validation_epoch_end(self):
251
+ if not self.validation_outputs:
252
+ return
253
+ from collections import defaultdict
254
+ video_logits = defaultdict(list)
255
+ video_labels = {}
256
+ for output in self.validation_outputs:
257
+ for i, vid in enumerate(output["video_ids"]):
258
+ video_logits[vid].append(output["logits"][i])
259
+ video_labels[vid] = output["labels"][i].item()
260
+ video_preds, video_true = [], []
261
+ for vid in sorted(video_logits.keys()):
262
+ avg = torch.stack(video_logits[vid]).mean(dim=0)
263
+ video_preds.append(torch.argmax(avg).item())
264
+ video_true.append(video_labels[vid])
265
+ video_acc = (torch.tensor(video_preds) == torch.tensor(video_true)).float().mean()
266
+ self.log("val_acc_video", video_acc, on_epoch=True, prog_bar=True)
267
+ self.log("val_acc", video_acc, on_epoch=True, prog_bar=True)
268
+ num_videos = len(video_logits)
269
+ num_clips = sum(len(v) for v in video_logits.values())
270
+ print(f"\n Video-level val: {num_videos} videos, {num_clips} clips, acc={video_acc:.4f}")
271
+ self.validation_outputs.clear()
272
+
273
+ def test_step(self, batch, batch_idx):
274
+ videos, labels, video_ids = self._unpack_batch(batch)
275
+ logits = self(videos)
276
+ loss = nn.functional.cross_entropy(logits, labels)
277
+ preds = torch.argmax(logits, dim=1)
278
+ acc = self.val_acc(preds, labels)
279
+ self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
280
+ self.log("test_acc_clip", acc, on_step=False, on_epoch=True, prog_bar=True)
281
+ if video_ids is not None:
282
+ self.validation_outputs.append({
283
+ "video_ids": video_ids,
284
+ "logits": logits.detach().cpu(),
285
+ "labels": labels.detach().cpu(),
286
+ "preds": preds.detach().cpu(),
287
+ })
288
+ return loss
289
+
290
+ def on_test_epoch_end(self):
291
+ if not self.validation_outputs:
292
+ return
293
+ from collections import defaultdict
294
+ video_logits = defaultdict(list)
295
+ video_labels = {}
296
+ for output in self.validation_outputs:
297
+ for i, vid in enumerate(output["video_ids"]):
298
+ video_logits[vid].append(output["logits"][i])
299
+ video_labels[vid] = output["labels"][i].item()
300
+ video_preds, video_true = [], []
301
+ for vid in sorted(video_logits.keys()):
302
+ avg = torch.stack(video_logits[vid]).mean(dim=0)
303
+ video_preds.append(torch.argmax(avg).item())
304
+ video_true.append(video_labels[vid])
305
+ video_acc = (torch.tensor(video_preds) == torch.tensor(video_true)).float().mean()
306
+ self.log("test_acc_video", video_acc, on_epoch=True, prog_bar=True)
307
+ self.log("test_acc", video_acc, on_epoch=True, prog_bar=True)
308
+ print(f"\n Video-level test: {len(video_logits)} videos, acc={video_acc:.4f}")
309
+ self.validation_outputs.clear()
310
+
311
+ def on_train_epoch_start(self):
312
+ if self.vit_freeze:
313
+ self.vit.eval()
314
+
315
+ def configure_optimizers(self):
316
+ decay, no_decay = [], []
317
+ for n, p in self.named_parameters():
318
+ if not p.requires_grad:
319
+ continue
320
+ if p.ndim < 2 or n.endswith("bias") or "norm" in n.lower() or "bn" in n.lower():
321
+ no_decay.append(p)
322
+ else:
323
+ decay.append(p)
324
+ vit_param_ids = {id(p) for p in self.vit.parameters()}
325
+ optimizer = torch.optim.AdamW([
326
+ {"params": [p for p in decay if id(p) not in vit_param_ids], "lr": self.lr, "weight_decay": self.weight_decay},
327
+ {"params": [p for p in no_decay if id(p) not in vit_param_ids], "lr": self.lr, "weight_decay": 0.0},
328
+ {"params": [p for p in decay if id(p) in vit_param_ids], "lr": self.lr * 0.1, "weight_decay": self.weight_decay},
329
+ {"params": [p for p in no_decay if id(p) in vit_param_ids], "lr": self.lr * 0.1, "weight_decay": 0.0},
330
+ ])
331
+
332
+ def lr_lambda(epoch: int) -> float:
333
+ if epoch < self.warmup_epochs:
334
+ return float((epoch + 1) / max(1, self.warmup_epochs))
335
+ progress = (epoch - self.warmup_epochs) / max(1, (self.max_epochs - self.warmup_epochs))
336
+ return 0.5 * (1.0 + math.cos(math.pi * min(1.0, max(0.0, progress))))
337
+
338
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
339
+ return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}}
340
+
341
+
342
+ if __name__ == "__main__":
343
+ model = ViTTRMVideo(num_classes=174, trm_H_cycles=2)
344
+ x = torch.randn(2, 16, 3, 224, 224)
345
+ y = model(x)
346
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
347
+ print(f"Trainable parameters: {num_params:,}")
348
+ print("Logits:", y.shape) # (2, 174)