potsawee commited on
Commit
45b7e2b
·
verified ·
1 Parent(s): a9f7000

Upload configuration_text_sync_mimi.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_text_sync_mimi.py +88 -0
configuration_text_sync_mimi.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TextSyncMimi model configuration"""
2
+
3
+ from transformers.utils import logging
4
+ from configuration_mimi import MimiConfig
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class TextSyncMimiConfig(MimiConfig):
11
+ r"""
12
+ This is the configuration class to store the configuration of a [`TextSyncMimi`].
13
+ It is used to instantiate a TextSyncMimi model according to the specified arguments,
14
+ defining the model architecture.
15
+
16
+ Configuration objects inherit from [`MimiConfig`] and include all Mimi parameters plus
17
+ additional TextSyncMimi-specific parameters.
18
+
19
+ Args:
20
+ mimi_model_id (`str`, *optional*, defaults to `"kyutai/mimi"`):
21
+ The Mimi model ID to use as the audio codec backbone.
22
+ vocab_size (`int`, *optional*, defaults to 128256):
23
+ Vocabulary size of the text tokenizer (LLaMA-3 tokenizer).
24
+ alpha (`float`, *optional*, defaults to 1.0):
25
+ Weight for the BCE end token loss in the total loss.
26
+ cross_attention_layers (`int`, *optional*, defaults to 4):
27
+ Number of cross-attention transformer layers for text-speech alignment.
28
+ causal_attention_layers (`int`, *optional*, defaults to 4):
29
+ Number of causal attention transformer layers for autoregressive generation.
30
+ bce_threshold (`float`, *optional*, defaults to 0.1):
31
+ BCE loss threshold - stop optimizing when BCE < threshold.
32
+ end_token_threshold (`float`, *optional*, defaults to 0.5):
33
+ BCE probability threshold for stopping generation (>= threshold means stop).
34
+ max_z_tokens (`int`, *optional*, defaults to 50):
35
+ Maximum z tokens to generate per text position during inference.
36
+ text_embedding_dim (`int`, *optional*, defaults to 4096):
37
+ Dimension of the text embeddings (LLaMA embedding size).
38
+ **kwargs: Additional parameters passed to MimiConfig (hidden_size, sample_rate, etc.)
39
+
40
+ Example:
41
+
42
+ ```python
43
+ >>> from transformers import TextSyncMimiConfig, TextSyncMimi
44
+
45
+ >>> # Initializing a TextSyncMimi configuration
46
+ >>> configuration = TextSyncMimiConfig()
47
+
48
+ >>> # Initializing a model (with random weights) from the configuration
49
+ >>> model = TextSyncMimi(configuration)
50
+
51
+ >>> # Accessing the model configuration
52
+ >>> configuration = model.config
53
+ ```
54
+ """
55
+
56
+ model_type = "text_sync_mimi"
57
+
58
+ def __init__(
59
+ self,
60
+ mimi_model_id: str = "kyutai/mimi",
61
+ vocab_size: int = 128256,
62
+ alpha: float = 1.0,
63
+ cross_attention_layers: int = 4,
64
+ causal_attention_layers: int = 4,
65
+ bce_threshold: float = 0.1,
66
+ end_token_threshold: float = 0.5,
67
+ max_z_tokens: int = 50,
68
+ text_embedding_dim: int = 4096,
69
+ **kwargs,
70
+ ):
71
+ # Initialize parent MimiConfig with default values
72
+ # This ensures all MimiConfig attributes are set
73
+ super().__init__(**kwargs)
74
+
75
+ # TextSyncMimi-specific parameters
76
+ self.mimi_model_id = mimi_model_id
77
+ self.vocab_size = vocab_size
78
+ self.alpha = alpha
79
+ self.cross_attention_layers = cross_attention_layers
80
+ self.causal_attention_layers = causal_attention_layers
81
+ self.bce_threshold = bce_threshold
82
+ self.end_token_threshold = end_token_threshold
83
+ self.max_z_tokens = max_z_tokens
84
+ self.text_embedding_dim = text_embedding_dim
85
+
86
+
87
+ __all__ = ["TextSyncMimiConfig"]
88
+