Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| from tempfile import TemporaryDirectory | |
| from typing import Optional | |
| import torch | |
| from torchcodec.encoders import AudioEncoder | |
| from sam_audio.ranking.clap import get_model | |
| class CLAP(torch.nn.Module): | |
| def __init__( | |
| self, | |
| checkpoint: Optional[str] = None, | |
| device: Optional[torch.device] = None, | |
| ): | |
| super().__init__() | |
| self.model = get_model(device) | |
| self.device = device or torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| def __call__( | |
| self, | |
| target_wavs: list[torch.Tensor], | |
| descriptions: list[str], | |
| target_wavs_sample_rate: int = 48_000, | |
| **kwargs, | |
| ) -> list[dict[str, float]]: | |
| with TemporaryDirectory() as tdir, torch.inference_mode(): | |
| file_list = [] | |
| for i, wav in enumerate(target_wavs): | |
| file_list.append(f"{tdir}/hyp_{i}.wav") | |
| encoder = AudioEncoder( | |
| samples=wav.cpu()[None] if wav.ndim == 1 else wav.cpu(), | |
| sample_rate=target_wavs_sample_rate, | |
| ) | |
| encoder.to_file(file_list[-1]) | |
| audio_embs = self.model.get_audio_embedding_from_filelist( | |
| file_list, use_tensor=True | |
| ) | |
| text_embs = self.model.get_text_embedding(descriptions, use_tensor=True) | |
| sims = audio_embs.unsqueeze(1) @ text_embs.unsqueeze(2) | |
| return {"CLAPSimilarity": sims.cpu()[:, 0, 0].tolist()} | |