ray-006's picture
Upload 43 files
fc605f9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import math
from typing import List, Union
import torch
import torchaudio
from sam_audio.model.config import ImageBindRankerConfig
from sam_audio.ranking.ranker import Ranker
try:
from imagebind.data import (
ConstantClipsPerVideoSampler,
NormalizeVideo,
SpatialCrop,
get_clip_timepoints,
load_and_transform_video_data,
pv_transforms,
transforms,
waveform2melspec,
)
from imagebind.models.imagebind_model import ModalityType, imagebind_huge
__imagebind_exists__ = True
except ImportError:
__imagebind_exists__ = False
def load_and_transform_audio_data(
audios: List[Union[str, torch.Tensor]],
input_sample_rate=None,
num_mel_bins=128,
target_length=204,
sample_rate=16000,
clip_duration=2,
clips_per_video=3,
mean=-4.268,
std=9.138,
device=None,
):
if audios is None:
return None
audio_outputs = []
clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
for audio in audios:
if isinstance(audio, str):
waveform, input_sample_rate = torchaudio.load(audio)
else:
assert torch.is_tensor(audio)
assert sample_rate is not None
# Preprocessing needs to be done in full precision
waveform = audio.float()
if waveform.ndim == 1:
waveform = waveform[None]
if sample_rate != input_sample_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=input_sample_rate, new_freq=sample_rate
)
all_clips_timepoints = get_clip_timepoints(
clip_sampler, waveform.size(1) / sample_rate
)
all_clips = []
for clip_timepoints in all_clips_timepoints:
waveform_clip = waveform[
:,
int(clip_timepoints[0] * sample_rate) : int(
clip_timepoints[1] * sample_rate
),
]
waveform_melspec = waveform2melspec(
waveform_clip, sample_rate, num_mel_bins, target_length
)
all_clips.append(waveform_melspec)
normalize = transforms.Normalize(mean=mean, std=std)
all_clips = [normalize(ac).to(device) for ac in all_clips]
all_clips = torch.stack(all_clips, dim=0)
audio_outputs.append(all_clips)
return torch.stack(audio_outputs, dim=0)
class VideoTransform:
def __init__(self, clip_duration=2, clips_per_video=5):
self.clip_duration = clip_duration
self.clips_per_video = clips_per_video
self.clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
self.video_transform = transforms.Compose(
[
pv_transforms.ShortSideScale(224),
NormalizeVideo(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
self.spatial_crop = SpatialCrop(224, num_crops=3)
def load_video_fast(self, videos, durations, **kwargs):
result = []
for video, duration in zip(videos, durations, strict=False):
nframes = video.size(0)
fps = video.size(0) / duration
timepoints = get_clip_timepoints(
self.clip_sampler,
duration,
)
# Instead of loading 5 2s clips, and then sub-sampling frames, we figure
# Out the indices of the final clips we want and only decode those.
all_idxs = []
for start_time, end_time in timepoints:
idxs = torch.arange(
min(int(math.ceil(fps * start_time)), nframes - 1),
min(int(math.ceil(fps * end_time)), nframes),
)
ts = (
torch.linspace(0, idxs.size(0) - 1, self.clip_duration)
.clamp(max=idxs.size(0) - 1)
.long()
)
all_idxs.append(idxs[ts])
all_idxs = torch.cat(all_idxs)
fast_frames = video[all_idxs].transpose(0, 1)
result.append(fast_frames.chunk(self.clips_per_video, dim=1))
return result
def transform_video(self, batch, device=None):
device = device or torch.device("cpu")
video_outputs = []
for all_video in batch:
all_video = [
self.video_transform(clip.to(device) / 255.0) for clip in all_video
]
all_video = self.spatial_crop(all_video)
all_video = torch.stack(all_video, dim=0)
video_outputs.append(all_video)
return torch.stack(video_outputs, dim=0)
def __call__(self, videos, durations, device=None):
return self.transform_video(
self.load_video_fast(videos, durations), device=device
)
class ImageBindRanker(Ranker):
def __init__(self, cfg: ImageBindRankerConfig):
super().__init__()
assert __imagebind_exists__, (
"Install ImageBind in order to use this ranker: https://github.com/facebookresearch/ImageBind/tree/main"
)
self.model = imagebind_huge(pretrained=cfg.checkpoint is None)
if cfg.checkpoint is not None:
self.model.load_state_dict(torch.load(cfg.checkpoint, map_location="cpu"))
self.model = self.model.eval()
self.video_transform = VideoTransform()
@torch.inference_mode()
def forward(
self,
extracted_audio: list[torch.Tensor],
videos: list[torch.Tensor | str],
sample_rate: int = 48_000,
**kwargs,
):
audio_data = torch.cat(
[
load_and_transform_audio_data(x, input_sample_rate=sample_rate)
for x in extracted_audio
],
dim=0,
)
if isinstance(videos[0], str):
video_data = load_and_transform_video_data(videos)
else:
durations = [x.size(-1) / sample_rate for x in extracted_audio]
video_data = self.video_transform(videos, durations, audio_data.device)
inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
embs = self.model(inputs)
audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
audio_embs, video_embs = (
audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
)
bsz = len(extracted_audio)
candidates = len(audio_embs) // bsz
scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
return scores