|
|
""" |
|
|
AuriStream Configuration for HuggingFace Transformers. |
|
|
|
|
|
AuriStream is a speech language model by Greta Tuckute and Klemen Kotar. |
|
|
""" |
|
|
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
|
class AuriStreamConfig(PretrainedConfig): |
|
|
""" |
|
|
Configuration class for AuriStream models. |
|
|
|
|
|
This configuration supports various model sizes and prediction head configurations |
|
|
for the AuriStream speech language model family. |
|
|
|
|
|
Args: |
|
|
vocab_size (`int`, *optional*, defaults to 8192): |
|
|
Vocabulary size of the model (number of cochlear tokens). |
|
|
n_embd (`int`, *optional*, defaults to 768): |
|
|
Dimensionality of the embeddings and hidden states. |
|
|
n_layer (`int`, *optional*, defaults to 12): |
|
|
Number of transformer layers. |
|
|
n_head (`int`, *optional*, defaults to 12): |
|
|
Number of attention heads for each attention layer. |
|
|
n_pred_steps (`int`, *optional*, defaults to 1): |
|
|
Number of future prediction steps (multi-token prediction heads). |
|
|
dropout (`float`, *optional*, defaults to 0.0): |
|
|
Dropout probability for all fully connected layers. |
|
|
bias (`bool`, *optional*, defaults to False): |
|
|
Whether to use bias in linear layers. |
|
|
rope_theta (`float`, *optional*, defaults to 10000.0): |
|
|
Base theta for RoPE embeddings. |
|
|
input_conv_kernel_size (`int`, *optional*, defaults to 0): |
|
|
Kernel size for input convolution layer (0 means no input conv). |
|
|
""" |
|
|
|
|
|
model_type = "AuriStream" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 8192, |
|
|
n_embd: int = 768, |
|
|
n_layer: int = 12, |
|
|
n_head: int = 12, |
|
|
n_pred_steps: int = 1, |
|
|
dropout: float = 0.0, |
|
|
bias: bool = False, |
|
|
rope_theta: float = 10000.0, |
|
|
input_conv_kernel_size: int = 0, |
|
|
**kwargs, |
|
|
): |
|
|
self.vocab_size = vocab_size |
|
|
self.n_embd = n_embd |
|
|
self.n_layer = n_layer |
|
|
self.n_head = n_head |
|
|
self.n_pred_steps = n_pred_steps |
|
|
self.dropout = dropout |
|
|
self.bias = bias |
|
|
self.rope_theta = rope_theta |
|
|
self.input_conv_kernel_size = input_conv_kernel_size |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
@classmethod |
|
|
def from_local_config(cls, local_cfg): |
|
|
""" |
|
|
Create an AuriStreamConfig from a local dataclass config. |
|
|
|
|
|
Args: |
|
|
local_cfg: A dataclass config object (e.g., AuriStream100M20PredConfig) |
|
|
|
|
|
Returns: |
|
|
AuriStreamConfig instance |
|
|
""" |
|
|
config_dict = {} |
|
|
|
|
|
|
|
|
known_attrs = [ |
|
|
'vocab_size', 'n_embd', 'n_layer', 'n_head', 'n_pred_steps', |
|
|
'dropout', 'bias', 'rope_theta', 'input_conv_kernel_size' |
|
|
] |
|
|
|
|
|
for attr in known_attrs: |
|
|
if hasattr(local_cfg, attr): |
|
|
config_dict[attr] = getattr(local_cfg, attr) |
|
|
|
|
|
|
|
|
if 'n_pred_steps' not in config_dict: |
|
|
config_dict['n_pred_steps'] = 1 |
|
|
|
|
|
return cls(**config_dict) |
|
|
|