| |
| """ |
| UCF-101 DataModule for PyTorch Lightning. |
| |
| Loads UCF-101 from the Hugging Face Hub (flwrlabs/ucf101). |
| The HF dataset stores individual frames with video_id + frame index. |
| We build a lightweight index (video_id -> row indices) and load frames |
| lazily at __getitem__ time to avoid OOM. |
| |
| Usage: |
| dm = UCF101DataModule(batch_size=8) |
| dm.setup() |
| for batch in dm.train_dataloader(): |
| ... |
| """ |
|
|
| from typing import Optional, Callable, List, Dict |
| from collections import defaultdict |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import pytorch_lightning as pl |
| import torchvision.transforms as T |
|
|
| from datasets import load_dataset |
|
|
|
|
| def sample_indices(total: int, num_frames: int) -> List[int]: |
| """Uniformly sample num_frames indices from [0, total).""" |
| if total <= num_frames: |
| return list(range(total)) + [total - 1] * (num_frames - total) |
| stride = total / num_frames |
| return [int(i * stride) for i in range(num_frames)] |
|
|
|
|
| class UCF101Dataset(Dataset): |
| """ |
| Lazy-loading UCF-101 dataset. Builds a lightweight index of |
| (video_id -> sorted row indices) and only loads pixel data on access. |
| """ |
|
|
| def __init__( |
| self, |
| split: str = "train", |
| num_frames: int = 16, |
| transform: Optional[Callable] = None, |
| num_clips: int = 1, |
| max_videos: Optional[int] = None, |
| ): |
| super().__init__() |
| self.num_frames = num_frames |
| self.transform = transform |
| self.num_clips = num_clips |
|
|
| print(f"Loading UCF-101 [{split}] from HF Hub...") |
| self.ds = load_dataset("flwrlabs/ucf101", split=split) |
|
|
| |
| print(f"Building video index for {len(self.ds)} frames...") |
| all_video_ids = self.ds["video_id"] |
| all_frames = self.ds["frame"] |
| all_labels = self.ds["label"] |
|
|
| video_rows: Dict[str, List] = defaultdict(list) |
| video_labels: Dict[str, int] = {} |
|
|
| for row_idx, (vid, frame, label) in enumerate(zip(all_video_ids, all_frames, all_labels)): |
| video_rows[vid].append((frame, row_idx)) |
| video_labels[vid] = label |
|
|
| del all_video_ids, all_frames, all_labels |
|
|
| |
| self.video_index = [] |
| for vid in sorted(video_rows.keys()): |
| sorted_rows = sorted(video_rows[vid], key=lambda x: x[0]) |
| row_indices = [r[1] for r in sorted_rows] |
| self.video_index.append({ |
| "video_id": vid, |
| "label": video_labels[vid], |
| "row_indices": row_indices, |
| }) |
|
|
| if max_videos is not None: |
| self.video_index = self.video_index[:max_videos] |
|
|
| self.num_classes = 101 |
| print(f"UCF-101 [{split}]: {len(self.video_index)} videos, {self.num_classes} classes") |
|
|
| def __len__(self): |
| return len(self.video_index) * self.num_clips |
|
|
| def __getitem__(self, idx): |
| video_idx = idx // self.num_clips |
| entry = self.video_index[video_idx] |
|
|
| row_indices = entry["row_indices"] |
| sampled = sample_indices(len(row_indices), self.num_frames) |
|
|
| frames = [] |
| for i in sampled: |
| row = self.ds[row_indices[i]] |
| img = row["image"] |
| tensor = T.functional.to_tensor(img) |
| if self.transform is not None: |
| tensor = self.transform(tensor) |
| frames.append(tensor) |
|
|
| video_tensor = torch.stack(frames) |
| return { |
| "video": video_tensor, |
| "label": entry["label"], |
| "video_id": entry["video_id"], |
| } |
|
|
|
|
| def build_train_transform(img_size: int = 224): |
| return T.Compose([ |
| T.Resize((img_size, img_size)), |
| T.RandomHorizontalFlip(), |
| T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| def build_val_transform(img_size: int = 224): |
| return T.Compose([ |
| T.Resize((img_size, img_size)), |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| class UCF101DataModule(pl.LightningDataModule): |
| def __init__( |
| self, |
| num_frames: int = 16, |
| img_size: int = 224, |
| batch_size: int = 8, |
| num_workers: int = 0, |
| num_clips_val: int = 4, |
| max_videos: Optional[int] = None, |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
| self.num_frames = num_frames |
| self.img_size = img_size |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.num_clips_val = num_clips_val |
| self.max_videos = max_videos |
|
|
| def setup(self, stage=None): |
| train_tf = build_train_transform(self.img_size) |
| val_tf = build_val_transform(self.img_size) |
|
|
| self.train_ds = UCF101Dataset( |
| "train", self.num_frames, train_tf, num_clips=1, max_videos=self.max_videos, |
| ) |
| self.val_ds = UCF101Dataset( |
| "test", self.num_frames, val_tf, num_clips=self.num_clips_val, max_videos=self.max_videos, |
| ) |
| self.num_classes = 101 |
|
|
| def train_dataloader(self): |
| return DataLoader( |
| self.train_ds, batch_size=self.batch_size, shuffle=True, |
| num_workers=self.num_workers, pin_memory=True, drop_last=True, |
| ) |
|
|
| def val_dataloader(self): |
| return DataLoader( |
| self.val_ds, batch_size=self.batch_size, shuffle=False, |
| num_workers=self.num_workers, pin_memory=True, |
| ) |
|
|
| def test_dataloader(self): |
| return self.val_dataloader() |
|
|