ray-006's picture
Update sam_audio/model/model.py
f37d52e verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import math
import re
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
from torchdiffeq import odeint
from sam_audio.model.align import AlignModalities
from sam_audio.model.base import BaseModel
from sam_audio.model.codec import DACVAE
from sam_audio.model.config import SAMAudioConfig
from sam_audio.model.text_encoder import T5TextEncoder
from sam_audio.model.transformer import DiT
from sam_audio.model.vision_encoder import PerceptionEncoder
from sam_audio.processor import Batch
from sam_audio.ranking import create_ranker
DFLT_ODE_OPT = {"method": "midpoint", "options": {"step_size": 2 / 32}}
class SinusoidalEmbedding(torch.nn.Module):
def __init__(self, dim, theta=10000):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
inv_freq = torch.exp(
-math.log(theta) * torch.arange(half_dim).float() / half_dim
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, pos=None):
if pos is None:
seq_len, device = x.shape[1], x.device
pos = torch.arange(seq_len, device=device)
emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
return emb
class EmbedAnchors(torch.nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
super().__init__()
self.embed = torch.nn.Embedding(
num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
)
self.gate = torch.nn.Parameter(torch.tensor([0.0]))
self.proj = torch.nn.Linear(embedding_dim, out_dim, bias=False)
def forward(
self,
x: torch.Tensor,
anchor_ids: Optional[torch.Tensor] = None,
anchor_alignment: Optional[torch.Tensor] = None,
):
if anchor_ids is None:
return x
embs = self.embed(anchor_ids.gather(1, anchor_alignment))
proj = self.proj(embs)
return x + self.gate.tanh() * proj
@dataclass
class SeparationResult:
target: torch.Tensor
residual: torch.Tensor
noise: torch.Tensor
class SAMAudio(BaseModel):
config_cls = SAMAudioConfig
revision = None
def __init__(self, cfg: SAMAudioConfig):
super().__init__()
self.audio_codec = DACVAE(cfg.audio_codec)
self.text_encoder = T5TextEncoder(cfg.text_encoder)
self.vision_encoder = PerceptionEncoder(cfg.vision_encoder)
self.transformer = DiT(cfg.transformer)
self.proj = torch.nn.Linear(cfg.in_channels, cfg.transformer.dim)
self.align_masked_video = AlignModalities(
cfg.vision_encoder.dim, cfg.transformer.dim
)
self.embed_anchors = EmbedAnchors(
cfg.num_anchors, cfg.anchor_embedding_dim, cfg.transformer.dim
)
self.memory_proj = torch.nn.Linear(cfg.text_encoder.dim, cfg.transformer.dim)
self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
self.visual_ranker = create_ranker(cfg.visual_ranker)
self.text_ranker = create_ranker(cfg.text_ranker)
if cfg.span_predictor is not None:
self.span_predictor = PEAudioFrame.from_config(
cfg.span_predictor, pretrained=True
)
self.span_predictor_transform = PEAudioFrameTransform.from_config(
cfg.span_predictor
)
@property
def sample_rate(self):
return self.audio_codec.sample_rate
def align_inputs(
self,
noisy_audio,
audio_features: torch.Tensor,
masked_video_features: Optional[torch.Tensor] = None,
anchor_ids: Optional[torch.Tensor] = None,
anchor_alignment: Optional[torch.Tensor] = None,
):
x = torch.cat(
[
noisy_audio,
torch.zeros_like(audio_features),
audio_features,
],
dim=2,
)
projected = self.proj(x)
aligned = self.align_masked_video(projected, masked_video_features)
aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
return aligned
def forward(
self,
noisy_audio: torch.Tensor,
audio_features: torch.Tensor,
text_features: torch.Tensor,
time: torch.Tensor,
masked_video_features: Optional[torch.Tensor] = None,
text_mask: Optional[torch.Tensor] = None,
anchor_ids: Optional[torch.Tensor] = None,
anchor_alignment: Optional[torch.Tensor] = None,
audio_pad_mask: Optional[torch.Tensor] = None,
):
"""
Forward pass for the model. Represents one function evaluation of the ODE.
In the below descriptions, B is batch size, T is sequence length, C is channel size.
Note that the size of C and T may vary across arguments (ex. text_features vs. audio_features),
it is used only to designate a Channel or time/sequence-length dimension respectively.
Args:
noisy_audio (torch.Tensor): Noisy audio input tensor (being denoised).
audio_features (torch.Tensor): Clean audio features [B x T x C].
text_features (torch.Tensor): Encoded text features tensor [B x T x C].
time (torch.Tensor): Timestep tensor for positional encoding [B].
masked_video_features (Optional[torch.Tensor], optional): Masked video features tensor. [B x C x T].
text_mask (Optional[torch.Tensor], optional): Padding mask for text features. [B x T].
anchor_ids (Optional[torch.Tensor], optional): Anchor IDs tensor. Defaults to None [B x T].
anchor_alignment (Optional[torch.Tensor], optional): Anchor alignment tensor. B x T.
audio_pad_mask (Optional[torch.Tensor], optional): Padding mask for audio input. [B x T].
Returns:
torch.Tensor
"""
aligned_inputs = self.align_inputs(
noisy_audio,
audio_features,
masked_video_features=masked_video_features,
anchor_ids=anchor_ids,
anchor_alignment=anchor_alignment,
)
memory = timestep_emb = self.timestep_emb(time, pos=time).unsqueeze(1)
if text_features is not None:
memory = self.memory_proj(text_features) + timestep_emb
return self.transformer(
aligned_inputs,
time,
padding_mask=audio_pad_mask,
memory=memory,
memory_padding_mask=text_mask,
)
def _get_audio_features(self, audios: torch.Tensor):
audio_features = self.audio_codec(audios).transpose(1, 2)
return torch.cat([audio_features, audio_features], dim=2)
def _get_video_features(self, video, audio_features):
B, T, _ = audio_features.shape
if video is None:
return audio_features.new_zeros(B, self.vision_encoder.dim, T)
else:
return self.vision_encoder(video).transpose(1, 2)
def _repeat_for_reranking(self, tensor, candidates):
if candidates > 1:
B = tensor.size(0)
rest = tensor.shape[1:]
return (
tensor.unsqueeze(1)
.expand(B, candidates, *rest)
.reshape(B * candidates, *rest)
)
else:
return tensor
def _unrepeat_from_reranking(self, tensor, candidates):
return tensor[::candidates]
def _get_forward_args(self, batch: Batch, candidates: int = 1):
audio_features = self._get_audio_features(batch.audios)
text_features, text_mask = self.text_encoder(batch.descriptions)
masked_video_features = self._get_video_features(
batch.masked_video, audio_features
)
return {
"audio_features": self._repeat_for_reranking(audio_features, candidates),
"text_features": self._repeat_for_reranking(text_features, candidates),
"text_mask": self._repeat_for_reranking(text_mask, candidates),
"masked_video_features": self._repeat_for_reranking(
masked_video_features, candidates
),
"anchor_ids": self._repeat_for_reranking(batch.anchor_ids, candidates),
"anchor_alignment": self._repeat_for_reranking(
batch.anchor_alignment, candidates
),
"audio_pad_mask": self._repeat_for_reranking(
batch.audio_pad_mask, candidates
),
}
def predict_spans(
self, batch: Batch, audio_features: torch.Tensor, audio_pad_mask: torch.Tensor
) -> Batch:
input = self.span_predictor_transform(text=batch.descriptions).to(
audio_features.device
)
output = self.span_predictor(
input_features=audio_features[:, :, :128],
padding_mask=audio_pad_mask,
return_spans=True,
**input,
)
anchors = [[["+"] + anchor for anchor in anchors] for anchors in output.spans]
batch.process_anchors(anchors)
return batch
@torch.inference_mode()
def separate(
self,
batch: Batch,
noise: Optional[torch.Tensor] = None,
ode_opt: Dict[str, Any] = DFLT_ODE_OPT,
reranking_candidates: int = 1,
predict_spans: bool = False,
) -> SeparationResult:
# Encode audio
forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
batch = self.predict_spans(
batch=batch,
audio_features=self._unrepeat_from_reranking(
forward_args["audio_features"], reranking_candidates
),
audio_pad_mask=self._unrepeat_from_reranking(
forward_args["audio_pad_mask"], reranking_candidates
),
)
audio_features = forward_args["audio_features"]
B, T, C = audio_features.shape
C = C // 2 # we stack audio_features, so the actual channels is half
if noise is None:
noise = torch.randn_like(audio_features)
def vector_field(t, noisy_audio):
res = self.forward(
noisy_audio=noisy_audio,
time=t.expand(noisy_audio.size(0)),
**forward_args,
)
return res
states = odeint(
vector_field,
noise,
torch.tensor([0.0, 1.0], device=noise.device),
**ode_opt,
)
generated_features = states[-1].transpose(1, 2)
# generated_features has shape [B, 2C, T]. Reshape to stack along the batch dimension
wavs = self.audio_codec.decode(generated_features.reshape(2 * B, C, T)).view(
B, 2, -1
)
bsz = wavs.size(0) // reranking_candidates
sizes = self.audio_codec.feature_idx_to_wav_idx(batch.sizes)
target_wavs = self.unbatch(
wavs[:, 0].view(bsz, reranking_candidates, -1), sizes
)
residual_wavs = self.unbatch(
wavs[:, 1].view(bsz, reranking_candidates, -1), sizes
)
if (
reranking_candidates > 1
and batch.masked_video is not None
and self.visual_ranker is not None
):
scores = self.visual_ranker(
extracted_audio=target_wavs,
videos=batch.masked_video,
sample_rate=self.audio_codec.sample_rate,
)
idxs = scores.argmax(dim=1)
elif reranking_candidates > 1 and self.text_ranker is not None:
input_audio = [
audio[:, :size].expand(reranking_candidates, -1)
for audio, size in zip(batch.audios, sizes, strict=False)
]
scores = self.text_ranker(
extracted_audio=target_wavs,
input_audio=input_audio,
descriptions=batch.descriptions,
sample_rate=self.audio_codec.sample_rate,
)
idxs = scores.argmax(dim=1)
else:
idxs = torch.zeros(bsz, dtype=torch.long, device=noise.device)
return SeparationResult(
target=[wav[idx] for wav, idx in zip(target_wavs, idxs, strict=False)],
residual=[
wavs[idx] for wavs, idx in zip(residual_wavs, idxs, strict=False)
],
noise=noise,
)
def unbatch(self, wavs: torch.Tensor, sizes: torch.Tensor, time_dim: int = -1):
result = []
for row, size in zip(wavs, sizes, strict=False):
result.append(row.narrow(dim=time_dim, start=0, length=size))
return result
def load_state_dict(self, state_dict, strict=True):
if strict:
missing_keys, unexpected_keys = super().load_state_dict(
state_dict, strict=False
)
# We load this directly from HF, not in checkpoint
skip_regex = re.compile(
"(^text_encoder|^visual_ranker|^text_ranker|^span_predictor)"
)
missing_keys = [x for x in missing_keys if not re.search(skip_regex, x)]
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
raise RuntimeError(
f"Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}"
)
__all__ = ["SAMAudio"]