mlinmg commited on
Commit
0d089b3
·
verified ·
1 Parent(s): 515a606

Delete gpt_config.py

Browse files
Files changed (1) hide show
  1. gpt_config.py +0 -172
gpt_config.py DELETED
@@ -1,172 +0,0 @@
1
- from dataclasses import asdict, dataclass, field
2
- from typing import Dict, Optional, List
3
- from transformers.configuration_utils import PretrainedConfig
4
- from transformers.utils import logging
5
-
6
- logger = logging.get_logger(__name__)
7
-
8
-
9
- @dataclass
10
- class XTTSAudioConfig:
11
- """Configuration for audio processing parameters"""
12
- sample_rate: int = 22050
13
- output_sample_rate: int = 24000
14
- mel_channels: int = 80
15
- hop_length: int = 256
16
- win_length: int = 1024
17
- n_fft: int = 1024
18
- fmin: int = 0
19
- fmax: int = 8000
20
- power: float = 1.0
21
- mel_norms_file: Optional[str] = None
22
-
23
-
24
- class XTTSGPTConfig(PretrainedConfig):
25
- """Configuration class for the GPT component of XTTS"""
26
- model_type = "xtts_gpt"
27
-
28
- def __init__(
29
- self,
30
- # Model architecture
31
- vocab_size: int = 256,
32
- num_chars: int = 255,
33
-
34
- # GPT parameters
35
- gpt_batch_size: int = 1,
36
- gpt_max_audio_tokens: int = 605,
37
- gpt_max_text_tokens: int = 402,
38
- gpt_max_prompt_tokens: int = 70,
39
- gpt_layers: int = 30,
40
- gpt_n_model_channels: int = 1024,
41
- gpt_n_heads: int = 16,
42
- gpt_number_text_tokens: int = 6681,
43
- gpt_start_text_token: Optional[int] = None,
44
- gpt_stop_text_token: Optional[int] = None,
45
- gpt_num_audio_tokens: int = 1026,
46
- gpt_start_audio_token: int = 1024,
47
- gpt_stop_audio_token: int = 1025,
48
- gpt_code_stride_len: int = 1024,
49
- gpt_use_masking_gt_prompt_approach: bool = True,
50
- gpt_use_perceiver_resampler: bool = True,
51
- gpt_checkpointing: bool = False,
52
- gpt_train_solo_embeddings: bool = False,
53
-
54
- # Training parameters
55
- enable_redaction: bool = False,
56
- kv_cache: bool = True,
57
- perceiver_cond_length_compression: int = 256,
58
- label_smoothing: float = 0.0,
59
-
60
- # Generation parameters
61
- temperature: float = 0.75,
62
- length_penalty: float = 1.0,
63
- repetition_penalty: float = 5.0,
64
- top_k: int = 50,
65
- top_p: float = 0.85,
66
- gpt_cond_len: int = 30,
67
- gpt_cond_chunk_len: int = 4,
68
- max_ref_len: int = 30,
69
- sound_norm_refs: bool = False,
70
-
71
- # Audio processing
72
- audio_config: Optional[XTTSAudioConfig] = None,
73
-
74
- # Constants and limits
75
- duration_const: int = 102400,
76
- char_limits: Optional[Dict[str, int]] = None,
77
- languages: Optional[List[str]] = None,
78
- pad_token_id: Optional[int] = None,
79
- bos_token_id: Optional[int] = None,
80
- eos_token_id: Optional[int] = None,
81
- **kwargs,
82
- ):
83
- if char_limits is None:
84
- char_limits = {
85
- "en": 250, "de": 253, "fr": 273, "es": 239,
86
- "it": 213, "pt": 203, "pl": 224, "zh": 82,
87
- "ar": 166, "cs": 186, "ru": 182, "nl": 251,
88
- "tr": 226, "ja": 71, "hu": 224, "ko": 95,
89
- }
90
-
91
- if languages is None:
92
- languages = [
93
- "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
94
- "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
95
- ]
96
-
97
- if audio_config is None:
98
- audio_config = XTTSAudioConfig()
99
-
100
- super().__init__(
101
- pad_token_id=pad_token_id,
102
- bos_token_id=bos_token_id,
103
- eos_token_id=eos_token_id,
104
- **kwargs
105
- )
106
-
107
- self.vocab_size = vocab_size
108
- self.num_chars = num_chars
109
-
110
- # GPT parameters
111
- self.gpt_batch_size = gpt_batch_size
112
- self.gpt_max_audio_tokens = gpt_max_audio_tokens
113
- self.gpt_max_text_tokens = gpt_max_text_tokens
114
- self.gpt_max_prompt_tokens = gpt_max_prompt_tokens
115
- self.gpt_layers = gpt_layers
116
- self.gpt_n_model_channels = gpt_n_model_channels
117
- self.gpt_n_heads = gpt_n_heads
118
- self.gpt_number_text_tokens = gpt_number_text_tokens
119
- self.gpt_start_text_token = gpt_start_text_token
120
- self.gpt_stop_text_token = gpt_stop_text_token
121
- self.gpt_num_audio_tokens = gpt_num_audio_tokens
122
- self.gpt_start_audio_token = gpt_start_audio_token
123
- self.gpt_stop_audio_token = gpt_stop_audio_token
124
- self.gpt_code_stride_len = gpt_code_stride_len
125
- self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach
126
- self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler
127
- self.gpt_checkpointing = gpt_checkpointing
128
- self.gpt_train_solo_embeddings = gpt_train_solo_embeddings
129
-
130
- # Training parameters
131
- self.enable_redaction = enable_redaction
132
- self.kv_cache = kv_cache
133
- self.perceiver_cond_length_compression = perceiver_cond_length_compression
134
- self.label_smoothing = label_smoothing
135
-
136
- # Generation parameters
137
- self.temperature = temperature
138
- self.length_penalty = length_penalty
139
- self.repetition_penalty = repetition_penalty
140
- self.top_k = top_k
141
- self.top_p = top_p
142
- self.gpt_cond_len = gpt_cond_len
143
- self.gpt_cond_chunk_len = gpt_cond_chunk_len
144
- self.max_ref_len = max_ref_len
145
- self.sound_norm_refs = sound_norm_refs
146
-
147
- # Audio processing
148
- self.audio_config = audio_config
149
-
150
- # Constants and limits
151
- self.duration_const = duration_const
152
- self.char_limits = char_limits
153
- self.languages = languages
154
-
155
- def to_dict(self):
156
- """Convert config to dictionary"""
157
- config_dict = super().to_dict()
158
- config_dict["audio_config"] = asdict(self.audio_config)
159
- return config_dict
160
-
161
- @classmethod
162
- def from_dict(cls, config_dict):
163
- """Create config from dictionary"""
164
- audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {}))
165
- return cls(audio_config=audio_config, **config_dict)
166
-
167
- def update_with_tokenizer(self, tokenizer=None):
168
- """Update configuration values based on tokenizer"""
169
- if tokenizer is not None:
170
- self.gpt_number_text_tokens = tokenizer.get_vocab_size()
171
- self.gpt_start_text_token = tokenizer.bos_token_id
172
- self.gpt_stop_text_token = tokenizer.eos_token_id