# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n import math from typing import List, Union import torch import torchaudio from sam_audio.model.config import ImageBindRankerConfig from sam_audio.ranking.ranker import Ranker try: from imagebind.data import ( ConstantClipsPerVideoSampler, NormalizeVideo, SpatialCrop, get_clip_timepoints, load_and_transform_video_data, pv_transforms, transforms, waveform2melspec, ) from imagebind.models.imagebind_model import ModalityType, imagebind_huge __imagebind_exists__ = True except ImportError: __imagebind_exists__ = False def load_and_transform_audio_data( audios: List[Union[str, torch.Tensor]], input_sample_rate=None, num_mel_bins=128, target_length=204, sample_rate=16000, clip_duration=2, clips_per_video=3, mean=-4.268, std=9.138, device=None, ): if audios is None: return None audio_outputs = [] clip_sampler = ConstantClipsPerVideoSampler( clip_duration=clip_duration, clips_per_video=clips_per_video ) for audio in audios: if isinstance(audio, str): waveform, input_sample_rate = torchaudio.load(audio) else: assert torch.is_tensor(audio) assert sample_rate is not None # Preprocessing needs to be done in full precision waveform = audio.float() if waveform.ndim == 1: waveform = waveform[None] if sample_rate != input_sample_rate: waveform = torchaudio.functional.resample( waveform, orig_freq=input_sample_rate, new_freq=sample_rate ) all_clips_timepoints = get_clip_timepoints( clip_sampler, waveform.size(1) / sample_rate ) all_clips = [] for clip_timepoints in all_clips_timepoints: waveform_clip = waveform[ :, int(clip_timepoints[0] * sample_rate) : int( clip_timepoints[1] * sample_rate ), ] waveform_melspec = waveform2melspec( waveform_clip, sample_rate, num_mel_bins, target_length ) all_clips.append(waveform_melspec) normalize = transforms.Normalize(mean=mean, std=std) all_clips = [normalize(ac).to(device) for ac in all_clips] all_clips = torch.stack(all_clips, dim=0) audio_outputs.append(all_clips) return torch.stack(audio_outputs, dim=0) class VideoTransform: def __init__(self, clip_duration=2, clips_per_video=5): self.clip_duration = clip_duration self.clips_per_video = clips_per_video self.clip_sampler = ConstantClipsPerVideoSampler( clip_duration=clip_duration, clips_per_video=clips_per_video ) self.video_transform = transforms.Compose( [ pv_transforms.ShortSideScale(224), NormalizeVideo( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) self.spatial_crop = SpatialCrop(224, num_crops=3) def load_video_fast(self, videos, durations, **kwargs): result = [] for video, duration in zip(videos, durations, strict=False): nframes = video.size(0) fps = video.size(0) / duration timepoints = get_clip_timepoints( self.clip_sampler, duration, ) # Instead of loading 5 2s clips, and then sub-sampling frames, we figure # Out the indices of the final clips we want and only decode those. all_idxs = [] for start_time, end_time in timepoints: idxs = torch.arange( min(int(math.ceil(fps * start_time)), nframes - 1), min(int(math.ceil(fps * end_time)), nframes), ) ts = ( torch.linspace(0, idxs.size(0) - 1, self.clip_duration) .clamp(max=idxs.size(0) - 1) .long() ) all_idxs.append(idxs[ts]) all_idxs = torch.cat(all_idxs) fast_frames = video[all_idxs].transpose(0, 1) result.append(fast_frames.chunk(self.clips_per_video, dim=1)) return result def transform_video(self, batch, device=None): device = device or torch.device("cpu") video_outputs = [] for all_video in batch: all_video = [ self.video_transform(clip.to(device) / 255.0) for clip in all_video ] all_video = self.spatial_crop(all_video) all_video = torch.stack(all_video, dim=0) video_outputs.append(all_video) return torch.stack(video_outputs, dim=0) def __call__(self, videos, durations, device=None): return self.transform_video( self.load_video_fast(videos, durations), device=device ) class ImageBindRanker(Ranker): def __init__(self, cfg: ImageBindRankerConfig): super().__init__() assert __imagebind_exists__, ( "Install ImageBind in order to use this ranker: https://github.com/facebookresearch/ImageBind/tree/main" ) self.model = imagebind_huge(pretrained=cfg.checkpoint is None) if cfg.checkpoint is not None: self.model.load_state_dict(torch.load(cfg.checkpoint, map_location="cpu")) self.model = self.model.eval() self.video_transform = VideoTransform() @torch.inference_mode() def forward( self, extracted_audio: list[torch.Tensor], videos: list[torch.Tensor | str], sample_rate: int = 48_000, **kwargs, ): audio_data = torch.cat( [ load_and_transform_audio_data(x, input_sample_rate=sample_rate) for x in extracted_audio ], dim=0, ) if isinstance(videos[0], str): video_data = load_and_transform_video_data(videos) else: durations = [x.size(-1) / sample_rate for x in extracted_audio] video_data = self.video_transform(videos, durations, audio_data.device) inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data} embs = self.model(inputs) audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION] audio_embs, video_embs = ( audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5), video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5), ) bsz = len(extracted_audio) candidates = len(audio_embs) // bsz scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1) return scores