File size: 3,183 Bytes
9c3e596 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""
AuriStream Configuration for HuggingFace Transformers.
AuriStream is a speech language model by Greta Tuckute and Klemen Kotar.
"""
from transformers import PretrainedConfig
class AuriStreamConfig(PretrainedConfig):
"""
Configuration class for AuriStream models.
This configuration supports various model sizes and prediction head configurations
for the AuriStream speech language model family.
Args:
vocab_size (`int`, *optional*, defaults to 8192):
Vocabulary size of the model (number of cochlear tokens).
n_embd (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 12):
Number of transformer layers.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer.
n_pred_steps (`int`, *optional*, defaults to 1):
Number of future prediction steps (multi-token prediction heads).
dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for all fully connected layers.
bias (`bool`, *optional*, defaults to False):
Whether to use bias in linear layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
Base theta for RoPE embeddings.
input_conv_kernel_size (`int`, *optional*, defaults to 0):
Kernel size for input convolution layer (0 means no input conv).
"""
model_type = "AuriStream"
def __init__(
self,
vocab_size: int = 8192,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
n_pred_steps: int = 1,
dropout: float = 0.0,
bias: bool = False,
rope_theta: float = 10000.0,
input_conv_kernel_size: int = 0,
**kwargs,
):
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_pred_steps = n_pred_steps
self.dropout = dropout
self.bias = bias
self.rope_theta = rope_theta
self.input_conv_kernel_size = input_conv_kernel_size
super().__init__(**kwargs)
@classmethod
def from_local_config(cls, local_cfg):
"""
Create an AuriStreamConfig from a local dataclass config.
Args:
local_cfg: A dataclass config object (e.g., AuriStream100M20PredConfig)
Returns:
AuriStreamConfig instance
"""
config_dict = {}
# Map all known attributes
known_attrs = [
'vocab_size', 'n_embd', 'n_layer', 'n_head', 'n_pred_steps',
'dropout', 'bias', 'rope_theta', 'input_conv_kernel_size'
]
for attr in known_attrs:
if hasattr(local_cfg, attr):
config_dict[attr] = getattr(local_cfg, attr)
# Handle n_pred_steps default (if not present, it's 1)
if 'n_pred_steps' not in config_dict:
config_dict['n_pred_steps'] = 1
return cls(**config_dict)
|