Spaces:
Running on Zero
Running on Zero
| import logging | |
| log = logging.getLogger() | |
| import os | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| from PIL import Image | |
| from transformers import AutoProcessor | |
| import pandas as pd | |
| import torch | |
| import torchaudio | |
| from torch.utils.data.dataset import Dataset | |
| from torchvision.transforms import v2 | |
| from torio.io import StreamingMediaDecoder | |
| import mediapy | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import subprocess | |
| from torchvision.utils import save_image | |
| try: | |
| from moviepy import VideoFileClip | |
| except ImportError: | |
| from moviepy.editor import VideoFileClip | |
| _CLIP_FPS = 4 | |
| _CLIP_SIZE = 288 | |
| _SYNC_FPS = 25 | |
| _SYNC_SIZE = 224 | |
| def pad_to_square(video_tensor): | |
| if len(video_tensor.shape) != 4: | |
| raise ValueError("Input tensor must have shape (l, c, h, w)") | |
| l, c, h, w = video_tensor.shape | |
| max_side = max(h, w) | |
| pad_h = max_side - h | |
| pad_w = max_side - w | |
| padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) | |
| video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) | |
| video_tensor = F.interpolate(video_padded, size=(_CLIP_SIZE, _CLIP_SIZE), mode='bilinear', align_corners=False) | |
| return video_tensor | |
| def get_video_duration(video_path): | |
| video = VideoFileClip(str(video_path)) | |
| return video.duration | |
| class VGGSound(Dataset): | |
| def __init__( | |
| self, | |
| root: Path, | |
| *, | |
| tsv_path: Path, | |
| sample_rate: int = 44100, | |
| normalize_audio: bool = False, | |
| start_row: int = None, | |
| end_row: int = None, | |
| save_dir: str = '', | |
| use_variable_length: bool = False, | |
| video_encoder: str = 'videoprism', | |
| video_resolution: int = _CLIP_SIZE, | |
| inference_mode: bool = False, | |
| video_fps: int = _CLIP_FPS | |
| ): | |
| self.inference_mode = inference_mode | |
| self.sample_rate=sample_rate | |
| self.root = Path(root) | |
| self.normalize_audio = normalize_audio | |
| self.use_variable_length = use_variable_length | |
| self.video_encoder = video_encoder | |
| self.video_resolution = video_resolution | |
| self.video_fps = video_fps | |
| self.videos = [] | |
| self.caption_cot = [] | |
| df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') | |
| if start_row is not None and end_row is not None: | |
| df_list = df_list[start_row:end_row] | |
| for record in df_list: | |
| id = record['id'] | |
| if os.path.exists(f'{save_dir}/{id}.npz'): continue | |
| caption_cot = record['caption_cot'] | |
| if not os.path.exists(os.path.join(self.root, id)+".mp4"): | |
| continue | |
| self.videos.append(id) | |
| self.caption_cot.append(caption_cot) | |
| log.info(f'processing {len(self.videos)} videos') | |
| self.sync_transform = v2.Compose([ | |
| v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), | |
| v2.CenterCrop(_SYNC_SIZE), | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| self.resampler = {} | |
| def sample(self, idx: int) -> dict[str, torch.Tensor]: | |
| video_id = self.videos[idx] | |
| caption_cot = self.caption_cot[idx] | |
| duration_sec= get_video_duration(self.root / (video_id + '.mp4')) | |
| reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) | |
| reader.add_basic_video_stream( | |
| frames_per_chunk=int(_CLIP_FPS * duration_sec), | |
| frame_rate=_CLIP_FPS, | |
| format='rgb24', | |
| ) | |
| reader.add_basic_video_stream( | |
| frames_per_chunk=int(_SYNC_FPS * duration_sec), | |
| frame_rate=_SYNC_FPS, | |
| format='rgb24', | |
| ) | |
| if not self.inference_mode: | |
| reader.add_basic_audio_stream(frames_per_chunk=2**30,) | |
| reader.fill_buffer() | |
| data_chunk = reader.pop_chunks() | |
| clip_chunk = data_chunk[0] | |
| sync_chunk = data_chunk[1] | |
| if not self.inference_mode: | |
| audio_chunk = data_chunk[2] | |
| audio_chunk = audio_chunk.transpose(0, 1) | |
| else: | |
| num_samples = int(self.sample_rate * duration_sec) | |
| audio_chunk = torch.randn((2, num_samples)) | |
| if len(audio_chunk.shape) != 2: | |
| raise RuntimeError(f'error audio shape {video_id}') | |
| if clip_chunk is None: | |
| raise RuntimeError(f'CLIP video returned None {video_id}') | |
| if sync_chunk is None: | |
| raise RuntimeError(f'Sync video returned None {video_id}') | |
| if not self.inference_mode: | |
| sample_rate = int(reader.get_out_stream_info(2).sample_rate) | |
| else: | |
| sample_rate = self.sample_rate | |
| abs_max = audio_chunk[0].abs().max() | |
| if self.normalize_audio: | |
| abs_max = audio_chunk.abs().max() | |
| audio_chunk = audio_chunk / abs_max * 0.95 | |
| clip_expected_length = int(_CLIP_FPS * duration_sec) | |
| sync_expected_length = int(_SYNC_FPS * duration_sec) | |
| clip_chunk = clip_chunk[:clip_expected_length] | |
| if clip_chunk.shape[0] != clip_expected_length: | |
| current_length = clip_chunk.shape[0] | |
| padding_needed = clip_expected_length - current_length | |
| # If assertion passes, proceed with padding | |
| if padding_needed > 0: | |
| last_frame = clip_chunk[-1] | |
| padding = last_frame.repeat(padding_needed, 1, 1, 1) | |
| clip_chunk = torch.cat((clip_chunk, padding), dim=0) | |
| clip_chunk = pad_to_square(clip_chunk) | |
| clip_chunk = clip_chunk.permute(0, 2, 3, 1) | |
| clip_chunk = mediapy.to_float01(clip_chunk) | |
| sync_chunk = sync_chunk[:sync_expected_length] | |
| if sync_chunk.shape[0] != sync_expected_length: | |
| # padding using the last frame, but no more than 2 | |
| current_length = sync_chunk.shape[0] | |
| last_frame = sync_chunk[-1] | |
| padding = last_frame.repeat(sync_expected_length - current_length, 1, 1, 1) | |
| sync_chunk = torch.cat((sync_chunk, padding), dim=0) | |
| sync_chunk = self.sync_transform(sync_chunk) | |
| data = { | |
| 'id': video_id, | |
| 'caption_cot': caption_cot, | |
| 'audio': audio_chunk, | |
| 'clip_video': clip_chunk, | |
| 'sync_video': sync_chunk, | |
| } | |
| return data | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| try: | |
| return self.sample(idx) | |
| except Exception as e: | |
| logging.error(f'Error loading {self.videos[idx]}: {e}') | |
| return None | |
| def __len__(self) -> int: | |
| return len(self.videos) | |