mlinmg commited on
Commit
8b6d69d
·
verified ·
1 Parent(s): 7db4254

Upload 6 files

Browse files
Files changed (6) hide show
  1. config.json +60 -0
  2. gpt_config.py +172 -0
  3. tokenizer.py +233 -0
  4. xtts2_config.py +418 -0
  5. xtts2_modeling.py +259 -0
  6. xttsv2-hifigan-mel.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "AstraMindAI/xtts2",
3
+ "architectures": [
4
+ "Xtts"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "AstraMindAI/xtts2--xtts2_config.XTTSConfig",
8
+ "AutoModelForCausalLM": "AstraMindAI/xtts2--xtts2_modeling.Xtts"
9
+ },
10
+ "cond_d_vector_in_each_upsampling_layer": true,
11
+ "d_vector_dim": 512,
12
+ "decoder_input_dim": 1024,
13
+ "input_sample_rate": 22050,
14
+ "model_type": "xtts_hifigan",
15
+ "output_hop_length": 256,
16
+ "output_sample_rate": 24000,
17
+ "resblock_dilation_sizes": [
18
+ [
19
+ 1,
20
+ 3,
21
+ 5
22
+ ],
23
+ [
24
+ 1,
25
+ 3,
26
+ 5
27
+ ],
28
+ [
29
+ 1,
30
+ 3,
31
+ 5
32
+ ]
33
+ ],
34
+ "resblock_kernel_sizes": [
35
+ 3,
36
+ 7,
37
+ 11
38
+ ],
39
+ "speaker_encoder_config": {
40
+ "model_config": null,
41
+ "model_name": "speaker_encoder",
42
+ "preprocess_config": null,
43
+ "speaker_embedding_dim": 512,
44
+ "use_torch_spec": true
45
+ },
46
+ "transformers_version": "4.45.1",
47
+ "upsample_initial_channel": 512,
48
+ "upsample_kernel_sizes": [
49
+ 16,
50
+ 16,
51
+ 4,
52
+ 4
53
+ ],
54
+ "upsample_rates": [
55
+ 8,
56
+ 8,
57
+ 2,
58
+ 2
59
+ ]
60
+ }
gpt_config.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Dict, Optional, List
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ @dataclass
10
+ class XTTSAudioConfig:
11
+ """Configuration for audio processing parameters"""
12
+ sample_rate: int = 22050
13
+ output_sample_rate: int = 24000
14
+ mel_channels: int = 80
15
+ hop_length: int = 256
16
+ win_length: int = 1024
17
+ n_fft: int = 1024
18
+ fmin: int = 0
19
+ fmax: int = 8000
20
+ power: float = 1.0
21
+ mel_norms_file: Optional[str] = None
22
+
23
+
24
+ class XTTSGPTConfig(PretrainedConfig):
25
+ """Configuration class for the GPT component of XTTS"""
26
+ model_type = "xtts_gpt"
27
+
28
+ def __init__(
29
+ self,
30
+ # Model architecture
31
+ vocab_size: int = 256,
32
+ num_chars: int = 255,
33
+
34
+ # GPT parameters
35
+ gpt_batch_size: int = 1,
36
+ gpt_max_audio_tokens: int = 605,
37
+ gpt_max_text_tokens: int = 402,
38
+ gpt_max_prompt_tokens: int = 70,
39
+ gpt_layers: int = 30,
40
+ gpt_n_model_channels: int = 1024,
41
+ gpt_n_heads: int = 16,
42
+ gpt_number_text_tokens: int = 6681,
43
+ gpt_start_text_token: Optional[int] = None,
44
+ gpt_stop_text_token: Optional[int] = None,
45
+ gpt_num_audio_tokens: int = 1026,
46
+ gpt_start_audio_token: int = 1024,
47
+ gpt_stop_audio_token: int = 1025,
48
+ gpt_code_stride_len: int = 1024,
49
+ gpt_use_masking_gt_prompt_approach: bool = True,
50
+ gpt_use_perceiver_resampler: bool = True,
51
+ gpt_checkpointing: bool = False,
52
+ gpt_train_solo_embeddings: bool = False,
53
+
54
+ # Training parameters
55
+ enable_redaction: bool = False,
56
+ kv_cache: bool = True,
57
+ perceiver_cond_length_compression: int = 256,
58
+ label_smoothing: float = 0.0,
59
+
60
+ # Generation parameters
61
+ temperature: float = 0.75,
62
+ length_penalty: float = 1.0,
63
+ repetition_penalty: float = 5.0,
64
+ top_k: int = 50,
65
+ top_p: float = 0.85,
66
+ gpt_cond_len: int = 30,
67
+ gpt_cond_chunk_len: int = 4,
68
+ max_ref_len: int = 30,
69
+ sound_norm_refs: bool = False,
70
+
71
+ # Audio processing
72
+ audio_config: Optional[XTTSAudioConfig] = None,
73
+
74
+ # Constants and limits
75
+ duration_const: int = 102400,
76
+ char_limits: Optional[Dict[str, int]] = None,
77
+ languages: Optional[List[str]] = None,
78
+ pad_token_id: Optional[int] = None,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs,
82
+ ):
83
+ if char_limits is None:
84
+ char_limits = {
85
+ "en": 250, "de": 253, "fr": 273, "es": 239,
86
+ "it": 213, "pt": 203, "pl": 224, "zh": 82,
87
+ "ar": 166, "cs": 186, "ru": 182, "nl": 251,
88
+ "tr": 226, "ja": 71, "hu": 224, "ko": 95,
89
+ }
90
+
91
+ if languages is None:
92
+ languages = [
93
+ "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
94
+ "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
95
+ ]
96
+
97
+ if audio_config is None:
98
+ audio_config = XTTSAudioConfig()
99
+
100
+ super().__init__(
101
+ pad_token_id=pad_token_id,
102
+ bos_token_id=bos_token_id,
103
+ eos_token_id=eos_token_id,
104
+ **kwargs
105
+ )
106
+
107
+ self.vocab_size = vocab_size
108
+ self.num_chars = num_chars
109
+
110
+ # GPT parameters
111
+ self.gpt_batch_size = gpt_batch_size
112
+ self.gpt_max_audio_tokens = gpt_max_audio_tokens
113
+ self.gpt_max_text_tokens = gpt_max_text_tokens
114
+ self.gpt_max_prompt_tokens = gpt_max_prompt_tokens
115
+ self.gpt_layers = gpt_layers
116
+ self.gpt_n_model_channels = gpt_n_model_channels
117
+ self.gpt_n_heads = gpt_n_heads
118
+ self.gpt_number_text_tokens = gpt_number_text_tokens
119
+ self.gpt_start_text_token = gpt_start_text_token
120
+ self.gpt_stop_text_token = gpt_stop_text_token
121
+ self.gpt_num_audio_tokens = gpt_num_audio_tokens
122
+ self.gpt_start_audio_token = gpt_start_audio_token
123
+ self.gpt_stop_audio_token = gpt_stop_audio_token
124
+ self.gpt_code_stride_len = gpt_code_stride_len
125
+ self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach
126
+ self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler
127
+ self.gpt_checkpointing = gpt_checkpointing
128
+ self.gpt_train_solo_embeddings = gpt_train_solo_embeddings
129
+
130
+ # Training parameters
131
+ self.enable_redaction = enable_redaction
132
+ self.kv_cache = kv_cache
133
+ self.perceiver_cond_length_compression = perceiver_cond_length_compression
134
+ self.label_smoothing = label_smoothing
135
+
136
+ # Generation parameters
137
+ self.temperature = temperature
138
+ self.length_penalty = length_penalty
139
+ self.repetition_penalty = repetition_penalty
140
+ self.top_k = top_k
141
+ self.top_p = top_p
142
+ self.gpt_cond_len = gpt_cond_len
143
+ self.gpt_cond_chunk_len = gpt_cond_chunk_len
144
+ self.max_ref_len = max_ref_len
145
+ self.sound_norm_refs = sound_norm_refs
146
+
147
+ # Audio processing
148
+ self.audio_config = audio_config
149
+
150
+ # Constants and limits
151
+ self.duration_const = duration_const
152
+ self.char_limits = char_limits
153
+ self.languages = languages
154
+
155
+ def to_dict(self):
156
+ """Convert config to dictionary"""
157
+ config_dict = super().to_dict()
158
+ config_dict["audio_config"] = asdict(self.audio_config)
159
+ return config_dict
160
+
161
+ @classmethod
162
+ def from_dict(cls, config_dict):
163
+ """Create config from dictionary"""
164
+ audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {}))
165
+ return cls(audio_config=audio_config, **config_dict)
166
+
167
+ def update_with_tokenizer(self, tokenizer=None):
168
+ """Update configuration values based on tokenizer"""
169
+ if tokenizer is not None:
170
+ self.gpt_number_text_tokens = tokenizer.get_vocab_size()
171
+ self.gpt_start_text_token = tokenizer.bos_token_id
172
+ self.gpt_stop_text_token = tokenizer.eos_token_id
tokenizer.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Tuple, Any
2
+ import os
3
+ from functools import cached_property
4
+
5
+ from transformers import PreTrainedTokenizerFast
6
+ from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
7
+ from tokenizers import Tokenizer, processors
8
+ from tokenizers.pre_tokenizers import WhitespaceSplit
9
+ from tokenizers.processors import TemplateProcessing
10
+ import torch
11
+ from hangul_romanize import Transliter
12
+ from hangul_romanize.rule import academic
13
+ import cutlet
14
+
15
+ from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
16
+ chinese_transliterate, korean_transliterate,
17
+ japanese_cleaners)
18
+
19
+ class XTTSTokenizerFast(PreTrainedTokenizerFast):
20
+ """
21
+ Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
22
+ """
23
+ def __init__(
24
+ self,
25
+ vocab_file: str = None,
26
+ tokenizer_object: Optional[Tokenizer] = None,
27
+ unk_token: str = "[UNK]",
28
+ pad_token: str = "[PAD]",
29
+ bos_token: str = "[START]",
30
+ eos_token: str = "[STOP]",
31
+ clean_up_tokenization_spaces: bool = True,
32
+ **kwargs
33
+ ):
34
+ if tokenizer_object is None and vocab_file is not None:
35
+ tokenizer_object = Tokenizer.from_file(vocab_file)
36
+
37
+ if tokenizer_object is not None:
38
+ # Configure the tokenizer
39
+ tokenizer_object.pre_tokenizer = WhitespaceSplit()
40
+ tokenizer_object.enable_padding(
41
+ direction='right',
42
+ pad_id=tokenizer_object.token_to_id(pad_token) or 0,
43
+ pad_token=pad_token
44
+ )
45
+ tokenizer_object.post_processor = TemplateProcessing(
46
+ single=f"{bos_token} $A {eos_token}",
47
+ special_tokens=[
48
+ (bos_token, tokenizer_object.token_to_id(bos_token)),
49
+ (eos_token, tokenizer_object.token_to_id(eos_token)),
50
+ ],
51
+ )
52
+
53
+ super().__init__(
54
+ tokenizer_object=tokenizer_object,
55
+ unk_token=unk_token,
56
+ pad_token=pad_token,
57
+ bos_token=bos_token,
58
+ eos_token=eos_token,
59
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
60
+ **kwargs
61
+ )
62
+
63
+ # Character limits per language
64
+ self.char_limits = {
65
+ "en": 250, "de": 253, "fr": 273, "es": 239,
66
+ "it": 213, "pt": 203, "pl": 224, "zh": 82,
67
+ "ar": 166, "cs": 186, "ru": 182, "nl": 251,
68
+ "tr": 226, "ja": 71, "hu": 224, "ko": 95,
69
+ }
70
+
71
+ # Initialize language tools
72
+ self._katsu = None
73
+ self._korean_transliter = Transliter(academic)
74
+
75
+ @cached_property
76
+ def katsu(self):
77
+ if self._katsu is None:
78
+ self._katsu = cutlet.Cutlet()
79
+ return self._katsu
80
+
81
+ def check_input_length(self, text: str, lang: str):
82
+ """Check if input text length is within limits for language"""
83
+ lang = lang.split("-")[0] # remove region
84
+ limit = self.char_limits.get(lang, 250)
85
+ if len(text) > limit:
86
+ print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
87
+
88
+ def preprocess_text(self, text: str, lang: str) -> str:
89
+ """Apply text preprocessing for language"""
90
+ if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
91
+ "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
92
+ text = multilingual_cleaners(text, lang)
93
+ if lang == "zh":
94
+ text = chinese_transliterate(text)
95
+ if lang == "ko":
96
+ text = korean_transliterate(text)
97
+ elif lang == "ja":
98
+ text = japanese_cleaners(text, self.katsu)
99
+ else:
100
+ text = basic_cleaners(text)
101
+ return text
102
+
103
+ def _batch_encode_plus(
104
+ self,
105
+ batch_text_or_text_pairs,
106
+ add_special_tokens: bool = True,
107
+ padding_strategy = PaddingStrategy.DO_NOT_PAD,
108
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
109
+ max_length: Optional[int] = 402,
110
+ stride: int = 0,
111
+ is_split_into_words: bool = False,
112
+ pad_to_multiple_of: Optional[int] = None,
113
+ return_tensors: Optional[str] = None,
114
+ return_token_type_ids: Optional[bool] = None,
115
+ return_attention_mask: Optional[bool] = None,
116
+ return_overflowing_tokens: bool = False,
117
+ return_special_tokens_mask: bool = False,
118
+ return_offsets_mapping: bool = False,
119
+ return_length: bool = False,
120
+ verbose: bool = True,
121
+ **kwargs
122
+ ) -> Dict[str, Any]:
123
+ """
124
+ Override batch encoding to handle language-specific preprocessing
125
+ """
126
+ lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
127
+ if isinstance(lang, str):
128
+ lang = [lang] * len(batch_text_or_text_pairs)
129
+
130
+ # Preprocess each text in the batch with its corresponding language
131
+ processed_texts = []
132
+ for text, text_lang in zip(batch_text_or_text_pairs, lang):
133
+ if isinstance(text, str):
134
+ # Check length and preprocess
135
+ self.check_input_length(text, text_lang)
136
+ processed_text = self.preprocess_text(text, text_lang)
137
+
138
+ # Format text with language tag and spaces
139
+ lang_code = "zh-cn" if text_lang == "zh" else text_lang
140
+ processed_text = f"[{lang_code}]{processed_text}"
141
+ processed_text = processed_text.replace(" ", "[SPACE]")
142
+
143
+ processed_texts.append(processed_text)
144
+ else:
145
+ processed_texts.append(text)
146
+
147
+ # Call the parent class's encoding method with processed texts
148
+ return super()._batch_encode_plus(
149
+ processed_texts,
150
+ add_special_tokens=add_special_tokens,
151
+ padding_strategy=padding_strategy,
152
+ truncation_strategy=truncation_strategy,
153
+ max_length=max_length,
154
+ stride=stride,
155
+ is_split_into_words=is_split_into_words,
156
+ pad_to_multiple_of=pad_to_multiple_of,
157
+ return_tensors=return_tensors,
158
+ return_token_type_ids=return_token_type_ids,
159
+ return_attention_mask=return_attention_mask,
160
+ return_overflowing_tokens=return_overflowing_tokens,
161
+ return_special_tokens_mask=return_special_tokens_mask,
162
+ return_offsets_mapping=return_offsets_mapping,
163
+ return_length=return_length,
164
+ verbose=verbose,
165
+ **kwargs
166
+ )
167
+
168
+ def __call__(
169
+ self,
170
+ text: Union[str, List[str]],
171
+ lang: Union[str, List[str]] = "en",
172
+ add_special_tokens: bool = True,
173
+ padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
174
+ truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
175
+ max_length: Optional[int] = 402,
176
+ stride: int = 0,
177
+ return_tensors: Optional[str] = None,
178
+ return_token_type_ids: Optional[bool] = None,
179
+ return_attention_mask: Optional[bool] = True, # Changed default to True
180
+ **kwargs
181
+ ):
182
+ """
183
+ Main tokenization method
184
+ Args:
185
+ text: Text or list of texts to tokenize
186
+ lang: Language code or list of language codes corresponding to each text
187
+ add_special_tokens: Whether to add special tokens
188
+ padding: Padding strategy (default True)
189
+ truncation: Truncation strategy (default True)
190
+ max_length: Maximum length
191
+ stride: Stride for truncation
192
+ return_tensors: Format of output tensors ("pt" for PyTorch)
193
+ return_token_type_ids: Whether to return token type IDs
194
+ return_attention_mask: Whether to return attention mask (default True)
195
+ """
196
+ # Convert single string to list for batch processing
197
+ if isinstance(text, str):
198
+ text = [text]
199
+ if isinstance(lang, str):
200
+ lang = [lang]
201
+
202
+ # Ensure text and lang lists have same length
203
+ if len(text) != len(lang):
204
+ raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
205
+
206
+ # Convert padding strategy
207
+ if isinstance(padding, bool):
208
+ padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
209
+ else:
210
+ padding_strategy = PaddingStrategy(padding)
211
+
212
+ # Convert truncation strategy
213
+ if isinstance(truncation, bool):
214
+ truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE
215
+ else:
216
+ truncation_strategy = TruncationStrategy(truncation)
217
+
218
+ # Use the batch encoding method
219
+ encoded = self._batch_encode_plus(
220
+ text,
221
+ add_special_tokens=add_special_tokens,
222
+ padding_strategy=padding_strategy,
223
+ truncation_strategy=truncation_strategy,
224
+ max_length=max_length,
225
+ stride=stride,
226
+ return_tensors=return_tensors,
227
+ return_token_type_ids=return_token_type_ids,
228
+ return_attention_mask=return_attention_mask,
229
+ lang=lang,
230
+ **kwargs
231
+ )
232
+
233
+ return encoded
xtts2_config.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, dataclass
2
+ from typing import Dict, List, Optional
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ @dataclass
7
+ class SpeakerEncoderConfig:
8
+ """Configuration for the speaker encoder component"""
9
+ model_name: str = "speaker_encoder"
10
+ preprocess_config: Optional[Dict] = None
11
+ model_config: Optional[Dict] = None
12
+ speaker_embedding_dim: int = 512
13
+ use_torch_spec: bool = True
14
+
15
+
16
+ @dataclass
17
+ class XTTSAudioConfig:
18
+ """Configuration for audio processing parameters"""
19
+ sample_rate: int = 22050
20
+ output_sample_rate: int = 24000
21
+ mel_channels: int = 80
22
+ hop_length: int = 256
23
+ win_length: int = 1024
24
+ n_fft: int = 1024
25
+ fmin: int = 0
26
+ fmax: int = 8000
27
+ power: float = 1.0
28
+ mel_norms_file: Optional[str] = None
29
+
30
+
31
+ class XTTSConfig(PretrainedConfig):
32
+ """Combined configuration class for XTTS including both HifiGAN and GPT components"""
33
+ model_type = "xtts"
34
+
35
+ def __init__(
36
+ self,
37
+ # HifiGAN Audio parameters
38
+ input_sample_rate: int = 22050,
39
+ output_sample_rate: int = 24000,
40
+ output_hop_length: int = 256,
41
+
42
+ # HifiGAN Model architecture
43
+ decoder_input_dim: int = 1024,
44
+ d_vector_dim: int = 512,
45
+ cond_d_vector_in_each_upsampling_layer: bool = True,
46
+
47
+ # HifiGAN Upsampling parameters
48
+ upsample_rates: List[int] = None,
49
+ upsample_kernel_sizes: List[int] = None,
50
+ upsample_initial_channel: int = 512,
51
+
52
+ # HifiGAN Resblock parameters
53
+ resblock_kernel_sizes: List[int] = None,
54
+ resblock_dilation_sizes: List[List[int]] = None,
55
+
56
+ # HifiGAN Speaker encoder
57
+ speaker_encoder_config: Optional[Dict] = None,
58
+
59
+ # GPT Model architecture
60
+ vocab_size: int = 256,
61
+ num_chars: int = 255,
62
+
63
+ # GPT parameters
64
+ gpt_batch_size: int = 1,
65
+ gpt_max_audio_tokens: int = 605,
66
+ gpt_max_text_tokens: int = 402,
67
+ gpt_max_prompt_tokens: int = 70,
68
+ gpt_layers: int = 30,
69
+ gpt_n_model_channels: int = 1024,
70
+ gpt_n_heads: int = 16,
71
+ gpt_number_text_tokens: int = 6681,
72
+ gpt_start_text_token: Optional[int] = None,
73
+ gpt_stop_text_token: Optional[int] = None,
74
+ gpt_num_audio_tokens: int = 1026,
75
+ gpt_start_audio_token: int = 1024,
76
+ gpt_stop_audio_token: int = 1025,
77
+ gpt_code_stride_len: int = 1024,
78
+ gpt_use_masking_gt_prompt_approach: bool = True,
79
+ gpt_use_perceiver_resampler: bool = True,
80
+ gpt_checkpointing: bool = False,
81
+ gpt_train_solo_embeddings: bool = False,
82
+
83
+ # GPT Training parameters
84
+ enable_redaction: bool = False,
85
+ kv_cache: bool = True,
86
+ perceiver_cond_length_compression: int = 256,
87
+ label_smoothing: float = 0.0,
88
+
89
+ # GPT Generation parameters
90
+ temperature: float = 0.75,
91
+ length_penalty: float = 1.0,
92
+ repetition_penalty: float = 5.0,
93
+ top_k: int = 50,
94
+ top_p: float = 0.85,
95
+ gpt_cond_len: int = 30,
96
+ gpt_cond_chunk_len: int = 4,
97
+ max_ref_len: int = 30,
98
+ sound_norm_refs: bool = False,
99
+
100
+ # GPT Audio processing
101
+ audio_config: Optional[XTTSAudioConfig] = None,
102
+
103
+ # GPT Constants and limits
104
+ duration_const: int = 102400,
105
+ char_limits: Optional[Dict[str, int]] = None,
106
+ languages: Optional[List[str]] = None,
107
+
108
+ # Base config parameters
109
+ pad_token_id: Optional[int] = None,
110
+ bos_token_id: Optional[int] = None,
111
+ eos_token_id: Optional[int] = None,
112
+ **kwargs,
113
+ ):
114
+ super().__init__(
115
+ pad_token_id=pad_token_id,
116
+ bos_token_id=bos_token_id,
117
+ eos_token_id=eos_token_id,
118
+ **kwargs
119
+ )
120
+
121
+ # Set default lists for HifiGAN
122
+ if upsample_rates is None:
123
+ upsample_rates = [8, 8, 2, 2]
124
+ if upsample_kernel_sizes is None:
125
+ upsample_kernel_sizes = [16, 16, 4, 4]
126
+ if resblock_kernel_sizes is None:
127
+ resblock_kernel_sizes = [3, 7, 11]
128
+ if resblock_dilation_sizes is None:
129
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
130
+
131
+ # Set default dicts for GPT
132
+ if char_limits is None:
133
+ char_limits = {
134
+ "en": 250, "de": 253, "fr": 273, "es": 239,
135
+ "it": 213, "pt": 203, "pl": 224, "zh": 82,
136
+ "ar": 166, "cs": 186, "ru": 182, "nl": 251,
137
+ "tr": 226, "ja": 71, "hu": 224, "ko": 95,
138
+ }
139
+
140
+ if languages is None:
141
+ languages = [
142
+ "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
143
+ "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
144
+ ]
145
+
146
+ # Initialize HifiGAN parameters
147
+ # Audio parameters
148
+ self.input_sample_rate = input_sample_rate
149
+ self.output_sample_rate = output_sample_rate
150
+ self.output_hop_length = output_hop_length
151
+
152
+ # Model architecture
153
+ self.decoder_input_dim = decoder_input_dim
154
+ self.d_vector_dim = d_vector_dim
155
+ self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer
156
+
157
+ # Upsampling parameters
158
+ self.upsample_rates = upsample_rates
159
+ self.upsample_kernel_sizes = upsample_kernel_sizes
160
+ self.upsample_initial_channel = upsample_initial_channel
161
+
162
+ # Resblock parameters
163
+ self.resblock_kernel_sizes = resblock_kernel_sizes
164
+ self.resblock_dilation_sizes = resblock_dilation_sizes
165
+
166
+ # Speaker encoder - store as dictionary
167
+ if speaker_encoder_config is None:
168
+ self.speaker_encoder_config = asdict(SpeakerEncoderConfig())
169
+ elif isinstance(speaker_encoder_config, dict):
170
+ default_config = asdict(SpeakerEncoderConfig())
171
+ default_config.update(speaker_encoder_config)
172
+ self.speaker_encoder_config = default_config
173
+ elif isinstance(speaker_encoder_config, SpeakerEncoderConfig):
174
+ self.speaker_encoder_config = asdict(speaker_encoder_config)
175
+ else:
176
+ raise ValueError("speaker_encoder_config must be either a dictionary or SpeakerEncoderConfig instance")
177
+
178
+ # Initialize GPT parameters
179
+ self.vocab_size = vocab_size
180
+ self.num_chars = num_chars
181
+
182
+ # GPT model parameters
183
+ self.gpt_batch_size = gpt_batch_size
184
+ self.gpt_max_audio_tokens = gpt_max_audio_tokens
185
+ self.gpt_max_text_tokens = gpt_max_text_tokens
186
+ self.gpt_max_prompt_tokens = gpt_max_prompt_tokens
187
+ self.gpt_layers = gpt_layers
188
+ self.gpt_n_model_channels = gpt_n_model_channels
189
+ self.gpt_n_heads = gpt_n_heads
190
+ self.gpt_number_text_tokens = gpt_number_text_tokens
191
+ self.gpt_start_text_token = gpt_start_text_token
192
+ self.gpt_stop_text_token = gpt_stop_text_token
193
+ self.gpt_num_audio_tokens = gpt_num_audio_tokens
194
+ self.gpt_start_audio_token = gpt_start_audio_token
195
+ self.gpt_stop_audio_token = gpt_stop_audio_token
196
+ self.gpt_code_stride_len = gpt_code_stride_len
197
+ self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach
198
+ self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler
199
+ self.gpt_checkpointing = gpt_checkpointing
200
+ self.gpt_train_solo_embeddings = gpt_train_solo_embeddings
201
+
202
+ # Training parameters
203
+ self.enable_redaction = enable_redaction
204
+ self.kv_cache = kv_cache
205
+ self.perceiver_cond_length_compression = perceiver_cond_length_compression
206
+ self.label_smoothing = label_smoothing
207
+
208
+ # Generation parameters
209
+ self.temperature = temperature
210
+ self.length_penalty = length_penalty
211
+ self.repetition_penalty = repetition_penalty
212
+ self.top_k = top_k
213
+ self.top_p = top_p
214
+ self.gpt_cond_len = gpt_cond_len
215
+ self.gpt_cond_chunk_len = gpt_cond_chunk_len
216
+ self.max_ref_len = max_ref_len
217
+ self.sound_norm_refs = sound_norm_refs
218
+
219
+ # Audio processing
220
+ if audio_config is None:
221
+ audio_config = XTTSAudioConfig()
222
+ elif isinstance(audio_config, dict):
223
+ audio_config = XTTSAudioConfig(**audio_config)
224
+ self.audio_config = audio_config
225
+
226
+ # Constants and limits
227
+ self.duration_const = duration_const
228
+ self.char_limits = char_limits
229
+ self.languages = languages
230
+
231
+ def to_dict(self) -> Dict:
232
+ """Convert the config to a dictionary format."""
233
+ # Get parent class dict
234
+ output = super().to_dict()
235
+
236
+ # Add all attributes
237
+ output.update({
238
+ # HifiGAN parameters
239
+ "input_sample_rate": self.input_sample_rate,
240
+ "output_sample_rate": self.output_sample_rate,
241
+ "output_hop_length": self.output_hop_length,
242
+ "decoder_input_dim": self.decoder_input_dim,
243
+ "d_vector_dim": self.d_vector_dim,
244
+ "cond_d_vector_in_each_upsampling_layer": self.cond_d_vector_in_each_upsampling_layer,
245
+ "upsample_rates": self.upsample_rates,
246
+ "upsample_kernel_sizes": self.upsample_kernel_sizes,
247
+ "upsample_initial_channel": self.upsample_initial_channel,
248
+ "resblock_kernel_sizes": self.resblock_kernel_sizes,
249
+ "resblock_dilation_sizes": self.resblock_dilation_sizes,
250
+ "speaker_encoder_config": self.speaker_encoder_config,
251
+
252
+ # GPT parameters
253
+ "vocab_size": self.vocab_size,
254
+ "num_chars": self.num_chars,
255
+ "gpt_batch_size": self.gpt_batch_size,
256
+ "gpt_max_audio_tokens": self.gpt_max_audio_tokens,
257
+ "gpt_max_text_tokens": self.gpt_max_text_tokens,
258
+ "gpt_max_prompt_tokens": self.gpt_max_prompt_tokens,
259
+ "gpt_layers": self.gpt_layers,
260
+ "gpt_n_model_channels": self.gpt_n_model_channels,
261
+ "gpt_n_heads": self.gpt_n_heads,
262
+ "gpt_number_text_tokens": self.gpt_number_text_tokens,
263
+ "gpt_start_text_token": self.gpt_start_text_token,
264
+ "gpt_stop_text_token": self.gpt_stop_text_token,
265
+ "gpt_num_audio_tokens": self.gpt_num_audio_tokens,
266
+ "gpt_start_audio_token": self.gpt_start_audio_token,
267
+ "gpt_stop_audio_token": self.gpt_stop_audio_token,
268
+ "gpt_code_stride_len": self.gpt_code_stride_len,
269
+ "gpt_use_masking_gt_prompt_approach": self.gpt_use_masking_gt_prompt_approach,
270
+ "gpt_use_perceiver_resampler": self.gpt_use_perceiver_resampler,
271
+ "gpt_checkpointing": self.gpt_checkpointing,
272
+ "gpt_train_solo_embeddings": self.gpt_train_solo_embeddings,
273
+ "enable_redaction": self.enable_redaction,
274
+ "kv_cache": self.kv_cache,
275
+ "perceiver_cond_length_compression": self.perceiver_cond_length_compression,
276
+ "label_smoothing": self.label_smoothing,
277
+ "temperature": self.temperature,
278
+ "length_penalty": self.length_penalty,
279
+ "repetition_penalty": self.repetition_penalty,
280
+ "top_k": self.top_k,
281
+ "top_p": self.top_p,
282
+ "gpt_cond_len": self.gpt_cond_len,
283
+ "gpt_cond_chunk_len": self.gpt_cond_chunk_len,
284
+ "max_ref_len": self.max_ref_len,
285
+ "sound_norm_refs": self.sound_norm_refs,
286
+ "audio_config": asdict(self.audio_config),
287
+ "duration_const": self.duration_const,
288
+ "char_limits": self.char_limits,
289
+ "languages": self.languages,
290
+ })
291
+
292
+ return output
293
+
294
+ @classmethod
295
+ def from_dict(cls, config_dict: Dict) -> "XTTSConfig":
296
+ """Create a config instance from a dictionary."""
297
+ config_copy = config_dict.copy()
298
+
299
+ # Handle special nested configs
300
+ if "audio_config" in config_copy:
301
+ config_copy["audio_config"] = XTTSAudioConfig(**config_copy["audio_config"])
302
+
303
+ return cls(**config_copy)
304
+
305
+ def get_speaker_encoder_config(self) -> SpeakerEncoderConfig:
306
+ """Get speaker encoder config as a SpeakerEncoderConfig instance"""
307
+ return SpeakerEncoderConfig(**self.speaker_encoder_config)
308
+
309
+ def update_with_tokenizer(self, tokenizer=None):
310
+ """Update configuration values based on tokenizer"""
311
+ if tokenizer is not None:
312
+ self.gpt_number_text_tokens = tokenizer.get_vocab_size()
313
+ self.gpt_start_text_token = tokenizer.bos_token_id
314
+ self.gpt_stop_text_token = tokenizer.eos_token_id
315
+ self.vocab_size = tokenizer.get_vocab_size()
316
+ self.pad_token_id = tokenizer.pad_token_id
317
+ self.bos_token_id = tokenizer.bos_token_id
318
+ self.eos_token_id = tokenizer.eos_token_id
319
+
320
+ def get_hifigan_config(self) -> Dict:
321
+ """Extract HiFiGAN-specific configuration"""
322
+ return {
323
+ "input_sample_rate": self.input_sample_rate,
324
+ "output_sample_rate": self.output_sample_rate,
325
+ "output_hop_length": self.output_hop_length,
326
+ "decoder_input_dim": self.decoder_input_dim,
327
+ "d_vector_dim": self.d_vector_dim,
328
+ "cond_d_vector_in_each_upsampling_layer": self.cond_d_vector_in_each_upsampling_layer,
329
+ "upsample_rates": self.upsample_rates,
330
+ "upsample_kernel_sizes": self.upsample_kernel_sizes,
331
+ "upsample_initial_channel": self.upsample_initial_channel,
332
+ "resblock_kernel_sizes": self.resblock_kernel_sizes,
333
+ "resblock_dilation_sizes": self.resblock_dilation_sizes,
334
+ "speaker_encoder_config": self.speaker_encoder_config
335
+ }
336
+
337
+ def get_gpt_config(self) -> Dict:
338
+ """Extract GPT-specific configuration"""
339
+ return {
340
+ "vocab_size": self.vocab_size,
341
+ "num_chars": self.num_chars,
342
+ "gpt_batch_size": self.gpt_batch_size,
343
+ "gpt_max_audio_tokens": self.gpt_max_audio_tokens,
344
+ "gpt_max_text_tokens": self.gpt_max_text_tokens,
345
+ "gpt_max_prompt_tokens": self.gpt_max_prompt_tokens,
346
+ "gpt_layers": self.gpt_layers,
347
+ "gpt_n_model_channels": self.gpt_n_model_channels,
348
+ "gpt_n_heads": self.gpt_n_heads,
349
+ "gpt_number_text_tokens": self.gpt_number_text_tokens,
350
+ "gpt_start_text_token": self.gpt_start_text_token,
351
+ "gpt_stop_text_token": self.gpt_stop_text_token,
352
+ "gpt_num_audio_tokens": self.gpt_num_audio_tokens,
353
+ "gpt_start_audio_token": self.gpt_start_audio_token,
354
+ "gpt_stop_audio_token": self.gpt_stop_audio_token,
355
+ "gpt_code_stride_len": self.gpt_code_stride_len,
356
+ "gpt_use_masking_gt_prompt_approach": self.gpt_use_masking_gt_prompt_approach,
357
+ "gpt_use_perceiver_resampler": self.gpt_use_perceiver_resampler,
358
+ "gpt_checkpointing": self.gpt_checkpointing,
359
+ "gpt_train_solo_embeddings": self.gpt_train_solo_embeddings,
360
+ "enable_redaction": self.enable_redaction,
361
+ "kv_cache": self.kv_cache,
362
+ "perceiver_cond_length_compression": self.perceiver_cond_length_compression,
363
+ "label_smoothing": self.label_smoothing,
364
+ "audio_config": self.audio_config,
365
+ "pad_token_id": self.pad_token_id,
366
+ "bos_token_id": self.bos_token_id,
367
+ "eos_token_id": self.eos_token_id
368
+ }
369
+
370
+ def get_generation_config(self) -> Dict:
371
+ """Extract generation-specific configuration"""
372
+ return {
373
+ "temperature": self.temperature,
374
+ "length_penalty": self.length_penalty,
375
+ "repetition_penalty": self.repetition_penalty,
376
+ "top_k": self.top_k,
377
+ "top_p": self.top_p,
378
+ "gpt_cond_len": self.gpt_cond_len,
379
+ "gpt_cond_chunk_len": self.gpt_cond_chunk_len,
380
+ "max_ref_len": self.max_ref_len,
381
+ "sound_norm_refs": self.sound_norm_refs
382
+ }
383
+
384
+ def validate(self):
385
+ """Validate configuration values"""
386
+ if self.gpt_max_text_tokens <= 0:
387
+ raise ValueError("gpt_max_text_tokens must be positive")
388
+ if self.gpt_max_audio_tokens <= 0:
389
+ raise ValueError("gpt_max_audio_tokens must be positive")
390
+ if self.gpt_layers <= 0:
391
+ raise ValueError("gpt_layers must be positive")
392
+ if self.gpt_n_heads <= 0:
393
+ raise ValueError("gpt_n_heads must be positive")
394
+ if self.gpt_n_model_channels <= 0:
395
+ raise ValueError("gpt_n_model_channels must be positive")
396
+ if len(self.upsample_rates) != len(self.upsample_kernel_sizes):
397
+ raise ValueError("upsample_rates and upsample_kernel_sizes must have same length")
398
+ if not all(isinstance(x, int) and x > 0 for x in self.upsample_rates):
399
+ raise ValueError("all upsample_rates must be positive integers")
400
+
401
+ def get_audio_config(self) -> XTTSAudioConfig:
402
+ """Get the audio configuration"""
403
+ return self.audio_config
404
+
405
+ @property
406
+ def num_hidden_layers(self) -> int:
407
+ """Get number of hidden layers (alias for gpt_layers)"""
408
+ return self.gpt_layers
409
+
410
+ @property
411
+ def hidden_size(self) -> int:
412
+ """Get hidden size (alias for gpt_n_model_channels)"""
413
+ return self.gpt_n_model_channels
414
+
415
+ @property
416
+ def num_attention_heads(self) -> int:
417
+ """Get number of attention heads (alias for gpt_n_heads)"""
418
+ return self.gpt_n_heads
xtts2_modeling.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ from typing import Optional, List, Tuple
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ import torch
6
+ import numpy as np
7
+ from transformers import PreTrainedModel
8
+
9
+ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt
10
+ from vllm.multimodal import MultiModalDataDict
11
+ from vllm.utils import Counter
12
+
13
+ from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
14
+ from gpt_config import XTTSGPTConfig
15
+ from xtts2_config import XTTSConfig
16
+ from tokenizer import XTTSTokenizerFast
17
+
18
+
19
+ @dataclass
20
+ class XTTSRequest:
21
+ """Container for XTTS inference request data"""
22
+ request_id: str
23
+ text: str
24
+ language: str
25
+ gpt_cond_latent: torch.Tensor
26
+ speaker_embedding: torch.Tensor
27
+ temperature: float = 0.75
28
+ top_p: float = 0.85
29
+ top_k: int = 50
30
+ repetition_penalty: float = 10.0
31
+ length_penalty: float = 1.0
32
+ do_sample: bool = True
33
+
34
+
35
+ @dataclass
36
+ class XTTSOutput:
37
+ """Container for XTTS inference output"""
38
+ request_id: str
39
+ wav: np.ndarray
40
+ gpt_latents: np.ndarray
41
+ speaker_embedding: torch.Tensor
42
+
43
+
44
+ class Xtts(PreTrainedModel):
45
+ """Async XTTS model implementation using VLLM's AsyncEngine."""
46
+
47
+ def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs):
48
+ self.hifi_config = hifi_config
49
+ self.gpt_config = gpt_config
50
+ self.tp = tensor_parallel_size
51
+ self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt")
52
+ self.request_counter = Counter()
53
+ self.executor = ThreadPoolExecutor(max_workers=4) # For CPU-bound tasks
54
+ self.init_models()
55
+ self.register_buffer("mel_stats", torch.ones(80))
56
+
57
+ @staticmethod
58
+ def get_memory_percentage(memory: int) -> float:
59
+ """Get memory percentage."""
60
+ return memory / torch.cuda.get_device_properties(0).total_memory
61
+
62
+ async def init_models(self):
63
+ """Initialize models with AsyncVLLMEngine."""
64
+ # Initialize VLLM engine
65
+ engine_args = AsyncEngineArgs(
66
+ model=self.gpt_config.model_dir,
67
+ tensor_parallel_size=self.tp,
68
+ dtype="auto ",
69
+ max_model_len=self.gpt_config.gpt_max_text_tokens + self.gpt_config.gpt_max_audio_tokens,
70
+ gpu_memory_utilization=self.get_memory_percentage(2),# since the model neds 2 gb we need to calc the bare minimum memory
71
+ trust_remote_code=True,
72
+ skip_tokenizer_init=True, # no need to initialize tokenizer, we use our own
73
+ max_num_batched_tokens=4096,
74
+ max_num_seqs=256,
75
+ )
76
+
77
+ self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args)
78
+ self.llm_engine = AsyncLLMEngine
79
+ # Initialize HiFi-GAN decoder
80
+ self.hifigan_decoder = HifiDecoder(
81
+ input_sample_rate=self.hifi_config.input_sample_rate,
82
+ output_sample_rate=self.hifi_config.output_sample_rate,
83
+ output_hop_length=self.hifi_config.output_hop_length,
84
+ ar_mel_length_compression=self.hifi_config.gpt_code_stride_len,
85
+ decoder_input_dim=self.hifi_config.decoder_input_dim,
86
+ d_vector_dim=self.hifi_config.d_vector_dim,
87
+ cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer,
88
+ )
89
+
90
+ @classmethod
91
+ def from_pretrained(
92
+ cls,
93
+ pretrained_model_name_or_path: str,
94
+ torch_dtype: torch.dtype = torch.float16,
95
+ device_map: Optional[str] = "auto",
96
+ tensor_parallel_size: int = 1,
97
+ **kwargs,
98
+ ) -> "Xtts":
99
+ """Load pretrained XTTS model from HuggingFace Hub.
100
+
101
+ Args:
102
+ pretrained_model_name_or_path (str): Path to pretrained weights or HF Hub model id
103
+ torch_dtype (torch.dtype, optional): Type to load the model as. Defaults to float16.
104
+ device_map (str, optional): Device mapping strategy. Defaults to "auto".
105
+ **kwargs: Additional arguments passed to the model.
106
+
107
+ Returns:
108
+ Xtts: Loaded model instance
109
+ """
110
+ from huggingface_hub import hf_hub_download
111
+ import json
112
+ import os
113
+
114
+ # Download and load configs
115
+ if not os.path.exists(pretrained_model_name_or_path):
116
+ config_file = hf_hub_download(
117
+ repo_id=pretrained_model_name_or_path,
118
+ filename="../xtts2_gpt/config.json"
119
+ )
120
+ with open(config_file, 'r') as f:
121
+ config = json.load(f)
122
+
123
+ gpt_config_file = hf_hub_download(
124
+ repo_id=pretrained_model_name_or_path,
125
+ filename="gpt_config.py"
126
+ )
127
+ with open(gpt_config_file, 'r') as f:
128
+ gpt_config = json.loads(f.read())
129
+
130
+ hifigan_config_file = hf_hub_download(
131
+ repo_id=pretrained_model_name_or_path,
132
+ filename="xtts2_config.py"
133
+ )
134
+ with open(hifigan_config_file, 'r') as f:
135
+ hifigan_config = json.loads(f.read())
136
+ else:
137
+ # Load from local path
138
+ with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f:
139
+ config = json.load(f)
140
+
141
+
142
+ # Initialize configs
143
+ gpt_config = XTTSGPTConfig(**config)
144
+ hifi_config = XTTSConfig(**config)
145
+
146
+ # Initialize model
147
+ model = cls(
148
+ hifi_config=hifi_config,
149
+ gpt_config=gpt_config,
150
+ tensor_parallel_size=tensor_parallel_size,
151
+ **kwargs
152
+ )
153
+
154
+ # Load model weights
155
+ if not os.path.exists(pretrained_model_name_or_path):
156
+ gpt_weights = hf_hub_download(
157
+ repo_id=pretrained_model_name_or_path,
158
+ filename="../xtts2_gpt/xttsv2-gpt.safetensors"
159
+ )
160
+ hifigan_weights = hf_hub_download(
161
+ repo_id=pretrained_model_name_or_path,
162
+ filename="xttsv2-hifigan-mel.safetensors"
163
+ )
164
+ else:
165
+ gpt_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-gpt.safetensors")
166
+ hifigan_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-hifigan-mel.safetensors")
167
+
168
+ # Load GPT weights
169
+ import safetensors.torch
170
+ state_dict = safetensors.torch.load_file(gpt_weights)
171
+ model.gpt.load_state_dict(state_dict)
172
+
173
+ # Load HiFi-GAN weights
174
+ hifigan_state = safetensors.torch.load_file(hifigan_weights)
175
+ model.hifigan_decoder.load_state_dict(hifigan_state)
176
+
177
+ # Set model properties
178
+ model.config = config
179
+
180
+ # Cast model to specified dtype
181
+ model = model.to(torch_dtype)
182
+
183
+ # Handle device mapping
184
+ if device_map:
185
+ from accelerate import dispatch_model
186
+ model = dispatch_model(model, device_map=device_map)
187
+
188
+ return model
189
+
190
+ def prepare_inputs(self, text: str, language: str, gpt_cond_latent: torch.Tensor) -> Tuple[List[int], torch.Tensor]:
191
+ """Prepare input text with conditioning tokens."""
192
+ # Add special tokens and conditioning format
193
+ # Format: <|condition|>latent_data<|endofcondition|>text<|endoftext|>
194
+ text_tokens = self.tokenizer.encode(text, lang=language)
195
+ return text_tokens, gpt_cond_latent
196
+
197
+
198
+
199
+ async def generate_speech_async(self, request: XTTSRequest) -> XTTSOutput:
200
+ """Generate speech for a single request asynchronously."""
201
+ # Prepare input with conditioning
202
+ tokens, gpt_cond_latent = self.prepare_inputs(
203
+ request.text,
204
+ request.language,
205
+ request.gpt_cond_latent
206
+ )
207
+
208
+ # Setup sampling parameters
209
+ sampling_params = SamplingParams(
210
+ temperature=request.temperature,
211
+ top_p=request.top_p,
212
+ top_k=request.top_k,
213
+ repetition_penalty=request.repetition_penalty,
214
+ max_tokens=self.gpt_config.gpt_max_audio_tokens,
215
+ stop=['</s>', '<|endoftext|>']
216
+ )
217
+ engine_inputs = TokensPrompt( prompt_token_ids = tokens )
218
+ if gpt_cond_latent is not None:
219
+ engine_inputs["multi_modal_data"] = MultiModalDataDict({"audio":gpt_cond_latent})
220
+ # Generate tokens using VLLM
221
+ output_generator = self.llm_engine.generate(
222
+ inputs=engine_inputs,
223
+ sampling_params=sampling_params,
224
+ request_id=request.request_id
225
+ )
226
+
227
+ async for outputs in output_generator:
228
+ # Extract generated tokens
229
+ generated_tokens = outputs.outputs[0].token_ids
230
+
231
+ # Convert to hidden states (this step depends on your model architecture)
232
+ hidden_states = await self._tokens_to_hidden_states(generated_tokens)
233
+
234
+ # Generate audio using HiFi-GAN (run in thread pool to avoid blocking)
235
+ wav = await asyncio.get_event_loop().run_in_executor(
236
+ self.executor,
237
+ lambda: self.hifigan_decoder(
238
+ hidden_states,
239
+ g=request.speaker_embedding
240
+ ).cpu().numpy().squeeze()
241
+ )
242
+
243
+ return XTTSOutput(
244
+ request_id=request.request_id,
245
+ wav=wav,
246
+ gpt_latents=hidden_states.cpu().numpy(),
247
+ speaker_embedding=request.speaker_embedding
248
+ )
249
+
250
+
251
+ async def _tokens_to_hidden_states(self, tokens: List[int]) -> torch.Tensor:
252
+ """Convert generated tokens to hidden states."""
253
+ # This implementation depends on your specific model architecture
254
+ # You'll need to adapt this based on how your model processes tokens
255
+ # This is a placeholder implementation
256
+ token_tensor = torch.tensor(tokens, device=self.device)
257
+ # Use VLLM's engine to get hidden states
258
+ hidden_states = await self.llm_engine.encode(token_tensor)
259
+ return hidden_states
xttsv2-hifigan-mel.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6eaf6c236291478363be6da06c0869551c51bf0c8983fd2dd70561a4a1f1ace3
3
+ size 103599512