ray-006's picture
Upload 43 files
fc605f9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
from tempfile import TemporaryDirectory
from typing import Optional
import torch
from torchcodec.encoders import AudioEncoder
from sam_audio.ranking.clap import get_model
class CLAP(torch.nn.Module):
def __init__(
self,
checkpoint: Optional[str] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.model = get_model(device)
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
def __call__(
self,
target_wavs: list[torch.Tensor],
descriptions: list[str],
target_wavs_sample_rate: int = 48_000,
**kwargs,
) -> list[dict[str, float]]:
with TemporaryDirectory() as tdir, torch.inference_mode():
file_list = []
for i, wav in enumerate(target_wavs):
file_list.append(f"{tdir}/hyp_{i}.wav")
encoder = AudioEncoder(
samples=wav.cpu()[None] if wav.ndim == 1 else wav.cpu(),
sample_rate=target_wavs_sample_rate,
)
encoder.to_file(file_list[-1])
audio_embs = self.model.get_audio_embedding_from_filelist(
file_list, use_tensor=True
)
text_embs = self.model.get_text_embedding(descriptions, use_tensor=True)
sims = audio_embs.unsqueeze(1) @ text_embs.unsqueeze(2)
return {"CLAPSimilarity": sims.cpu()[:, 0, 0].tolist()}