File size: 3,183 Bytes
9c3e596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
AuriStream Configuration for HuggingFace Transformers.

AuriStream is a speech language model by Greta Tuckute and Klemen Kotar.
"""

from transformers import PretrainedConfig


class AuriStreamConfig(PretrainedConfig):
    """
    Configuration class for AuriStream models.
    
    This configuration supports various model sizes and prediction head configurations
    for the AuriStream speech language model family.
    
    Args:
        vocab_size (`int`, *optional*, defaults to 8192):
            Vocabulary size of the model (number of cochlear tokens).
        n_embd (`int`, *optional*, defaults to 768):
            Dimensionality of the embeddings and hidden states.
        n_layer (`int`, *optional*, defaults to 12):
            Number of transformer layers.
        n_head (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer.
        n_pred_steps (`int`, *optional*, defaults to 1):
            Number of future prediction steps (multi-token prediction heads).
        dropout (`float`, *optional*, defaults to 0.0):
            Dropout probability for all fully connected layers.
        bias (`bool`, *optional*, defaults to False):
            Whether to use bias in linear layers.
        rope_theta (`float`, *optional*, defaults to 10000.0):
            Base theta for RoPE embeddings.
        input_conv_kernel_size (`int`, *optional*, defaults to 0):
            Kernel size for input convolution layer (0 means no input conv).
    """
    
    model_type = "AuriStream"
    
    def __init__(
        self,
        vocab_size: int = 8192,
        n_embd: int = 768,
        n_layer: int = 12,
        n_head: int = 12,
        n_pred_steps: int = 1,
        dropout: float = 0.0,
        bias: bool = False,
        rope_theta: float = 10000.0,
        input_conv_kernel_size: int = 0,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_pred_steps = n_pred_steps
        self.dropout = dropout
        self.bias = bias
        self.rope_theta = rope_theta
        self.input_conv_kernel_size = input_conv_kernel_size
        
        super().__init__(**kwargs)
    
    @classmethod
    def from_local_config(cls, local_cfg):
        """
        Create an AuriStreamConfig from a local dataclass config.
        
        Args:
            local_cfg: A dataclass config object (e.g., AuriStream100M20PredConfig)
            
        Returns:
            AuriStreamConfig instance
        """
        config_dict = {}
        
        # Map all known attributes
        known_attrs = [
            'vocab_size', 'n_embd', 'n_layer', 'n_head', 'n_pred_steps',
            'dropout', 'bias', 'rope_theta', 'input_conv_kernel_size'
        ]
        
        for attr in known_attrs:
            if hasattr(local_cfg, attr):
                config_dict[attr] = getattr(local_cfg, attr)
        
        # Handle n_pred_steps default (if not present, it's 1)
        if 'n_pred_steps' not in config_dict:
            config_dict['n_pred_steps'] = 1
            
        return cls(**config_dict)