""" AuriStream Configuration for HuggingFace Transformers. AuriStream is a speech language model by Greta Tuckute and Klemen Kotar. """ from transformers import PretrainedConfig class AuriStreamConfig(PretrainedConfig): """ Configuration class for AuriStream models. This configuration supports various model sizes and prediction head configurations for the AuriStream speech language model family. Args: vocab_size (`int`, *optional*, defaults to 8192): Vocabulary size of the model (number of cochlear tokens). n_embd (`int`, *optional*, defaults to 768): Dimensionality of the embeddings and hidden states. n_layer (`int`, *optional*, defaults to 12): Number of transformer layers. n_head (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer. n_pred_steps (`int`, *optional*, defaults to 1): Number of future prediction steps (multi-token prediction heads). dropout (`float`, *optional*, defaults to 0.0): Dropout probability for all fully connected layers. bias (`bool`, *optional*, defaults to False): Whether to use bias in linear layers. rope_theta (`float`, *optional*, defaults to 10000.0): Base theta for RoPE embeddings. input_conv_kernel_size (`int`, *optional*, defaults to 0): Kernel size for input convolution layer (0 means no input conv). """ model_type = "AuriStream" def __init__( self, vocab_size: int = 8192, n_embd: int = 768, n_layer: int = 12, n_head: int = 12, n_pred_steps: int = 1, dropout: float = 0.0, bias: bool = False, rope_theta: float = 10000.0, input_conv_kernel_size: int = 0, **kwargs, ): self.vocab_size = vocab_size self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.n_pred_steps = n_pred_steps self.dropout = dropout self.bias = bias self.rope_theta = rope_theta self.input_conv_kernel_size = input_conv_kernel_size super().__init__(**kwargs) @classmethod def from_local_config(cls, local_cfg): """ Create an AuriStreamConfig from a local dataclass config. Args: local_cfg: A dataclass config object (e.g., AuriStream100M20PredConfig) Returns: AuriStreamConfig instance """ config_dict = {} # Map all known attributes known_attrs = [ 'vocab_size', 'n_embd', 'n_layer', 'n_head', 'n_pred_steps', 'dropout', 'bias', 'rope_theta', 'input_conv_kernel_size' ] for attr in known_attrs: if hasattr(local_cfg, attr): config_dict[attr] = getattr(local_cfg, attr) # Handle n_pred_steps default (if not present, it's 1) if 'n_pred_steps' not in config_dict: config_dict['n_pred_steps'] = 1 return cls(**config_dict)