|
|
"""TextSyncMimi model configuration""" |
|
|
|
|
|
from transformers.utils import logging |
|
|
|
|
|
try: |
|
|
from .configuration_mimi import MimiConfig |
|
|
except ImportError: |
|
|
from configuration_mimi import MimiConfig |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class TextSyncMimiConfig(MimiConfig): |
|
|
r""" |
|
|
This is the configuration class to store the configuration of a [`TextSyncMimi`]. |
|
|
It is used to instantiate a TextSyncMimi model according to the specified arguments, |
|
|
defining the model architecture. |
|
|
|
|
|
Configuration objects inherit from [`MimiConfig`] and include all Mimi parameters plus |
|
|
additional TextSyncMimi-specific parameters. |
|
|
|
|
|
Args: |
|
|
mimi_model_id (`str`, *optional*, defaults to `"kyutai/mimi"`): |
|
|
The Mimi model ID to use as the audio codec backbone. |
|
|
vocab_size (`int`, *optional*, defaults to 128256): |
|
|
Vocabulary size of the text tokenizer (LLaMA-3 tokenizer). |
|
|
alpha (`float`, *optional*, defaults to 1.0): |
|
|
Weight for the BCE end token loss in the total loss. |
|
|
cross_attention_layers (`int`, *optional*, defaults to 4): |
|
|
Number of cross-attention transformer layers for text-speech alignment. |
|
|
causal_attention_layers (`int`, *optional*, defaults to 4): |
|
|
Number of causal attention transformer layers for autoregressive generation. |
|
|
bce_threshold (`float`, *optional*, defaults to 0.1): |
|
|
BCE loss threshold - stop optimizing when BCE < threshold. |
|
|
end_token_threshold (`float`, *optional*, defaults to 0.5): |
|
|
BCE probability threshold for stopping generation (>= threshold means stop). |
|
|
max_z_tokens (`int`, *optional*, defaults to 50): |
|
|
Maximum z tokens to generate per text position during inference. |
|
|
text_embedding_dim (`int`, *optional*, defaults to 4096): |
|
|
Dimension of the text embeddings (LLaMA embedding size). |
|
|
**kwargs: Additional parameters passed to MimiConfig (hidden_size, sample_rate, etc.) |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from transformers import TextSyncMimiConfig, TextSyncMimi |
|
|
|
|
|
>>> # Initializing a TextSyncMimi configuration |
|
|
>>> configuration = TextSyncMimiConfig() |
|
|
|
|
|
>>> # Initializing a model (with random weights) from the configuration |
|
|
>>> model = TextSyncMimi(configuration) |
|
|
|
|
|
>>> # Accessing the model configuration |
|
|
>>> configuration = model.config |
|
|
``` |
|
|
""" |
|
|
|
|
|
model_type = "text_sync_mimi" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mimi_model_id: str = "kyutai/mimi", |
|
|
vocab_size: int = 128256, |
|
|
alpha: float = 1.0, |
|
|
cross_attention_layers: int = 4, |
|
|
causal_attention_layers: int = 4, |
|
|
bce_threshold: float = 0.1, |
|
|
end_token_threshold: float = 0.5, |
|
|
max_z_tokens: int = 50, |
|
|
text_embedding_dim: int = 4096, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
self.mimi_model_id = mimi_model_id |
|
|
self.vocab_size = vocab_size |
|
|
self.alpha = alpha |
|
|
self.cross_attention_layers = cross_attention_layers |
|
|
self.causal_attention_layers = causal_attention_layers |
|
|
self.bce_threshold = bce_threshold |
|
|
self.end_token_threshold = end_token_threshold |
|
|
self.max_z_tokens = max_z_tokens |
|
|
self.text_embedding_dim = text_embedding_dim |
|
|
|
|
|
|
|
|
__all__ = ["TextSyncMimiConfig"] |
|
|
|
|
|
|