TextSyncMimi-v1 / configuration_text_sync_mimi.py
potsawee's picture
Upload configuration_text_sync_mimi.py with huggingface_hub
96a4150 verified
"""TextSyncMimi model configuration"""
from transformers.utils import logging
try:
from .configuration_mimi import MimiConfig
except ImportError:
from configuration_mimi import MimiConfig
logger = logging.get_logger(__name__)
class TextSyncMimiConfig(MimiConfig):
r"""
This is the configuration class to store the configuration of a [`TextSyncMimi`].
It is used to instantiate a TextSyncMimi model according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`MimiConfig`] and include all Mimi parameters plus
additional TextSyncMimi-specific parameters.
Args:
mimi_model_id (`str`, *optional*, defaults to `"kyutai/mimi"`):
The Mimi model ID to use as the audio codec backbone.
vocab_size (`int`, *optional*, defaults to 128256):
Vocabulary size of the text tokenizer (LLaMA-3 tokenizer).
alpha (`float`, *optional*, defaults to 1.0):
Weight for the BCE end token loss in the total loss.
cross_attention_layers (`int`, *optional*, defaults to 4):
Number of cross-attention transformer layers for text-speech alignment.
causal_attention_layers (`int`, *optional*, defaults to 4):
Number of causal attention transformer layers for autoregressive generation.
bce_threshold (`float`, *optional*, defaults to 0.1):
BCE loss threshold - stop optimizing when BCE < threshold.
end_token_threshold (`float`, *optional*, defaults to 0.5):
BCE probability threshold for stopping generation (>= threshold means stop).
max_z_tokens (`int`, *optional*, defaults to 50):
Maximum z tokens to generate per text position during inference.
text_embedding_dim (`int`, *optional*, defaults to 4096):
Dimension of the text embeddings (LLaMA embedding size).
**kwargs: Additional parameters passed to MimiConfig (hidden_size, sample_rate, etc.)
Example:
```python
>>> from transformers import TextSyncMimiConfig, TextSyncMimi
>>> # Initializing a TextSyncMimi configuration
>>> configuration = TextSyncMimiConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = TextSyncMimi(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "text_sync_mimi"
def __init__(
self,
mimi_model_id: str = "kyutai/mimi",
vocab_size: int = 128256,
alpha: float = 1.0,
cross_attention_layers: int = 4,
causal_attention_layers: int = 4,
bce_threshold: float = 0.1,
end_token_threshold: float = 0.5,
max_z_tokens: int = 50,
text_embedding_dim: int = 4096,
**kwargs,
):
# Initialize parent MimiConfig with default values
# This ensures all MimiConfig attributes are set
super().__init__(**kwargs)
# TextSyncMimi-specific parameters
self.mimi_model_id = mimi_model_id
self.vocab_size = vocab_size
self.alpha = alpha
self.cross_attention_layers = cross_attention_layers
self.causal_attention_layers = causal_attention_layers
self.bce_threshold = bce_threshold
self.end_token_threshold = end_token_threshold
self.max_z_tokens = max_z_tokens
self.text_embedding_dim = text_embedding_dim
__all__ = ["TextSyncMimiConfig"]