# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n import math import re from dataclasses import dataclass from typing import Any, Dict, Optional import torch from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform from torchdiffeq import odeint from sam_audio.model.align import AlignModalities from sam_audio.model.base import BaseModel from sam_audio.model.codec import DACVAE from sam_audio.model.config import SAMAudioConfig from sam_audio.model.text_encoder import T5TextEncoder from sam_audio.model.transformer import DiT from sam_audio.model.vision_encoder import PerceptionEncoder from sam_audio.processor import Batch from sam_audio.ranking import create_ranker DFLT_ODE_OPT = {"method": "midpoint", "options": {"step_size": 2 / 32}} class SinusoidalEmbedding(torch.nn.Module): def __init__(self, dim, theta=10000): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 inv_freq = torch.exp( -math.log(theta) * torch.arange(half_dim).float() / half_dim ) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, x, pos=None): if pos is None: seq_len, device = x.shape[1], x.device pos = torch.arange(seq_len, device=device) emb = torch.einsum("i, j -> i j", pos, self.inv_freq) emb = torch.cat((emb.cos(), emb.sin()), dim=-1) return emb class EmbedAnchors(torch.nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int): super().__init__() self.embed = torch.nn.Embedding( num_embeddings + 1, embedding_dim, padding_idx=num_embeddings ) self.gate = torch.nn.Parameter(torch.tensor([0.0])) self.proj = torch.nn.Linear(embedding_dim, out_dim, bias=False) def forward( self, x: torch.Tensor, anchor_ids: Optional[torch.Tensor] = None, anchor_alignment: Optional[torch.Tensor] = None, ): if anchor_ids is None: return x embs = self.embed(anchor_ids.gather(1, anchor_alignment)) proj = self.proj(embs) return x + self.gate.tanh() * proj @dataclass class SeparationResult: target: torch.Tensor residual: torch.Tensor noise: torch.Tensor class SAMAudio(BaseModel): config_cls = SAMAudioConfig revision = None def __init__(self, cfg: SAMAudioConfig): super().__init__() self.audio_codec = DACVAE(cfg.audio_codec) self.text_encoder = T5TextEncoder(cfg.text_encoder) self.vision_encoder = PerceptionEncoder(cfg.vision_encoder) self.transformer = DiT(cfg.transformer) self.proj = torch.nn.Linear(cfg.in_channels, cfg.transformer.dim) self.align_masked_video = AlignModalities( cfg.vision_encoder.dim, cfg.transformer.dim ) self.embed_anchors = EmbedAnchors( cfg.num_anchors, cfg.anchor_embedding_dim, cfg.transformer.dim ) self.memory_proj = torch.nn.Linear(cfg.text_encoder.dim, cfg.transformer.dim) self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim) self.visual_ranker = create_ranker(cfg.visual_ranker) self.text_ranker = create_ranker(cfg.text_ranker) if cfg.span_predictor is not None: self.span_predictor = PEAudioFrame.from_config( cfg.span_predictor, pretrained=True ) self.span_predictor_transform = PEAudioFrameTransform.from_config( cfg.span_predictor ) @property def sample_rate(self): return self.audio_codec.sample_rate def align_inputs( self, noisy_audio, audio_features: torch.Tensor, masked_video_features: Optional[torch.Tensor] = None, anchor_ids: Optional[torch.Tensor] = None, anchor_alignment: Optional[torch.Tensor] = None, ): x = torch.cat( [ noisy_audio, torch.zeros_like(audio_features), audio_features, ], dim=2, ) projected = self.proj(x) aligned = self.align_masked_video(projected, masked_video_features) aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment) return aligned def forward( self, noisy_audio: torch.Tensor, audio_features: torch.Tensor, text_features: torch.Tensor, time: torch.Tensor, masked_video_features: Optional[torch.Tensor] = None, text_mask: Optional[torch.Tensor] = None, anchor_ids: Optional[torch.Tensor] = None, anchor_alignment: Optional[torch.Tensor] = None, audio_pad_mask: Optional[torch.Tensor] = None, ): """ Forward pass for the model. Represents one function evaluation of the ODE. In the below descriptions, B is batch size, T is sequence length, C is channel size. Note that the size of C and T may vary across arguments (ex. text_features vs. audio_features), it is used only to designate a Channel or time/sequence-length dimension respectively. Args: noisy_audio (torch.Tensor): Noisy audio input tensor (being denoised). audio_features (torch.Tensor): Clean audio features [B x T x C]. text_features (torch.Tensor): Encoded text features tensor [B x T x C]. time (torch.Tensor): Timestep tensor for positional encoding [B]. masked_video_features (Optional[torch.Tensor], optional): Masked video features tensor. [B x C x T]. text_mask (Optional[torch.Tensor], optional): Padding mask for text features. [B x T]. anchor_ids (Optional[torch.Tensor], optional): Anchor IDs tensor. Defaults to None [B x T]. anchor_alignment (Optional[torch.Tensor], optional): Anchor alignment tensor. B x T. audio_pad_mask (Optional[torch.Tensor], optional): Padding mask for audio input. [B x T]. Returns: torch.Tensor """ aligned_inputs = self.align_inputs( noisy_audio, audio_features, masked_video_features=masked_video_features, anchor_ids=anchor_ids, anchor_alignment=anchor_alignment, ) memory = timestep_emb = self.timestep_emb(time, pos=time).unsqueeze(1) if text_features is not None: memory = self.memory_proj(text_features) + timestep_emb return self.transformer( aligned_inputs, time, padding_mask=audio_pad_mask, memory=memory, memory_padding_mask=text_mask, ) def _get_audio_features(self, audios: torch.Tensor): audio_features = self.audio_codec(audios).transpose(1, 2) return torch.cat([audio_features, audio_features], dim=2) def _get_video_features(self, video, audio_features): B, T, _ = audio_features.shape if video is None: return audio_features.new_zeros(B, self.vision_encoder.dim, T) else: return self.vision_encoder(video).transpose(1, 2) def _repeat_for_reranking(self, tensor, candidates): if candidates > 1: B = tensor.size(0) rest = tensor.shape[1:] return ( tensor.unsqueeze(1) .expand(B, candidates, *rest) .reshape(B * candidates, *rest) ) else: return tensor def _unrepeat_from_reranking(self, tensor, candidates): return tensor[::candidates] def _get_forward_args(self, batch: Batch, candidates: int = 1): audio_features = self._get_audio_features(batch.audios) text_features, text_mask = self.text_encoder(batch.descriptions) masked_video_features = self._get_video_features( batch.masked_video, audio_features ) return { "audio_features": self._repeat_for_reranking(audio_features, candidates), "text_features": self._repeat_for_reranking(text_features, candidates), "text_mask": self._repeat_for_reranking(text_mask, candidates), "masked_video_features": self._repeat_for_reranking( masked_video_features, candidates ), "anchor_ids": self._repeat_for_reranking(batch.anchor_ids, candidates), "anchor_alignment": self._repeat_for_reranking( batch.anchor_alignment, candidates ), "audio_pad_mask": self._repeat_for_reranking( batch.audio_pad_mask, candidates ), } def predict_spans( self, batch: Batch, audio_features: torch.Tensor, audio_pad_mask: torch.Tensor ) -> Batch: input = self.span_predictor_transform(text=batch.descriptions).to( audio_features.device ) output = self.span_predictor( input_features=audio_features[:, :, :128], padding_mask=audio_pad_mask, return_spans=True, **input, ) anchors = [[["+"] + anchor for anchor in anchors] for anchors in output.spans] batch.process_anchors(anchors) return batch @torch.inference_mode() def separate( self, batch: Batch, noise: Optional[torch.Tensor] = None, ode_opt: Dict[str, Any] = DFLT_ODE_OPT, reranking_candidates: int = 1, predict_spans: bool = False, ) -> SeparationResult: # Encode audio forward_args = self._get_forward_args(batch, candidates=reranking_candidates) if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None: batch = self.predict_spans( batch=batch, audio_features=self._unrepeat_from_reranking( forward_args["audio_features"], reranking_candidates ), audio_pad_mask=self._unrepeat_from_reranking( forward_args["audio_pad_mask"], reranking_candidates ), ) audio_features = forward_args["audio_features"] B, T, C = audio_features.shape C = C // 2 # we stack audio_features, so the actual channels is half if noise is None: noise = torch.randn_like(audio_features) def vector_field(t, noisy_audio): res = self.forward( noisy_audio=noisy_audio, time=t.expand(noisy_audio.size(0)), **forward_args, ) return res states = odeint( vector_field, noise, torch.tensor([0.0, 1.0], device=noise.device), **ode_opt, ) generated_features = states[-1].transpose(1, 2) # generated_features has shape [B, 2C, T]. Reshape to stack along the batch dimension wavs = self.audio_codec.decode(generated_features.reshape(2 * B, C, T)).view( B, 2, -1 ) bsz = wavs.size(0) // reranking_candidates sizes = self.audio_codec.feature_idx_to_wav_idx(batch.sizes) target_wavs = self.unbatch( wavs[:, 0].view(bsz, reranking_candidates, -1), sizes ) residual_wavs = self.unbatch( wavs[:, 1].view(bsz, reranking_candidates, -1), sizes ) if ( reranking_candidates > 1 and batch.masked_video is not None and self.visual_ranker is not None ): scores = self.visual_ranker( extracted_audio=target_wavs, videos=batch.masked_video, sample_rate=self.audio_codec.sample_rate, ) idxs = scores.argmax(dim=1) elif reranking_candidates > 1 and self.text_ranker is not None: input_audio = [ audio[:, :size].expand(reranking_candidates, -1) for audio, size in zip(batch.audios, sizes, strict=False) ] scores = self.text_ranker( extracted_audio=target_wavs, input_audio=input_audio, descriptions=batch.descriptions, sample_rate=self.audio_codec.sample_rate, ) idxs = scores.argmax(dim=1) else: idxs = torch.zeros(bsz, dtype=torch.long, device=noise.device) return SeparationResult( target=[wav[idx] for wav, idx in zip(target_wavs, idxs, strict=False)], residual=[ wavs[idx] for wavs, idx in zip(residual_wavs, idxs, strict=False) ], noise=noise, ) def unbatch(self, wavs: torch.Tensor, sizes: torch.Tensor, time_dim: int = -1): result = [] for row, size in zip(wavs, sizes, strict=False): result.append(row.narrow(dim=time_dim, start=0, length=size)) return result def load_state_dict(self, state_dict, strict=True): if strict: missing_keys, unexpected_keys = super().load_state_dict( state_dict, strict=False ) # We load this directly from HF, not in checkpoint skip_regex = re.compile( "(^text_encoder|^visual_ranker|^text_ranker|^span_predictor)" ) missing_keys = [x for x in missing_keys if not re.search(skip_regex, x)] if len(missing_keys) > 0 or len(unexpected_keys) > 0: raise RuntimeError( f"Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}" ) __all__ = ["SAMAudio"]