Upload configuration_text_sync_mimi.py with huggingface_hub
Browse files
configuration_text_sync_mimi.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TextSyncMimi model configuration"""
|
| 2 |
+
|
| 3 |
+
from transformers.utils import logging
|
| 4 |
+
from configuration_mimi import MimiConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
logger = logging.get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TextSyncMimiConfig(MimiConfig):
|
| 11 |
+
r"""
|
| 12 |
+
This is the configuration class to store the configuration of a [`TextSyncMimi`].
|
| 13 |
+
It is used to instantiate a TextSyncMimi model according to the specified arguments,
|
| 14 |
+
defining the model architecture.
|
| 15 |
+
|
| 16 |
+
Configuration objects inherit from [`MimiConfig`] and include all Mimi parameters plus
|
| 17 |
+
additional TextSyncMimi-specific parameters.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
mimi_model_id (`str`, *optional*, defaults to `"kyutai/mimi"`):
|
| 21 |
+
The Mimi model ID to use as the audio codec backbone.
|
| 22 |
+
vocab_size (`int`, *optional*, defaults to 128256):
|
| 23 |
+
Vocabulary size of the text tokenizer (LLaMA-3 tokenizer).
|
| 24 |
+
alpha (`float`, *optional*, defaults to 1.0):
|
| 25 |
+
Weight for the BCE end token loss in the total loss.
|
| 26 |
+
cross_attention_layers (`int`, *optional*, defaults to 4):
|
| 27 |
+
Number of cross-attention transformer layers for text-speech alignment.
|
| 28 |
+
causal_attention_layers (`int`, *optional*, defaults to 4):
|
| 29 |
+
Number of causal attention transformer layers for autoregressive generation.
|
| 30 |
+
bce_threshold (`float`, *optional*, defaults to 0.1):
|
| 31 |
+
BCE loss threshold - stop optimizing when BCE < threshold.
|
| 32 |
+
end_token_threshold (`float`, *optional*, defaults to 0.5):
|
| 33 |
+
BCE probability threshold for stopping generation (>= threshold means stop).
|
| 34 |
+
max_z_tokens (`int`, *optional*, defaults to 50):
|
| 35 |
+
Maximum z tokens to generate per text position during inference.
|
| 36 |
+
text_embedding_dim (`int`, *optional*, defaults to 4096):
|
| 37 |
+
Dimension of the text embeddings (LLaMA embedding size).
|
| 38 |
+
**kwargs: Additional parameters passed to MimiConfig (hidden_size, sample_rate, etc.)
|
| 39 |
+
|
| 40 |
+
Example:
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
>>> from transformers import TextSyncMimiConfig, TextSyncMimi
|
| 44 |
+
|
| 45 |
+
>>> # Initializing a TextSyncMimi configuration
|
| 46 |
+
>>> configuration = TextSyncMimiConfig()
|
| 47 |
+
|
| 48 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 49 |
+
>>> model = TextSyncMimi(configuration)
|
| 50 |
+
|
| 51 |
+
>>> # Accessing the model configuration
|
| 52 |
+
>>> configuration = model.config
|
| 53 |
+
```
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
model_type = "text_sync_mimi"
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
mimi_model_id: str = "kyutai/mimi",
|
| 61 |
+
vocab_size: int = 128256,
|
| 62 |
+
alpha: float = 1.0,
|
| 63 |
+
cross_attention_layers: int = 4,
|
| 64 |
+
causal_attention_layers: int = 4,
|
| 65 |
+
bce_threshold: float = 0.1,
|
| 66 |
+
end_token_threshold: float = 0.5,
|
| 67 |
+
max_z_tokens: int = 50,
|
| 68 |
+
text_embedding_dim: int = 4096,
|
| 69 |
+
**kwargs,
|
| 70 |
+
):
|
| 71 |
+
# Initialize parent MimiConfig with default values
|
| 72 |
+
# This ensures all MimiConfig attributes are set
|
| 73 |
+
super().__init__(**kwargs)
|
| 74 |
+
|
| 75 |
+
# TextSyncMimi-specific parameters
|
| 76 |
+
self.mimi_model_id = mimi_model_id
|
| 77 |
+
self.vocab_size = vocab_size
|
| 78 |
+
self.alpha = alpha
|
| 79 |
+
self.cross_attention_layers = cross_attention_layers
|
| 80 |
+
self.causal_attention_layers = causal_attention_layers
|
| 81 |
+
self.bce_threshold = bce_threshold
|
| 82 |
+
self.end_token_threshold = end_token_threshold
|
| 83 |
+
self.max_z_tokens = max_z_tokens
|
| 84 |
+
self.text_embedding_dim = text_embedding_dim
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
__all__ = ["TextSyncMimiConfig"]
|
| 88 |
+
|