Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,032 Bytes
fc605f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
from typing import Optional
import torch
from imagebind.models.imagebind_model import ModalityType, imagebind_huge
from sam_audio.ranking.imagebind import VideoTransform, load_and_transform_audio_data
class ImageBind(torch.nn.Module):
def __init__(
self,
checkpoint: Optional[str] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.model = imagebind_huge(pretrained=checkpoint is None)
if checkpoint is not None:
self.model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
self.model = self.model.eval()
self.video_transform = VideoTransform()
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.model = self.model.to(self.device)
def __call__(
self,
target_wavs: list[torch.Tensor],
videos: list[torch.Tensor],
target_wavs_sample_rate: int = 48_000,
**kwargs,
) -> dict[str, list[float]]:
audio_data = load_and_transform_audio_data(
target_wavs, input_sample_rate=target_wavs_sample_rate
)
durations = [x.size(-1) / target_wavs_sample_rate for x in target_wavs]
video_data = self.video_transform(videos, durations, audio_data.device)
inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
embs = self.model(inputs)
audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
audio_embs, video_embs = (
audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
)
bsz = len(target_wavs)
candidates = len(audio_embs) // bsz
scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
return {"ImageBind": scores.squeeze(1, 2).cpu().tolist()}
|