File size: 3,509 Bytes
45b7e2b
 
 
96a4150
 
 
 
 
45b7e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""TextSyncMimi model configuration"""

from transformers.utils import logging

try:
    from .configuration_mimi import MimiConfig
except ImportError:
    from configuration_mimi import MimiConfig


logger = logging.get_logger(__name__)


class TextSyncMimiConfig(MimiConfig):
    r"""
    This is the configuration class to store the configuration of a [`TextSyncMimi`]. 
    It is used to instantiate a TextSyncMimi model according to the specified arguments, 
    defining the model architecture.

    Configuration objects inherit from [`MimiConfig`] and include all Mimi parameters plus
    additional TextSyncMimi-specific parameters.

    Args:
        mimi_model_id (`str`, *optional*, defaults to `"kyutai/mimi"`):
            The Mimi model ID to use as the audio codec backbone.
        vocab_size (`int`, *optional*, defaults to 128256):
            Vocabulary size of the text tokenizer (LLaMA-3 tokenizer).
        alpha (`float`, *optional*, defaults to 1.0):
            Weight for the BCE end token loss in the total loss.
        cross_attention_layers (`int`, *optional*, defaults to 4):
            Number of cross-attention transformer layers for text-speech alignment.
        causal_attention_layers (`int`, *optional*, defaults to 4):
            Number of causal attention transformer layers for autoregressive generation.
        bce_threshold (`float`, *optional*, defaults to 0.1):
            BCE loss threshold - stop optimizing when BCE < threshold.
        end_token_threshold (`float`, *optional*, defaults to 0.5):
            BCE probability threshold for stopping generation (>= threshold means stop).
        max_z_tokens (`int`, *optional*, defaults to 50):
            Maximum z tokens to generate per text position during inference.
        text_embedding_dim (`int`, *optional*, defaults to 4096):
            Dimension of the text embeddings (LLaMA embedding size).
        **kwargs: Additional parameters passed to MimiConfig (hidden_size, sample_rate, etc.)

    Example:

    ```python
    >>> from transformers import TextSyncMimiConfig, TextSyncMimi

    >>> # Initializing a TextSyncMimi configuration
    >>> configuration = TextSyncMimiConfig()

    >>> # Initializing a model (with random weights) from the configuration
    >>> model = TextSyncMimi(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    model_type = "text_sync_mimi"

    def __init__(
        self,
        mimi_model_id: str = "kyutai/mimi",
        vocab_size: int = 128256,
        alpha: float = 1.0,
        cross_attention_layers: int = 4,
        causal_attention_layers: int = 4,
        bce_threshold: float = 0.1,
        end_token_threshold: float = 0.5,
        max_z_tokens: int = 50,
        text_embedding_dim: int = 4096,
        **kwargs,
    ):
        # Initialize parent MimiConfig with default values
        # This ensures all MimiConfig attributes are set
        super().__init__(**kwargs)
        
        # TextSyncMimi-specific parameters
        self.mimi_model_id = mimi_model_id
        self.vocab_size = vocab_size
        self.alpha = alpha
        self.cross_attention_layers = cross_attention_layers
        self.causal_attention_layers = causal_attention_layers
        self.bce_threshold = bce_threshold
        self.end_token_threshold = end_token_threshold
        self.max_z_tokens = max_z_tokens
        self.text_embedding_dim = text_embedding_dim


__all__ = ["TextSyncMimiConfig"]