sam-audio-large-onnx / onnx_export /standalone_config.py
matbee's picture
Upload folder using huggingface_hub
07823f7 verified
"""
Standalone configuration classes for ONNX export.
These are copied from sam_audio/model/config.py but without the problematic
imports that require the 'perception_models' library.
"""
from typing import Optional
import numpy as np
class DACVAEConfig:
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: list[int] = [2, 8, 10, 12],
latent_dim: int = 1024,
decoder_dim: int = 1536,
decoder_rates: list[int] = [12, 10, 8, 2],
n_codebooks: int = 16,
codebook_size: int = 1024,
codebook_dim: int = 128,
quantizer_dropout: bool = False,
sample_rate: int = 48_000,
mean: float = 0.0,
std: float = 1.0,
):
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.latent_dim = latent_dim
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.sample_rate = sample_rate
self.mean = mean
self.std = std
@property
def hop_length(self):
return int(np.prod(self.encoder_rates))
class T5EncoderConfig:
def __init__(
self,
name: str = "t5-base",
max_length: Optional[int] = 512,
pad_mode: str = "longest",
dim: int = 768,
):
self.dim = dim
self.name = name
self.max_length = max_length
self.pad_mode = pad_mode
class VisionEncoderConfig:
def __init__(self, dim: int = 1024, batch_size: int = 300):
self.dim = dim
self.batch_size = batch_size
class PerceptionEncoderConfig(VisionEncoderConfig):
def __init__(
self,
dim: int = 1024,
batch_size: int = 300,
name: str = "PE-Core-L14-336",
normalize_feature: bool = True,
interpolation_mode: str = "BICUBIC",
image_size: int = 336,
):
super().__init__(dim=dim, batch_size=batch_size)
self.name = name
self.normalize_feature = normalize_feature
self.interpolation_mode = interpolation_mode
self.image_size = image_size
class TransformerConfig:
"""Configuration for the DiT transformer."""
def __init__(
self,
dim: int = 2048,
n_heads: int = 16,
n_layers: int = 16,
dropout: float = 0.1,
norm_eps: float = 1.0e-05,
qk_norm: bool = True,
fc_bias: bool = False,
ffn_exp: int = 4,
ffn_dim_multiplier: int = 1,
multiple_of: int = 64,
non_linearity: str = "swiglu",
use_rope: bool = True,
max_positions: int = 10000,
frequency_embedding_dim: int = 256,
timestep_non_linearity: str = "swiglu",
t_block_non_linearity: str = "silu",
t_block_bias: bool = True,
context_dim: int = 2048,
context_non_linearity: str = "swiglu",
context_embedder_dropout: float = 0.0,
context_norm: bool = False,
out_channels: int = 256,
in_channels: Optional[int] = None,
):
self.dim = dim
self.n_heads = n_heads
self.n_layers = n_layers
self.dropout = dropout
self.norm_eps = norm_eps
self.qk_norm = qk_norm
self.fc_bias = fc_bias
self.ffn_exp = ffn_exp
self.ffn_dim_multiplier = ffn_dim_multiplier
self.multiple_of = multiple_of
self.non_linearity = non_linearity
self.use_rope = use_rope
self.max_positions = max_positions
self.frequency_embedding_dim = frequency_embedding_dim
self.timestep_non_linearity = timestep_non_linearity
self.t_block_non_linearity = t_block_non_linearity
self.t_block_bias = t_block_bias
self.context_dim = context_dim
self.context_non_linearity = context_non_linearity
self.context_embedder_dropout = context_embedder_dropout
self.context_norm = context_norm
self.out_channels = out_channels
self.in_channels = in_channels
@property
def d_model(self):
"""Alias for dim, used in transformer code."""
return self.dim