AuriStream-base / configuration_auristream.py
klemenk's picture
Upload AuriStream base model code
9c3e596 verified
"""
AuriStream Configuration for HuggingFace Transformers.
AuriStream is a speech language model by Greta Tuckute and Klemen Kotar.
"""
from transformers import PretrainedConfig
class AuriStreamConfig(PretrainedConfig):
"""
Configuration class for AuriStream models.
This configuration supports various model sizes and prediction head configurations
for the AuriStream speech language model family.
Args:
vocab_size (`int`, *optional*, defaults to 8192):
Vocabulary size of the model (number of cochlear tokens).
n_embd (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 12):
Number of transformer layers.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer.
n_pred_steps (`int`, *optional*, defaults to 1):
Number of future prediction steps (multi-token prediction heads).
dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for all fully connected layers.
bias (`bool`, *optional*, defaults to False):
Whether to use bias in linear layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
Base theta for RoPE embeddings.
input_conv_kernel_size (`int`, *optional*, defaults to 0):
Kernel size for input convolution layer (0 means no input conv).
"""
model_type = "AuriStream"
def __init__(
self,
vocab_size: int = 8192,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
n_pred_steps: int = 1,
dropout: float = 0.0,
bias: bool = False,
rope_theta: float = 10000.0,
input_conv_kernel_size: int = 0,
**kwargs,
):
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_pred_steps = n_pred_steps
self.dropout = dropout
self.bias = bias
self.rope_theta = rope_theta
self.input_conv_kernel_size = input_conv_kernel_size
super().__init__(**kwargs)
@classmethod
def from_local_config(cls, local_cfg):
"""
Create an AuriStreamConfig from a local dataclass config.
Args:
local_cfg: A dataclass config object (e.g., AuriStream100M20PredConfig)
Returns:
AuriStreamConfig instance
"""
config_dict = {}
# Map all known attributes
known_attrs = [
'vocab_size', 'n_embd', 'n_layer', 'n_head', 'n_pred_steps',
'dropout', 'bias', 'rope_theta', 'input_conv_kernel_size'
]
for attr in known_attrs:
if hasattr(local_cfg, attr):
config_dict[attr] = getattr(local_cfg, attr)
# Handle n_pred_steps default (if not present, it's 1)
if 'n_pred_steps' not in config_dict:
config_dict['n_pred_steps'] = 1
return cls(**config_dict)