ray-006's picture
Update sam_audio/model/config.py
e387e4c verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
from typing import Optional, Tuple
import numpy as np
from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig
from transformers import ModernBertConfig
class DACVAEConfig:
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: list[int] = [2, 8, 10, 12],
latent_dim: int = 1024,
decoder_dim: int = 1536,
decoder_rates: list[int] = [12, 10, 8, 2],
n_codebooks: int = 16,
codebook_size: int = 1024,
codebook_dim: int = 128,
quantizer_dropout: bool = False,
sample_rate: int = 48_000,
mean: float = 0.0,
std: float = 1.0,
):
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.latent_dim = latent_dim
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.sample_rate = sample_rate
self.mean = mean
self.std = std
@property
def hop_length(self):
return int(np.prod(self.encoder_rates))
class TextEncoderConfig:
def __init__(self, dim: int = 768):
self.dim = dim
class T5EncoderConfig(TextEncoderConfig):
def __init__(
self,
name: str = "t5-base",
max_length: Optional[int] = 512,
pad_mode: str = "longest",
dim: int = 768,
):
super().__init__(dim=dim)
self.name = name
self.max_length = max_length
self.pad_mode = pad_mode
class VisionEncoderConfig:
def __init__(self, dim: int = 1024, batch_size: int = 300):
self.dim = dim
self.batch_size = batch_size
class PerceptionEncoderConfig(VisionEncoderConfig):
def __init__(
self,
dim: int = 1024,
batch_size: int = 300,
name: str = "PE-Core-L14-336",
normalize_feature: bool = True,
interpolation_mode: str = "BICUBIC",
image_size: int = 336,
):
super().__init__(dim=dim, batch_size=batch_size)
self.name = name
self.normalize_feature = normalize_feature
self.interpolation_mode = interpolation_mode
self.image_size = image_size
class TransformerConfig:
def __init__(
self,
dim: int = 2048,
n_heads: int = 16,
n_layers: int = 16,
dropout: float = 0.1,
norm_eps: float = 1.0e-05,
qk_norm: bool = True,
fc_bias: bool = False,
ffn_exp: int = 4,
ffn_dim_multiplier: int = 1,
multiple_of: int = 64,
non_linearity: str = "swiglu",
use_rope: bool = True,
max_positions: int = 10000,
frequency_embedding_dim: int = 256,
timestep_non_linearity: str = "swiglu",
t_block_non_linearity: str = "silu",
t_block_bias: bool = True,
context_dim: int = 2048,
context_non_linearity: str = "swiglu",
context_embedder_dropout: float = 0.0,
context_norm: bool = False,
out_channels: int = 256,
in_channels: Optional[int] = None,
):
self.dim = dim
self.n_heads = n_heads
self.n_layers = n_layers
self.dropout = dropout
self.norm_eps = norm_eps
self.qk_norm = qk_norm
self.fc_bias = fc_bias
self.ffn_exp = ffn_exp
self.ffn_dim_multiplier = ffn_dim_multiplier
self.multiple_of = multiple_of
self.non_linearity = non_linearity
self.use_rope = use_rope
self.max_positions = max_positions
self.frequency_embedding_dim = frequency_embedding_dim
self.timestep_non_linearity = timestep_non_linearity
self.t_block_non_linearity = t_block_non_linearity
self.t_block_bias = t_block_bias
self.context_dim = context_dim
self.context_non_linearity = context_non_linearity
self.context_embedder_dropout = context_embedder_dropout
self.context_norm = context_norm
self.out_channels = out_channels
self.in_channels = in_channels
class RankerConfig:
kind: str
class ImageBindRankerConfig(RankerConfig):
kind: str = "imagebind"
def __init__(self, checkpoint: Optional[str] = None):
self.checkpoint = checkpoint
class ClapRankerConfig(RankerConfig):
kind: str = "clap"
def __init__(self, checkpoint: Optional[str] = None):
self.checkpoint = checkpoint
class JudgeRankerConfig(RankerConfig):
kind: str = "judge"
def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"):
self.checkpoint_or_model_id = checkpoint_or_model_id
class SoundActivityRankerConfig(RankerConfig):
kind: str = "sound_activity"
def __init__(
self,
threshold_mode: str = "rel_to_max",
sil_threshold: float = -40,
metric: str = "iou",
):
self.threshold_mode = threshold_mode
self.sil_threshold = sil_threshold
self.metric = metric
class EnsembleRankerConfig(RankerConfig):
kind: str = "ensemble"
def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]):
self.rankers = rankers
def parse_ranker_config(config_dict: dict):
kind = config_dict.pop("kind")
match kind:
case ImageBindRankerConfig.kind:
return ImageBindRankerConfig(**config_dict)
case ClapRankerConfig.kind:
return ClapRankerConfig(**config_dict)
case JudgeRankerConfig.kind:
return JudgeRankerConfig(**config_dict)
case SoundActivityRankerConfig.kind:
return SoundActivityRankerConfig(**config_dict)
case EnsembleRankerConfig.kind:
return EnsembleRankerConfig(
{
k: (parse_ranker_config(v), w)
for k, (v, w) in config_dict["rankers"].items()
}
)
class SAMAudioConfig:
def __init__(
self,
in_channels: int = 768,
audio_codec=None,
text_encoder=None,
vision_encoder=None,
transformer=None,
num_anchors: int = 3,
anchor_embedding_dim: int = 128,
visual_ranker=None,
text_ranker=None,
span_predictor: Optional[str] = "pe-a-frame-large",
):
self.in_channels = in_channels
self.audio_codec = DACVAEConfig(**(audio_codec or {}))
self.text_encoder = T5EncoderConfig(**(text_encoder or {}))
self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {}))
self.transformer = TransformerConfig(**(transformer or {}))
self.num_anchors = num_anchors
self.anchor_embedding_dim = anchor_embedding_dim
self.visual_ranker = (
None if visual_ranker is None else parse_ranker_config(visual_ranker)
)
self.text_ranker = (
None if text_ranker is None else parse_ranker_config(text_ranker)
)
self.span_predictor = span_predictor
class SAMAudioJudgeConfig:
def __init__(
self,
audio_codec: DACVAEConfig = None,
transformer: PEAVTransformerConfig = None,
text_model: ModernBertConfig = None,
finetune_transformer: PEAVTransformerConfig = None,
nth_text_layer: int = 22,
bottleneck_dim: int = 256,
):
self.audio_codec = DACVAEConfig(**(audio_codec or {}))
self.transformer = PEAVTransformerConfig(**(transformer or {}))
self.text_model = ModernBertConfig(**(text_model or {}))
self.finetune_transformer = PEAVTransformerConfig(
**(finetune_transformer or {})
)
self.nth_text_layer = nth_text_layer
self.bottleneck_dim = bottleneck_dim