Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| import json | |
| import logging | |
| import math | |
| import os | |
| from typing import Callable, List, Optional, Tuple | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import hf_hub_download | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torchcodec.decoders import AudioDecoder, VideoDecoder | |
| from transformers import AutoTokenizer, BatchFeature | |
| from sam_audio.model.config import SAMAudioConfig, SAMAudioJudgeConfig | |
| logger = logging.getLogger(__name__) | |
| Anchor = Tuple[str, float, float] | |
| def batch_audio( | |
| audios: list[str | torch.Tensor], audio_sampling_rate: int = 48_000 | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| wavs = [] | |
| for audio in audios: | |
| if isinstance(audio, str): | |
| wav, sr = torchaudio.load(audio) | |
| if sr != audio_sampling_rate: | |
| wav = torchaudio.functional.resample(wav, sr, audio_sampling_rate) | |
| else: | |
| wav = audio | |
| wavs.append(wav.mean(0)) | |
| sizes = torch.tensor([wav.size(-1) for wav in wavs]) | |
| return pad_sequence(wavs, batch_first=True).unsqueeze(1), sizes | |
| class Batch: | |
| def __init__( | |
| self, | |
| audios: torch.Tensor, | |
| sizes: torch.Tensor, | |
| wav_sizes: torch.Tensor, | |
| descriptions: list[str], | |
| hop_length: int, | |
| audio_sampling_rate: int, | |
| anchors: Optional[list[list[Anchor]]] = None, | |
| audio_pad_mask: Optional[torch.Tensor] = None, | |
| masked_video: Optional[torch.Tensor] = None, | |
| ): | |
| self.audios = audios | |
| self.sizes = sizes | |
| self.wav_sizes = wav_sizes | |
| self.descriptions = descriptions | |
| self.audio_pad_mask = audio_pad_mask | |
| self.masked_video = masked_video | |
| self.hop_length = hop_length | |
| self.audio_sampling_rate = audio_sampling_rate | |
| self.process_anchors(anchors) | |
| assert self.audios.size(0) == len(self.descriptions) | |
| def _wav_to_feature_idx(self, wav_idx: int): | |
| return math.ceil(wav_idx / self.hop_length) | |
| def to(self, device: torch.device): | |
| self.audios = self.audios.to(device) | |
| self.anchor_ids = self.anchor_ids.to(device) | |
| self.anchor_alignment = self.anchor_alignment.to(device) | |
| self.sizes = self.sizes.to(device) | |
| self.wav_sizes = self.wav_sizes.to(device) | |
| if self.audio_pad_mask is not None: | |
| self.audio_pad_mask = self.audio_pad_mask.to(device) | |
| if self.masked_video is not None: | |
| self.masked_video = [v.to(device) for v in self.masked_video] | |
| return self | |
| def process_anchors(self, anchors: Optional[list[list[Anchor]]]): | |
| batch_size = len(self.audios) | |
| anchor_dict = {"<null>": 0, "+": 1, "-": 2, "<pad>": 3} | |
| if anchors is None: | |
| anchor_ids = torch.full( | |
| (batch_size, 2), anchor_dict["<null>"], dtype=torch.long | |
| ) | |
| anchor_ids[:, 1] = anchor_dict["<pad>"] | |
| anchor_alignment = torch.full( | |
| ( | |
| batch_size, | |
| self.audio_pad_mask.size(-1), | |
| ), | |
| 0, | |
| dtype=torch.long, | |
| ) | |
| anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token | |
| else: | |
| anchor_alignment = torch.full( | |
| ( | |
| batch_size, | |
| self.audio_pad_mask.size(-1), | |
| ), | |
| 0, | |
| dtype=torch.long, | |
| ) | |
| anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token | |
| ids = [] | |
| for i, anchor_list in enumerate(anchors): | |
| current = [anchor_dict["<null>"], anchor_dict["<pad>"]] | |
| for token, start_time, end_time in anchor_list: | |
| start_idx = self._wav_to_feature_idx( | |
| start_time * self.audio_sampling_rate | |
| ) | |
| end_idx = self._wav_to_feature_idx( | |
| end_time * self.audio_sampling_rate | |
| ) | |
| anchor_alignment[i, start_idx:end_idx] = len(current) | |
| current.append(anchor_dict[token]) | |
| ids.append(torch.tensor(current)) | |
| anchor_ids = pad_sequence( | |
| ids, batch_first=True, padding_value=anchor_dict["<pad>"] | |
| ) | |
| self.anchor_ids = anchor_ids | |
| self.anchor_alignment = anchor_alignment | |
| self.anchors = anchors | |
| def mask_from_sizes(sizes: torch.Tensor) -> torch.Tensor: | |
| return torch.arange(sizes.max()).expand(len(sizes), -1) < sizes.unsqueeze(1) | |
| def load_video( | |
| sizes: torch.Tensor, | |
| videos: List[str], | |
| feature_idx_to_wav_idx: Callable[[torch.Tensor], torch.Tensor], | |
| audio_sampling_rate: int, | |
| ) -> list[torch.Tensor]: | |
| all_frames = [] | |
| for size, video in zip(sizes, videos, strict=False): | |
| audio_timestamps = ( | |
| feature_idx_to_wav_idx(torch.arange(size)) / audio_sampling_rate | |
| ) | |
| if isinstance(video, str): | |
| decoder = VideoDecoder(video, dimension_order="NCHW") | |
| data = decoder.get_frames_in_range(0, len(decoder)) | |
| diffs = (audio_timestamps[None] - data.pts_seconds[:, None]).abs() | |
| frame_idxs = diffs.argmin(dim=0) | |
| frames = data.data[frame_idxs] | |
| else: | |
| assert video.size(1) == 3, ( | |
| f"Expected video tensor to be in NCHW format, but found {video.size(1)} channels" | |
| ) | |
| idx = torch.linspace(0, video.size(0) - 1, int(size)).round().long() | |
| frames = video[idx] | |
| all_frames.append(frames) | |
| return all_frames | |
| class Processor: | |
| config_cls: Callable | |
| def __init__(self, audio_hop_length: int, audio_sampling_rate: int): | |
| self.audio_hop_length = audio_hop_length | |
| self.audio_sampling_rate = audio_sampling_rate | |
| def _get_config(cls, model_name_or_path: str): | |
| if os.path.exists(model_name_or_path): | |
| config_path = os.path.join(model_name_or_path, "config.json") | |
| else: | |
| config_path = hf_hub_download( | |
| repo_id=model_name_or_path, | |
| filename="config.json", | |
| revision=cls.revision, | |
| ) | |
| with open(config_path) as fin: | |
| config = cls.config_cls(**json.load(fin)) | |
| return config | |
| def from_pretrained(cls, model_name_or_path: str) -> "Processor": | |
| config = cls._get_config(model_name_or_path) | |
| return cls( | |
| audio_hop_length=config.audio_codec.hop_length, | |
| audio_sampling_rate=config.audio_codec.sample_rate, | |
| ) | |
| def feature_to_wav_idx(self, feature_idx): | |
| return feature_idx * self.audio_hop_length | |
| def wav_to_feature_idx(self, wav_idx): | |
| if torch.is_tensor(wav_idx): | |
| ceil = torch.ceil | |
| else: | |
| ceil = math.ceil | |
| return ceil(wav_idx / self.audio_hop_length) | |
| def mask_videos( | |
| self, | |
| videos: List[str | torch.Tensor], | |
| masks: List[str | torch.Tensor], | |
| ) -> list[torch.Tensor]: | |
| video = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in videos] | |
| video_mask = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in masks] | |
| return [v * m.eq(0) for v, m in zip(video, video_mask, strict=False)] | |
| class SAMAudioProcessor(Processor): | |
| config_cls = SAMAudioConfig | |
| revision = None | |
| def __call__( | |
| self, | |
| descriptions: list[str], | |
| audios: list[str | torch.Tensor], | |
| anchors: Optional[list[list[Anchor]]] = None, | |
| masked_videos: Optional[list[str | torch.Tensor]] = None, | |
| ): | |
| """ | |
| Processes input data for the model. | |
| Args: | |
| descriptions (list[str]): List of text descriptions corresponding to each audio sample. | |
| audios (list[str]): List of audio file paths or tensors. | |
| If a tensor: | |
| - should have shape (channels, time) where channels=1 for mono and 2 for stereo. | |
| - should be resampled to 48_000 hz | |
| anchors (Optional[list[list[Anchor]]], optional): List of anchors for each sample, | |
| where each anchor is a tuple (token, start_time, end_time). | |
| masked_videos (Optional[list[str | torch.Tensor]], optional): List of masked video file paths or tensors. | |
| If a tensor, should have shape (N, C, H, W) | |
| Returns: | |
| Batch: A Batch object containing processed audio, sizes, descriptions, anchor ids, anchor alignment, audio pad mask, and optionally masked video. | |
| """ | |
| assert len(descriptions) == len(audios) | |
| assert anchors is None or len(descriptions) == len(anchors) | |
| assert masked_videos is None or len(descriptions) == len(masked_videos) | |
| audios, wav_sizes = batch_audio(audios, self.audio_sampling_rate) | |
| sizes = self.wav_to_feature_idx(wav_sizes) | |
| audio_pad_mask = mask_from_sizes(sizes) | |
| masked_video = None | |
| if masked_videos is not None: | |
| masked_video = load_video( | |
| sizes, masked_videos, self.feature_to_wav_idx, self.audio_sampling_rate | |
| ) | |
| return Batch( | |
| audios=audios, | |
| sizes=sizes, | |
| descriptions=descriptions, | |
| audio_pad_mask=audio_pad_mask, | |
| anchors=anchors, | |
| masked_video=masked_video, | |
| hop_length=self.audio_hop_length, | |
| audio_sampling_rate=self.audio_sampling_rate, | |
| wav_sizes=wav_sizes, | |
| ) | |
| class SAMAudioJudgeProcessor(Processor): | |
| config_cls = SAMAudioJudgeConfig | |
| revision = "sam_audio" | |
| def __init__( | |
| self, | |
| audio_hop_length: int, | |
| audio_sampling_rate: int, | |
| tokenizer: AutoTokenizer, | |
| ): | |
| super().__init__(audio_hop_length, audio_sampling_rate) | |
| self.tokenizer = tokenizer | |
| def from_pretrained(cls, model_name_or_path: str) -> "SAMAudioJudgeProcessor": | |
| config = cls._get_config(model_name_or_path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| return cls( | |
| audio_hop_length=config.audio_codec.hop_length, | |
| audio_sampling_rate=config.audio_codec.sample_rate, | |
| tokenizer=tokenizer, | |
| ) | |
| def _reflect_pad(self, wav): | |
| if wav.ndim == 1: | |
| wav = wav.unsqueeze(0) | |
| if wav.size(-1) % self.audio_hop_length == 0: | |
| return wav | |
| p1d = (0, self.audio_hop_length - (wav.size(-1) % self.audio_hop_length)) | |
| return torch.nn.functional.pad(wav, p1d, mode="reflect") | |
| def _load_audio(self, path: str): | |
| ad = AudioDecoder(path, sample_rate=self.audio_sampling_rate, num_channels=1) | |
| return ad.get_all_samples().data | |
| def _process_audio( | |
| self, | |
| raw_audio, | |
| sampling_rate: Optional[int] = None, | |
| ): | |
| from_file = False | |
| if isinstance(raw_audio, str): | |
| raw_audio = [raw_audio] | |
| if isinstance(raw_audio, (list, tuple)) and isinstance(raw_audio[0], str): | |
| loaded = [] | |
| for audio_file in raw_audio: | |
| loaded.append(self._load_audio(audio_file)) | |
| raw_audio = loaded | |
| from_file = True | |
| if sampling_rate is not None: | |
| if sampling_rate != self.audio_sampling_rate: | |
| raise ValueError( | |
| f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" | |
| f" {self.audio_sampling_rate}. Please make sure that the provided audio input was sampled with" | |
| f" {self.audio_sampling_rate} and not {sampling_rate}." | |
| ) | |
| elif not from_file: | |
| logger.warning( | |
| f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " | |
| "Failing to do so can result in silent errors that might be hard to debug." | |
| ) | |
| if isinstance(raw_audio, list): | |
| raw_audio = [self._reflect_pad(x).T for x in raw_audio] | |
| else: | |
| raw_audio = self._reflect_pad(raw_audio).T | |
| # verify inputs are valid | |
| for example in raw_audio: | |
| if example.ndim > 2: | |
| raise ValueError( | |
| f"Expected input shape (channels, num_samples), but got shape ({example.shape})" | |
| ) | |
| lengths = torch.tensor([x.size(0) for x in raw_audio]) | |
| input_values = pad_sequence(raw_audio, batch_first=True).transpose(1, 2) | |
| padding_mask = torch.arange(lengths.max())[None] < lengths[:, None] | |
| return BatchFeature( | |
| {"input_values": input_values, "padding_mask": padding_mask} | |
| ) | |
| def __call__( | |
| self, | |
| text: Optional[str] = None, | |
| input_audio: Optional[ | |
| str | list[str] | torch.Tensor | list[torch.Tensor] | |
| ] = None, | |
| separated_audio: Optional[ | |
| str | list[str] | torch.Tensor | list[torch.Tensor] | |
| ] = None, | |
| sampling_rate: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| batch = BatchFeature() | |
| if text is not None: | |
| batch.update( | |
| self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="longest", | |
| max_length=512, | |
| truncation=True, | |
| ) | |
| ) | |
| if input_audio is not None: | |
| batch.update(self._process_audio(input_audio, sampling_rate)) | |
| if separated_audio is not None: | |
| batch["separated_values"] = self._process_audio( | |
| separated_audio, sampling_rate | |
| )["input_values"] | |
| return batch | |
| __all__ = ["SAMAudioProcessor", "SAMAudioJudgeProcessor", "Batch"] | |