Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| import os | |
| from dataclasses import dataclass | |
| from io import BytesIO | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from datasets import load_dataset | |
| from torchcodec.decoders import AudioDecoder, VideoDecoder | |
| class Item: | |
| anchors: list[Tuple[str, float, float]] | |
| masked_video_frames: torch.Tensor | |
| audio_samples: torch.Tensor | |
| description: str | |
| class SAMAudioBench(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| cache_path, | |
| collate_fn, | |
| span: bool = True, | |
| visual: bool = True, | |
| subset: Optional[str] = None, | |
| ): | |
| self.dataset = load_dataset("facebook/sam-audio-bench")["test"] | |
| self.subset = subset | |
| self._span = span | |
| self._visual = visual | |
| if subset is not None: | |
| self.dataset = self.dataset.filter(lambda x: subset in x["paper_eval_sets"]) | |
| self.cache_path = os.path.join(cache_path, "sam_audio_bench") | |
| self.collate_fn = collate_fn | |
| DATA_MSG = ( | |
| f"`SAMAudioBench` requires the user to create a directory named {self.cache_path} " | |
| "see the README.md file for how to prepare" | |
| ) | |
| assert os.path.exists(self.cache_path), DATA_MSG | |
| def visual(self): | |
| return self._visual | |
| def __len__(self): | |
| return len(self.dataset) | |
| def _get_path( | |
| self, video_id: str, source_dataset: str, start_offset: float, end_offset: float | |
| ) -> str: | |
| path = f"{self.cache_path}/{source_dataset}/{video_id}.mp4" | |
| select_frames = True | |
| if not os.path.exists(path): | |
| path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset * 1000)}_{int(end_offset * 1000)}.mp4" | |
| select_frames = False | |
| if not os.path.exists(path): | |
| path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset)}_{int(end_offset)}.mp4" | |
| if not os.path.exists(path): | |
| path = f"{self.cache_path}/{source_dataset}/{video_id}.{int(start_offset * 1000):08d}_{int(end_offset * 1000):08d}.mp4" | |
| return path, select_frames | |
| def collate(self, items: list[Item]): | |
| has_video = any(item.masked_video_frames is not None for item in items) | |
| return self.collate_fn( | |
| descriptions=[item.description for item in items], | |
| audios=[item.audio_samples for item in items], | |
| anchors=[item.anchors for item in items] if self._span else None, | |
| masked_videos=[item.masked_video_frames for item in items] | |
| if has_video and self._visual | |
| else None, | |
| ) | |
| def _get_masked_video(self, item, video_path, select_frames): | |
| if item["mask_bytes"] is None: | |
| return None | |
| mask = torch.from_numpy(np.load(BytesIO(item["mask_bytes"]))["video_masklet"]) | |
| video_decoder = VideoDecoder(video_path) | |
| if select_frames: | |
| video_frames = video_decoder.get_frames_played_in_range( | |
| item["start_offset"], item["end_offset"] | |
| ).data | |
| else: | |
| video_frames = video_decoder[:].data | |
| if mask.size(0) != video_frames.size(0): | |
| # It's possible that the mask and the video frames differ by a small amount | |
| # we interpolate the mask frame to match | |
| idxs = ( | |
| torch.linspace(0, mask.size(0) - 1, video_frames.size(0)).round().long() | |
| ) | |
| mask = mask[idxs] | |
| mask = mask.unsqueeze(1) | |
| if mask.shape[-2:] != video_frames.shape[-2:]: | |
| mask = F.interpolate(mask, size=video_frames.shape[-2:]) | |
| import torchvision | |
| torchvision.io.write_video("test.mp4", video_frames.permute(0, 2, 3, 1), 30) | |
| torchvision.io.write_video( | |
| "test_mask.mp4", mask.unsqueeze(-1).expand(-1, -1, -1, 3) * 255, 30 | |
| ) | |
| return video_frames * mask | |
| def __getitem__(self, idx) -> Item: | |
| item = self.dataset[idx] | |
| video_path, select_frames = self._get_path( | |
| item["video_id"], | |
| item["source_dataset"], | |
| item["start_offset"], | |
| item["end_offset"], | |
| ) | |
| assert os.path.exists(video_path), f"{video_path} does not exist!" | |
| audio_decoder = AudioDecoder(video_path) | |
| audio_samples = audio_decoder.get_samples_played_in_range( | |
| start_seconds=item["start_offset"] if select_frames else 0, | |
| stop_seconds=item["end_offset"] if select_frames else None, | |
| ) | |
| if audio_samples.sample_rate != self.collate_fn.audio_sampling_rate: | |
| resampled_audio = torchaudio.functional.resample( | |
| audio_samples.data, | |
| audio_samples.sample_rate, | |
| self.collate_fn.audio_sampling_rate, | |
| ) | |
| else: | |
| resampled_audio = audio_samples.data | |
| masked_video_frames = self._get_masked_video(item, video_path, select_frames) | |
| return Item( | |
| description=item["description"], | |
| anchors=[("+", start, end) for start, end in item["spans"]], | |
| masked_video_frames=masked_video_frames, | |
| audio_samples=resampled_audio.mean(0, keepdim=True), | |
| ) | |