Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| from typing import Optional | |
| import torch | |
| from imagebind.models.imagebind_model import ModalityType, imagebind_huge | |
| from sam_audio.ranking.imagebind import VideoTransform, load_and_transform_audio_data | |
| class ImageBind(torch.nn.Module): | |
| def __init__( | |
| self, | |
| checkpoint: Optional[str] = None, | |
| device: Optional[torch.device] = None, | |
| ): | |
| super().__init__() | |
| self.model = imagebind_huge(pretrained=checkpoint is None) | |
| if checkpoint is not None: | |
| self.model.load_state_dict(torch.load(checkpoint, map_location="cpu")) | |
| self.model = self.model.eval() | |
| self.video_transform = VideoTransform() | |
| self.device = device or torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self.model = self.model.to(self.device) | |
| def __call__( | |
| self, | |
| target_wavs: list[torch.Tensor], | |
| videos: list[torch.Tensor], | |
| target_wavs_sample_rate: int = 48_000, | |
| **kwargs, | |
| ) -> dict[str, list[float]]: | |
| audio_data = load_and_transform_audio_data( | |
| target_wavs, input_sample_rate=target_wavs_sample_rate | |
| ) | |
| durations = [x.size(-1) / target_wavs_sample_rate for x in target_wavs] | |
| video_data = self.video_transform(videos, durations, audio_data.device) | |
| inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data} | |
| embs = self.model(inputs) | |
| audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION] | |
| audio_embs, video_embs = ( | |
| audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5), | |
| video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5), | |
| ) | |
| bsz = len(target_wavs) | |
| candidates = len(audio_embs) // bsz | |
| scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1) | |
| return {"ImageBind": scores.squeeze(1, 2).cpu().tolist()} | |