vit-trm-ssv2 / ucf101_datamodule.py
bcgxtberg's picture
Upload ucf101_datamodule.py with huggingface_hub
efc67a0 verified
#!/usr/bin/env python3
"""
UCF-101 DataModule for PyTorch Lightning.
Loads UCF-101 from the Hugging Face Hub (flwrlabs/ucf101).
The HF dataset stores individual frames with video_id + frame index.
We build a lightweight index (video_id -> row indices) and load frames
lazily at __getitem__ time to avoid OOM.
Usage:
dm = UCF101DataModule(batch_size=8)
dm.setup()
for batch in dm.train_dataloader():
...
"""
from typing import Optional, Callable, List, Dict
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torchvision.transforms as T
from datasets import load_dataset
def sample_indices(total: int, num_frames: int) -> List[int]:
"""Uniformly sample num_frames indices from [0, total)."""
if total <= num_frames:
return list(range(total)) + [total - 1] * (num_frames - total)
stride = total / num_frames
return [int(i * stride) for i in range(num_frames)]
class UCF101Dataset(Dataset):
"""
Lazy-loading UCF-101 dataset. Builds a lightweight index of
(video_id -> sorted row indices) and only loads pixel data on access.
"""
def __init__(
self,
split: str = "train",
num_frames: int = 16,
transform: Optional[Callable] = None,
num_clips: int = 1,
max_videos: Optional[int] = None,
):
super().__init__()
self.num_frames = num_frames
self.transform = transform
self.num_clips = num_clips
print(f"Loading UCF-101 [{split}] from HF Hub...")
self.ds = load_dataset("flwrlabs/ucf101", split=split)
# Build lightweight index using batch column access (fast)
print(f"Building video index for {len(self.ds)} frames...")
all_video_ids = self.ds["video_id"]
all_frames = self.ds["frame"]
all_labels = self.ds["label"]
video_rows: Dict[str, List] = defaultdict(list)
video_labels: Dict[str, int] = {}
for row_idx, (vid, frame, label) in enumerate(zip(all_video_ids, all_frames, all_labels)):
video_rows[vid].append((frame, row_idx))
video_labels[vid] = label
del all_video_ids, all_frames, all_labels
# Sort frames within each video and store as index
self.video_index = [] # list of {"video_id", "label", "row_indices"}
for vid in sorted(video_rows.keys()):
sorted_rows = sorted(video_rows[vid], key=lambda x: x[0])
row_indices = [r[1] for r in sorted_rows]
self.video_index.append({
"video_id": vid,
"label": video_labels[vid],
"row_indices": row_indices,
})
if max_videos is not None:
self.video_index = self.video_index[:max_videos]
self.num_classes = 101
print(f"UCF-101 [{split}]: {len(self.video_index)} videos, {self.num_classes} classes")
def __len__(self):
return len(self.video_index) * self.num_clips
def __getitem__(self, idx):
video_idx = idx // self.num_clips
entry = self.video_index[video_idx]
row_indices = entry["row_indices"]
sampled = sample_indices(len(row_indices), self.num_frames)
frames = []
for i in sampled:
row = self.ds[row_indices[i]]
img = row["image"] # PIL Image — loaded lazily
tensor = T.functional.to_tensor(img)
if self.transform is not None:
tensor = self.transform(tensor)
frames.append(tensor)
video_tensor = torch.stack(frames) # (T, C, H, W)
return {
"video": video_tensor,
"label": entry["label"],
"video_id": entry["video_id"],
}
def build_train_transform(img_size: int = 224):
return T.Compose([
T.Resize((img_size, img_size)),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def build_val_transform(img_size: int = 224):
return T.Compose([
T.Resize((img_size, img_size)),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class UCF101DataModule(pl.LightningDataModule):
def __init__(
self,
num_frames: int = 16,
img_size: int = 224,
batch_size: int = 8,
num_workers: int = 0,
num_clips_val: int = 4,
max_videos: Optional[int] = None,
):
super().__init__()
self.save_hyperparameters()
self.num_frames = num_frames
self.img_size = img_size
self.batch_size = batch_size
self.num_workers = num_workers
self.num_clips_val = num_clips_val
self.max_videos = max_videos
def setup(self, stage=None):
train_tf = build_train_transform(self.img_size)
val_tf = build_val_transform(self.img_size)
self.train_ds = UCF101Dataset(
"train", self.num_frames, train_tf, num_clips=1, max_videos=self.max_videos,
)
self.val_ds = UCF101Dataset(
"test", self.num_frames, val_tf, num_clips=self.num_clips_val, max_videos=self.max_videos,
)
self.num_classes = 101
def train_dataloader(self):
return DataLoader(
self.train_ds, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers, pin_memory=True, drop_last=True,
)
def val_dataloader(self):
return DataLoader(
self.val_ds, batch_size=self.batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=True,
)
def test_dataloader(self):
return self.val_dataloader()