File size: 3,509 Bytes
45b7e2b 96a4150 45b7e2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""TextSyncMimi model configuration"""
from transformers.utils import logging
try:
from .configuration_mimi import MimiConfig
except ImportError:
from configuration_mimi import MimiConfig
logger = logging.get_logger(__name__)
class TextSyncMimiConfig(MimiConfig):
r"""
This is the configuration class to store the configuration of a [`TextSyncMimi`].
It is used to instantiate a TextSyncMimi model according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`MimiConfig`] and include all Mimi parameters plus
additional TextSyncMimi-specific parameters.
Args:
mimi_model_id (`str`, *optional*, defaults to `"kyutai/mimi"`):
The Mimi model ID to use as the audio codec backbone.
vocab_size (`int`, *optional*, defaults to 128256):
Vocabulary size of the text tokenizer (LLaMA-3 tokenizer).
alpha (`float`, *optional*, defaults to 1.0):
Weight for the BCE end token loss in the total loss.
cross_attention_layers (`int`, *optional*, defaults to 4):
Number of cross-attention transformer layers for text-speech alignment.
causal_attention_layers (`int`, *optional*, defaults to 4):
Number of causal attention transformer layers for autoregressive generation.
bce_threshold (`float`, *optional*, defaults to 0.1):
BCE loss threshold - stop optimizing when BCE < threshold.
end_token_threshold (`float`, *optional*, defaults to 0.5):
BCE probability threshold for stopping generation (>= threshold means stop).
max_z_tokens (`int`, *optional*, defaults to 50):
Maximum z tokens to generate per text position during inference.
text_embedding_dim (`int`, *optional*, defaults to 4096):
Dimension of the text embeddings (LLaMA embedding size).
**kwargs: Additional parameters passed to MimiConfig (hidden_size, sample_rate, etc.)
Example:
```python
>>> from transformers import TextSyncMimiConfig, TextSyncMimi
>>> # Initializing a TextSyncMimi configuration
>>> configuration = TextSyncMimiConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = TextSyncMimi(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "text_sync_mimi"
def __init__(
self,
mimi_model_id: str = "kyutai/mimi",
vocab_size: int = 128256,
alpha: float = 1.0,
cross_attention_layers: int = 4,
causal_attention_layers: int = 4,
bce_threshold: float = 0.1,
end_token_threshold: float = 0.5,
max_z_tokens: int = 50,
text_embedding_dim: int = 4096,
**kwargs,
):
# Initialize parent MimiConfig with default values
# This ensures all MimiConfig attributes are set
super().__init__(**kwargs)
# TextSyncMimi-specific parameters
self.mimi_model_id = mimi_model_id
self.vocab_size = vocab_size
self.alpha = alpha
self.cross_attention_layers = cross_attention_layers
self.causal_attention_layers = causal_attention_layers
self.bce_threshold = bce_threshold
self.end_token_threshold = end_token_threshold
self.max_z_tokens = max_z_tokens
self.text_embedding_dim = text_embedding_dim
__all__ = ["TextSyncMimiConfig"]
|