Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig | |
| from transformers import ModernBertConfig | |
| class DACVAEConfig: | |
| def __init__( | |
| self, | |
| encoder_dim: int = 64, | |
| encoder_rates: list[int] = [2, 8, 10, 12], | |
| latent_dim: int = 1024, | |
| decoder_dim: int = 1536, | |
| decoder_rates: list[int] = [12, 10, 8, 2], | |
| n_codebooks: int = 16, | |
| codebook_size: int = 1024, | |
| codebook_dim: int = 128, | |
| quantizer_dropout: bool = False, | |
| sample_rate: int = 48_000, | |
| mean: float = 0.0, | |
| std: float = 1.0, | |
| ): | |
| self.encoder_dim = encoder_dim | |
| self.encoder_rates = encoder_rates | |
| self.latent_dim = latent_dim | |
| self.decoder_dim = decoder_dim | |
| self.decoder_rates = decoder_rates | |
| self.n_codebooks = n_codebooks | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.quantizer_dropout = quantizer_dropout | |
| self.sample_rate = sample_rate | |
| self.mean = mean | |
| self.std = std | |
| def hop_length(self): | |
| return int(np.prod(self.encoder_rates)) | |
| class TextEncoderConfig: | |
| def __init__(self, dim: int = 768): | |
| self.dim = dim | |
| class T5EncoderConfig(TextEncoderConfig): | |
| def __init__( | |
| self, | |
| name: str = "t5-base", | |
| max_length: Optional[int] = 512, | |
| pad_mode: str = "longest", | |
| dim: int = 768, | |
| ): | |
| super().__init__(dim=dim) | |
| self.name = name | |
| self.max_length = max_length | |
| self.pad_mode = pad_mode | |
| class VisionEncoderConfig: | |
| def __init__(self, dim: int = 1024, batch_size: int = 300): | |
| self.dim = dim | |
| self.batch_size = batch_size | |
| class PerceptionEncoderConfig(VisionEncoderConfig): | |
| def __init__( | |
| self, | |
| dim: int = 1024, | |
| batch_size: int = 300, | |
| name: str = "PE-Core-L14-336", | |
| normalize_feature: bool = True, | |
| interpolation_mode: str = "BICUBIC", | |
| image_size: int = 336, | |
| ): | |
| super().__init__(dim=dim, batch_size=batch_size) | |
| self.name = name | |
| self.normalize_feature = normalize_feature | |
| self.interpolation_mode = interpolation_mode | |
| self.image_size = image_size | |
| class TransformerConfig: | |
| def __init__( | |
| self, | |
| dim: int = 2048, | |
| n_heads: int = 16, | |
| n_layers: int = 16, | |
| dropout: float = 0.1, | |
| norm_eps: float = 1.0e-05, | |
| qk_norm: bool = True, | |
| fc_bias: bool = False, | |
| ffn_exp: int = 4, | |
| ffn_dim_multiplier: int = 1, | |
| multiple_of: int = 64, | |
| non_linearity: str = "swiglu", | |
| use_rope: bool = True, | |
| max_positions: int = 10000, | |
| frequency_embedding_dim: int = 256, | |
| timestep_non_linearity: str = "swiglu", | |
| t_block_non_linearity: str = "silu", | |
| t_block_bias: bool = True, | |
| context_dim: int = 2048, | |
| context_non_linearity: str = "swiglu", | |
| context_embedder_dropout: float = 0.0, | |
| context_norm: bool = False, | |
| out_channels: int = 256, | |
| in_channels: Optional[int] = None, | |
| ): | |
| self.dim = dim | |
| self.n_heads = n_heads | |
| self.n_layers = n_layers | |
| self.dropout = dropout | |
| self.norm_eps = norm_eps | |
| self.qk_norm = qk_norm | |
| self.fc_bias = fc_bias | |
| self.ffn_exp = ffn_exp | |
| self.ffn_dim_multiplier = ffn_dim_multiplier | |
| self.multiple_of = multiple_of | |
| self.non_linearity = non_linearity | |
| self.use_rope = use_rope | |
| self.max_positions = max_positions | |
| self.frequency_embedding_dim = frequency_embedding_dim | |
| self.timestep_non_linearity = timestep_non_linearity | |
| self.t_block_non_linearity = t_block_non_linearity | |
| self.t_block_bias = t_block_bias | |
| self.context_dim = context_dim | |
| self.context_non_linearity = context_non_linearity | |
| self.context_embedder_dropout = context_embedder_dropout | |
| self.context_norm = context_norm | |
| self.out_channels = out_channels | |
| self.in_channels = in_channels | |
| class RankerConfig: | |
| kind: str | |
| class ImageBindRankerConfig(RankerConfig): | |
| kind: str = "imagebind" | |
| def __init__(self, checkpoint: Optional[str] = None): | |
| self.checkpoint = checkpoint | |
| class ClapRankerConfig(RankerConfig): | |
| kind: str = "clap" | |
| def __init__(self, checkpoint: Optional[str] = None): | |
| self.checkpoint = checkpoint | |
| class JudgeRankerConfig(RankerConfig): | |
| kind: str = "judge" | |
| def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"): | |
| self.checkpoint_or_model_id = checkpoint_or_model_id | |
| class SoundActivityRankerConfig(RankerConfig): | |
| kind: str = "sound_activity" | |
| def __init__( | |
| self, | |
| threshold_mode: str = "rel_to_max", | |
| sil_threshold: float = -40, | |
| metric: str = "iou", | |
| ): | |
| self.threshold_mode = threshold_mode | |
| self.sil_threshold = sil_threshold | |
| self.metric = metric | |
| class EnsembleRankerConfig(RankerConfig): | |
| kind: str = "ensemble" | |
| def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]): | |
| self.rankers = rankers | |
| def parse_ranker_config(config_dict: dict): | |
| kind = config_dict.pop("kind") | |
| match kind: | |
| case ImageBindRankerConfig.kind: | |
| return ImageBindRankerConfig(**config_dict) | |
| case ClapRankerConfig.kind: | |
| return ClapRankerConfig(**config_dict) | |
| case JudgeRankerConfig.kind: | |
| return JudgeRankerConfig(**config_dict) | |
| case SoundActivityRankerConfig.kind: | |
| return SoundActivityRankerConfig(**config_dict) | |
| case EnsembleRankerConfig.kind: | |
| return EnsembleRankerConfig( | |
| { | |
| k: (parse_ranker_config(v), w) | |
| for k, (v, w) in config_dict["rankers"].items() | |
| } | |
| ) | |
| class SAMAudioConfig: | |
| def __init__( | |
| self, | |
| in_channels: int = 768, | |
| audio_codec=None, | |
| text_encoder=None, | |
| vision_encoder=None, | |
| transformer=None, | |
| num_anchors: int = 3, | |
| anchor_embedding_dim: int = 128, | |
| visual_ranker=None, | |
| text_ranker=None, | |
| span_predictor: Optional[str] = "pe-a-frame-large", | |
| ): | |
| self.in_channels = in_channels | |
| self.audio_codec = DACVAEConfig(**(audio_codec or {})) | |
| self.text_encoder = T5EncoderConfig(**(text_encoder or {})) | |
| self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {})) | |
| self.transformer = TransformerConfig(**(transformer or {})) | |
| self.num_anchors = num_anchors | |
| self.anchor_embedding_dim = anchor_embedding_dim | |
| self.visual_ranker = ( | |
| None if visual_ranker is None else parse_ranker_config(visual_ranker) | |
| ) | |
| self.text_ranker = ( | |
| None if text_ranker is None else parse_ranker_config(text_ranker) | |
| ) | |
| self.span_predictor = span_predictor | |
| class SAMAudioJudgeConfig: | |
| def __init__( | |
| self, | |
| audio_codec: DACVAEConfig = None, | |
| transformer: PEAVTransformerConfig = None, | |
| text_model: ModernBertConfig = None, | |
| finetune_transformer: PEAVTransformerConfig = None, | |
| nth_text_layer: int = 22, | |
| bottleneck_dim: int = 256, | |
| ): | |
| self.audio_codec = DACVAEConfig(**(audio_codec or {})) | |
| self.transformer = PEAVTransformerConfig(**(transformer or {})) | |
| self.text_model = ModernBertConfig(**(text_model or {})) | |
| self.finetune_transformer = PEAVTransformerConfig( | |
| **(finetune_transformer or {}) | |
| ) | |
| self.nth_text_layer = nth_text_layer | |
| self.bottleneck_dim = bottleneck_dim | |