| | import json
|
| | import logging
|
| | import os
|
| | from pathlib import Path
|
| | from typing import Union
|
| |
|
| | import pandas as pd
|
| | import torch
|
| | from torch.utils.data.dataset import Dataset
|
| | from torchvision.transforms import v2
|
| | from torio.io import StreamingMediaDecoder
|
| |
|
| | from ...utils.dist_utils import local_rank
|
| |
|
| | log = logging.getLogger()
|
| |
|
| | _CLIP_SIZE = 384
|
| | _CLIP_FPS = 8.0
|
| |
|
| | _SYNC_SIZE = 224
|
| | _SYNC_FPS = 25.0
|
| |
|
| |
|
| | class VideoDataset(Dataset):
|
| |
|
| | def __init__(
|
| | self,
|
| | video_root: Union[str, Path],
|
| | *,
|
| | duration_sec: float = 8.0,
|
| | ):
|
| | self.video_root = Path(video_root)
|
| |
|
| | self.duration_sec = duration_sec
|
| |
|
| | self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
| | self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
| |
|
| | self.clip_transform = v2.Compose([
|
| | v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
| | v2.ToImage(),
|
| | v2.ToDtype(torch.float32, scale=True),
|
| | ])
|
| |
|
| | self.sync_transform = v2.Compose([
|
| | v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
| | v2.CenterCrop(_SYNC_SIZE),
|
| | v2.ToImage(),
|
| | v2.ToDtype(torch.float32, scale=True),
|
| | v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| | ])
|
| |
|
| |
|
| | self.captions = {}
|
| | self.videos = sorted(list(self.captions.keys()))
|
| |
|
| | def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
| | video_id = self.videos[idx]
|
| | caption = self.captions[video_id]
|
| |
|
| | reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
| | reader.add_basic_video_stream(
|
| | frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
| | frame_rate=_CLIP_FPS,
|
| | format='rgb24',
|
| | )
|
| | reader.add_basic_video_stream(
|
| | frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
| | frame_rate=_SYNC_FPS,
|
| | format='rgb24',
|
| | )
|
| |
|
| | reader.fill_buffer()
|
| | data_chunk = reader.pop_chunks()
|
| |
|
| | clip_chunk = data_chunk[0]
|
| | sync_chunk = data_chunk[1]
|
| | if clip_chunk is None:
|
| | raise RuntimeError(f'CLIP video returned None {video_id}')
|
| | if clip_chunk.shape[0] < self.clip_expected_length:
|
| | raise RuntimeError(
|
| | f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
| | )
|
| |
|
| | if sync_chunk is None:
|
| | raise RuntimeError(f'Sync video returned None {video_id}')
|
| | if sync_chunk.shape[0] < self.sync_expected_length:
|
| | raise RuntimeError(
|
| | f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
| | )
|
| |
|
| |
|
| | clip_chunk = clip_chunk[:self.clip_expected_length]
|
| | if clip_chunk.shape[0] != self.clip_expected_length:
|
| | raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
| | f'expected {self.clip_expected_length}, '
|
| | f'got {clip_chunk.shape[0]}')
|
| | clip_chunk = self.clip_transform(clip_chunk)
|
| |
|
| | sync_chunk = sync_chunk[:self.sync_expected_length]
|
| | if sync_chunk.shape[0] != self.sync_expected_length:
|
| | raise RuntimeError(f'Sync video wrong length {video_id}, '
|
| | f'expected {self.sync_expected_length}, '
|
| | f'got {sync_chunk.shape[0]}')
|
| | sync_chunk = self.sync_transform(sync_chunk)
|
| |
|
| | data = {
|
| | 'name': video_id,
|
| | 'caption': caption,
|
| | 'clip_video': clip_chunk,
|
| | 'sync_video': sync_chunk,
|
| | }
|
| |
|
| | return data
|
| |
|
| | def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| | try:
|
| | return self.sample(idx)
|
| | except Exception as e:
|
| | log.error(f'Error loading video {self.videos[idx]}: {e}')
|
| | return None
|
| |
|
| | def __len__(self):
|
| | return len(self.captions)
|
| |
|
| |
|
| | class VGGSound(VideoDataset):
|
| |
|
| | def __init__(
|
| | self,
|
| | video_root: Union[str, Path],
|
| | csv_path: Union[str, Path],
|
| | *,
|
| | duration_sec: float = 8.0,
|
| | ):
|
| | super().__init__(video_root, duration_sec=duration_sec)
|
| | self.video_root = Path(video_root)
|
| | self.csv_path = Path(csv_path)
|
| |
|
| | videos = sorted(os.listdir(self.video_root))
|
| | if local_rank == 0:
|
| | log.info(f'{len(videos)} videos found in {video_root}')
|
| | self.captions = {}
|
| |
|
| | df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
| | 'split']).to_dict(orient='records')
|
| |
|
| | videos_no_found = []
|
| | for row in df:
|
| | if row['split'] == 'test':
|
| | start_sec = int(row['sec'])
|
| | video_id = str(row['id'])
|
| |
|
| | video_name = f'{video_id}_{start_sec:06d}'
|
| | if video_name + '.mp4' not in videos:
|
| | videos_no_found.append(video_name)
|
| | continue
|
| |
|
| | self.captions[video_name] = row['caption']
|
| |
|
| | if local_rank == 0:
|
| | log.info(f'{len(videos)} videos found in {video_root}')
|
| | log.info(f'{len(self.captions)} useable videos found')
|
| | if videos_no_found:
|
| | log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
| | log.info(
|
| | 'A small amount is expected, as not all videos are still available on YouTube')
|
| |
|
| | self.videos = sorted(list(self.captions.keys()))
|
| |
|
| |
|
| | class MovieGen(VideoDataset):
|
| |
|
| | def __init__(
|
| | self,
|
| | video_root: Union[str, Path],
|
| | jsonl_root: Union[str, Path],
|
| | *,
|
| | duration_sec: float = 10.0,
|
| | ):
|
| | super().__init__(video_root, duration_sec=duration_sec)
|
| | self.video_root = Path(video_root)
|
| | self.jsonl_root = Path(jsonl_root)
|
| |
|
| | videos = sorted(os.listdir(self.video_root))
|
| | videos = [v[:-4] for v in videos]
|
| | self.captions = {}
|
| |
|
| | for v in videos:
|
| | with open(self.jsonl_root / (v + '.jsonl')) as f:
|
| | data = json.load(f)
|
| | self.captions[v] = data['audio_prompt']
|
| |
|
| | if local_rank == 0:
|
| | log.info(f'{len(videos)} videos found in {video_root}')
|
| |
|
| | self.videos = videos
|
| |
|