Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| import math | |
| from typing import List, Union | |
| import torch | |
| import torchaudio | |
| from sam_audio.model.config import ImageBindRankerConfig | |
| from sam_audio.ranking.ranker import Ranker | |
| try: | |
| from imagebind.data import ( | |
| ConstantClipsPerVideoSampler, | |
| NormalizeVideo, | |
| SpatialCrop, | |
| get_clip_timepoints, | |
| load_and_transform_video_data, | |
| pv_transforms, | |
| transforms, | |
| waveform2melspec, | |
| ) | |
| from imagebind.models.imagebind_model import ModalityType, imagebind_huge | |
| __imagebind_exists__ = True | |
| except ImportError: | |
| __imagebind_exists__ = False | |
| def load_and_transform_audio_data( | |
| audios: List[Union[str, torch.Tensor]], | |
| input_sample_rate=None, | |
| num_mel_bins=128, | |
| target_length=204, | |
| sample_rate=16000, | |
| clip_duration=2, | |
| clips_per_video=3, | |
| mean=-4.268, | |
| std=9.138, | |
| device=None, | |
| ): | |
| if audios is None: | |
| return None | |
| audio_outputs = [] | |
| clip_sampler = ConstantClipsPerVideoSampler( | |
| clip_duration=clip_duration, clips_per_video=clips_per_video | |
| ) | |
| for audio in audios: | |
| if isinstance(audio, str): | |
| waveform, input_sample_rate = torchaudio.load(audio) | |
| else: | |
| assert torch.is_tensor(audio) | |
| assert sample_rate is not None | |
| # Preprocessing needs to be done in full precision | |
| waveform = audio.float() | |
| if waveform.ndim == 1: | |
| waveform = waveform[None] | |
| if sample_rate != input_sample_rate: | |
| waveform = torchaudio.functional.resample( | |
| waveform, orig_freq=input_sample_rate, new_freq=sample_rate | |
| ) | |
| all_clips_timepoints = get_clip_timepoints( | |
| clip_sampler, waveform.size(1) / sample_rate | |
| ) | |
| all_clips = [] | |
| for clip_timepoints in all_clips_timepoints: | |
| waveform_clip = waveform[ | |
| :, | |
| int(clip_timepoints[0] * sample_rate) : int( | |
| clip_timepoints[1] * sample_rate | |
| ), | |
| ] | |
| waveform_melspec = waveform2melspec( | |
| waveform_clip, sample_rate, num_mel_bins, target_length | |
| ) | |
| all_clips.append(waveform_melspec) | |
| normalize = transforms.Normalize(mean=mean, std=std) | |
| all_clips = [normalize(ac).to(device) for ac in all_clips] | |
| all_clips = torch.stack(all_clips, dim=0) | |
| audio_outputs.append(all_clips) | |
| return torch.stack(audio_outputs, dim=0) | |
| class VideoTransform: | |
| def __init__(self, clip_duration=2, clips_per_video=5): | |
| self.clip_duration = clip_duration | |
| self.clips_per_video = clips_per_video | |
| self.clip_sampler = ConstantClipsPerVideoSampler( | |
| clip_duration=clip_duration, clips_per_video=clips_per_video | |
| ) | |
| self.video_transform = transforms.Compose( | |
| [ | |
| pv_transforms.ShortSideScale(224), | |
| NormalizeVideo( | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| ) | |
| self.spatial_crop = SpatialCrop(224, num_crops=3) | |
| def load_video_fast(self, videos, durations, **kwargs): | |
| result = [] | |
| for video, duration in zip(videos, durations, strict=False): | |
| nframes = video.size(0) | |
| fps = video.size(0) / duration | |
| timepoints = get_clip_timepoints( | |
| self.clip_sampler, | |
| duration, | |
| ) | |
| # Instead of loading 5 2s clips, and then sub-sampling frames, we figure | |
| # Out the indices of the final clips we want and only decode those. | |
| all_idxs = [] | |
| for start_time, end_time in timepoints: | |
| idxs = torch.arange( | |
| min(int(math.ceil(fps * start_time)), nframes - 1), | |
| min(int(math.ceil(fps * end_time)), nframes), | |
| ) | |
| ts = ( | |
| torch.linspace(0, idxs.size(0) - 1, self.clip_duration) | |
| .clamp(max=idxs.size(0) - 1) | |
| .long() | |
| ) | |
| all_idxs.append(idxs[ts]) | |
| all_idxs = torch.cat(all_idxs) | |
| fast_frames = video[all_idxs].transpose(0, 1) | |
| result.append(fast_frames.chunk(self.clips_per_video, dim=1)) | |
| return result | |
| def transform_video(self, batch, device=None): | |
| device = device or torch.device("cpu") | |
| video_outputs = [] | |
| for all_video in batch: | |
| all_video = [ | |
| self.video_transform(clip.to(device) / 255.0) for clip in all_video | |
| ] | |
| all_video = self.spatial_crop(all_video) | |
| all_video = torch.stack(all_video, dim=0) | |
| video_outputs.append(all_video) | |
| return torch.stack(video_outputs, dim=0) | |
| def __call__(self, videos, durations, device=None): | |
| return self.transform_video( | |
| self.load_video_fast(videos, durations), device=device | |
| ) | |
| class ImageBindRanker(Ranker): | |
| def __init__(self, cfg: ImageBindRankerConfig): | |
| super().__init__() | |
| assert __imagebind_exists__, ( | |
| "Install ImageBind in order to use this ranker: https://github.com/facebookresearch/ImageBind/tree/main" | |
| ) | |
| self.model = imagebind_huge(pretrained=cfg.checkpoint is None) | |
| if cfg.checkpoint is not None: | |
| self.model.load_state_dict(torch.load(cfg.checkpoint, map_location="cpu")) | |
| self.model = self.model.eval() | |
| self.video_transform = VideoTransform() | |
| def forward( | |
| self, | |
| extracted_audio: list[torch.Tensor], | |
| videos: list[torch.Tensor | str], | |
| sample_rate: int = 48_000, | |
| **kwargs, | |
| ): | |
| audio_data = torch.cat( | |
| [ | |
| load_and_transform_audio_data(x, input_sample_rate=sample_rate) | |
| for x in extracted_audio | |
| ], | |
| dim=0, | |
| ) | |
| if isinstance(videos[0], str): | |
| video_data = load_and_transform_video_data(videos) | |
| else: | |
| durations = [x.size(-1) / sample_rate for x in extracted_audio] | |
| video_data = self.video_transform(videos, durations, audio_data.device) | |
| inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data} | |
| embs = self.model(inputs) | |
| audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION] | |
| audio_embs, video_embs = ( | |
| audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5), | |
| video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5), | |
| ) | |
| bsz = len(extracted_audio) | |
| candidates = len(audio_embs) // bsz | |
| scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1) | |
| return scores | |