sbapan41 commited on
Commit
46270db
·
verified ·
1 Parent(s): 1ca37ed

Delete qhash

Browse files
qhash/autoencoder.py DELETED
@@ -1,26 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torchaudio
5
- from transformers.models.dac import DacModel
6
-
7
-
8
- class DACAutoencoder:
9
- def __init__(self):
10
- super().__init__()
11
- self.dac = DacModel.from_pretrained("Quantamhash/dac_44khz")
12
- self.dac.eval().requires_grad_(False)
13
- self.codebook_size = self.dac.config.codebook_size
14
- self.num_codebooks = self.dac.quantizer.n_codebooks
15
- self.sampling_rate = self.dac.config.sampling_rate
16
-
17
- def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
18
- wav = torchaudio.functional.resample(wav, sr, 44_100)
19
- right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
20
- return torch.nn.functional.pad(wav, (0, right_pad))
21
-
22
- def encode(self, wav: torch.Tensor) -> torch.Tensor:
23
- return self.dac.encode(wav).audio_codes
24
-
25
- def decode(self, codes: torch.Tensor) -> torch.Tensor:
26
- return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qhash/backbone.py DELETED
@@ -1,50 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from mamba_ssm.models.mixer_seq_simple import create_block
4
- from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
5
- from mamba_ssm.utils.generation import InferenceParams
6
-
7
- from qhash.config import BackboneConfig
8
-
9
-
10
- class ZonosBackbone(nn.Module):
11
- def __init__(self, config: BackboneConfig):
12
- super().__init__()
13
- self.config = config
14
-
15
- self.layers = nn.ModuleList(
16
- [
17
- create_block(
18
- d_model=config.d_model,
19
- d_intermediate=config.d_intermediate
20
- if (i not in config.attn_layer_idx)
21
- else config.attn_mlp_d_intermediate,
22
- ssm_cfg=config.ssm_cfg,
23
- layer_idx=i,
24
- attn_layer_idx=config.attn_layer_idx,
25
- attn_cfg=config.attn_cfg,
26
- norm_epsilon=config.norm_epsilon,
27
- residual_in_fp32=config.residual_in_fp32,
28
- fused_add_norm=True,
29
- rms_norm=config.rms_norm,
30
- )
31
- for i in range(config.n_layer)
32
- ]
33
- )
34
-
35
- self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
36
-
37
- def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
38
- residual = None
39
- for layer in self.layers:
40
- hidden_states, residual = layer(hidden_states, residual, inference_params)
41
-
42
- return layer_norm_fn(
43
- hidden_states,
44
- self.norm_f.weight,
45
- self.norm_f.bias,
46
- residual,
47
- eps=self.norm_f.eps,
48
- residual_in_fp32=self.config.residual_in_fp32,
49
- is_rms_norm=self.config.rms_norm,
50
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qhash/codebook_pattern.py DELETED
@@ -1,12 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
-
4
-
5
- def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
6
- codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
7
- return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
8
-
9
-
10
- def revert_delay_pattern(codes: torch.Tensor):
11
- _, n_q, seq_len = codes.shape
12
- return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
qhash/conditioning.py DELETED
@@ -1,373 +0,0 @@
1
- from functools import cache
2
- from typing import Any, Literal, Iterable
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from qhash.config import PrefixConditionerConfig
8
-
9
-
10
- class Conditioner(nn.Module):
11
- def __init__(
12
- self,
13
- output_dim: int,
14
- name: str,
15
- cond_dim: int | None = None,
16
- projection: Literal["none", "linear", "mlp"] = "none",
17
- uncond_type: Literal["learned", "none"] = "none",
18
- **kwargs,
19
- ):
20
- super().__init__()
21
- self.name = name
22
- self.output_dim = output_dim
23
- self.cond_dim = cond_dim = cond_dim or output_dim
24
-
25
- if projection == "linear":
26
- self.project = nn.Linear(cond_dim, output_dim)
27
- elif projection == "mlp":
28
- self.project = nn.Sequential(
29
- nn.Linear(cond_dim, output_dim),
30
- nn.SiLU(),
31
- nn.Linear(output_dim, output_dim),
32
- )
33
- else:
34
- self.project = nn.Identity()
35
-
36
- self.uncond_vector = None
37
- if uncond_type == "learned":
38
- self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
39
-
40
- def apply_cond(self, *inputs: Any) -> torch.Tensor:
41
- raise NotImplementedError()
42
-
43
- def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
44
- if inputs is None:
45
- assert self.uncond_vector is not None
46
- return self.uncond_vector.data.view(1, 1, -1)
47
-
48
- cond = self.apply_cond(*inputs)
49
- cond = self.project(cond)
50
- return cond
51
-
52
-
53
- # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
54
- import re
55
- import unicodedata
56
-
57
- import inflect
58
- import torch
59
- import torch.nn as nn
60
- from kanjize import number2kanji
61
- from phonemizer.backend import EspeakBackend
62
- from sudachipy import Dictionary, SplitMode
63
-
64
- # --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
65
-
66
- _inflect = inflect.engine()
67
- _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
68
- _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
69
- _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
70
- _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
71
- _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
72
- _number_re = re.compile(r"[0-9]+")
73
-
74
-
75
- def _remove_commas(m: re.Match) -> str:
76
- return m.group(1).replace(",", "")
77
-
78
-
79
- def _expand_decimal_point(m: re.Match) -> str:
80
- return m.group(1).replace(".", " point ")
81
-
82
-
83
- def _expand_dollars(m: re.Match) -> str:
84
- match = m.group(1)
85
- parts = match.split(".")
86
- if len(parts) > 2:
87
- return match + " dollars" # Unexpected format
88
- dollars = int(parts[0]) if parts[0] else 0
89
- cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
90
- if dollars and cents:
91
- dollar_unit = "dollar" if dollars == 1 else "dollars"
92
- cent_unit = "cent" if cents == 1 else "cents"
93
- return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
94
- elif dollars:
95
- dollar_unit = "dollar" if dollars == 1 else "dollars"
96
- return "%s %s" % (dollars, dollar_unit)
97
- elif cents:
98
- cent_unit = "cent" if cents == 1 else "cents"
99
- return "%s %s" % (cents, cent_unit)
100
- else:
101
- return "zero dollars"
102
-
103
-
104
- def _expand_ordinal(m: re.Match) -> str:
105
- return _inflect.number_to_words(m.group(0))
106
-
107
-
108
- def _expand_number(m: re.Match) -> str:
109
- num = int(m.group(0))
110
- if num > 1000 and num < 3000:
111
- if num == 2000:
112
- return "two thousand"
113
- elif num > 2000 and num < 2010:
114
- return "two thousand " + _inflect.number_to_words(num % 100)
115
- elif num % 100 == 0:
116
- return _inflect.number_to_words(num // 100) + " hundred"
117
- else:
118
- return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
119
- else:
120
- return _inflect.number_to_words(num, andword="")
121
-
122
-
123
- def normalize_numbers(text: str) -> str:
124
- text = re.sub(_comma_number_re, _remove_commas, text)
125
- text = re.sub(_pounds_re, r"\1 pounds", text)
126
- text = re.sub(_dollars_re, _expand_dollars, text)
127
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
128
- text = re.sub(_ordinal_re, _expand_ordinal, text)
129
- text = re.sub(_number_re, _expand_number, text)
130
- return text
131
-
132
-
133
- # --- Number normalization code end ---
134
-
135
-
136
- PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
137
- SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
138
-
139
- _punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
140
- _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
141
- _letters_ipa = (
142
- "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
143
- )
144
-
145
- symbols = [*_punctuation, *_letters, *_letters_ipa]
146
- _symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
147
-
148
-
149
- def _get_symbol_id(s: str) -> int:
150
- return _symbol_to_id.get(s, 1)
151
-
152
-
153
- def get_symbol_ids(text: str) -> list[int]:
154
- return list(map(_get_symbol_id, text))
155
-
156
-
157
- def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
158
- phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
159
- lengths = list(map(len, phoneme_ids))
160
- longest = max(lengths)
161
- phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
162
- return torch.tensor(phoneme_ids), lengths
163
-
164
-
165
- def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
166
- text = unicodedata.normalize("NFKC", text)
167
- text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
168
- final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
169
- return final_text
170
-
171
-
172
- def clean(texts: list[str], languages: list[str]) -> list[str]:
173
- texts_out = []
174
- for text, language in zip(texts, languages):
175
- if "ja" in language:
176
- text = normalize_jp_text(text)
177
- else:
178
- text = normalize_numbers(text)
179
- texts_out.append(text)
180
- return texts_out
181
-
182
-
183
- @cache
184
- def get_backend(language: str) -> "EspeakBackend":
185
- import logging
186
-
187
- from phonemizer.backend import EspeakBackend
188
-
189
- logger = logging.getLogger("phonemizer")
190
- backend = EspeakBackend(
191
- language,
192
- preserve_punctuation=True,
193
- with_stress=True,
194
- punctuation_marks=_punctuation,
195
- logger=logger,
196
- )
197
- logger.setLevel(logging.ERROR)
198
- return backend
199
-
200
-
201
- def phonemize(texts: list[str], languages: list[str]) -> list[str]:
202
- texts = clean(texts, languages)
203
-
204
- batch_phonemes = []
205
- for text, language in zip(texts, languages):
206
- backend = get_backend(language)
207
- phonemes = backend.phonemize([text], strip=True)
208
- batch_phonemes.append(phonemes[0])
209
-
210
- return batch_phonemes
211
-
212
-
213
- class EspeakPhonemeConditioner(Conditioner):
214
- def __init__(self, output_dim: int, **kwargs):
215
- super().__init__(output_dim, **kwargs)
216
- self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
217
-
218
- def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
219
- """
220
- Args:
221
- texts: list of texts to convert to phonemes
222
- languages: ISO 639-1 -or otherwise eSpeak compatible- language code
223
- """
224
- device = self.phoneme_embedder.weight.device
225
-
226
- phonemes = phonemize(texts, languages)
227
- phoneme_ids, _ = tokenize_phonemes(phonemes)
228
- phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
229
-
230
- return phoneme_embeds
231
-
232
-
233
- # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
234
-
235
-
236
- class FourierConditioner(Conditioner):
237
- def __init__(
238
- self,
239
- output_dim: int,
240
- input_dim: int = 1,
241
- std: float = 1.0,
242
- min_val: float = 0.0,
243
- max_val: float = 1.0,
244
- **kwargs,
245
- ):
246
- assert output_dim % 2 == 0
247
- super().__init__(output_dim, **kwargs)
248
- self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
249
- self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
250
-
251
- def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
252
- assert x.shape[-1] == self.input_dim
253
- x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
254
- f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
255
- return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
256
-
257
-
258
- class IntegerConditioner(Conditioner):
259
- def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
260
- super().__init__(output_dim, **kwargs)
261
- self.min_val = min_val
262
- self.max_val = max_val
263
- self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
264
-
265
- def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
266
- assert x.shape[-1] == 1
267
- return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
268
-
269
-
270
- class PassthroughConditioner(Conditioner):
271
- def __init__(self, output_dim: int, **kwargs):
272
- super().__init__(output_dim, **kwargs)
273
-
274
- def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
275
- assert x.shape[-1] == self.cond_dim
276
- return x
277
-
278
-
279
- _cond_cls_map = {
280
- "PassthroughConditioner": PassthroughConditioner,
281
- "EspeakPhonemeConditioner": EspeakPhonemeConditioner,
282
- "FourierConditioner": FourierConditioner,
283
- "IntegerConditioner": IntegerConditioner,
284
- }
285
-
286
-
287
- def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
288
- return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
289
-
290
-
291
- class PrefixConditioner(Conditioner):
292
- def __init__(self, config: PrefixConditionerConfig, output_dim: int):
293
- super().__init__(output_dim, "prefix", projection=config.projection)
294
- self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
295
- self.norm = nn.LayerNorm(output_dim)
296
- self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
297
-
298
- def forward(self, cond_dict: dict) -> torch.Tensor:
299
- if not set(cond_dict).issuperset(self.required_keys):
300
- raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
301
- conds = []
302
- for conditioner in self.conditioners:
303
- conds.append(conditioner(cond_dict.get(conditioner.name)))
304
- max_bsz = max(map(len, conds))
305
- assert all(c.shape[0] in (max_bsz, 1) for c in conds)
306
- conds = [c.expand(max_bsz, -1, -1) for c in conds]
307
- return self.norm(self.project(torch.cat(conds, dim=-2)))
308
-
309
-
310
- supported_language_codes = [
311
- 'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
312
- 'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
313
- 'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
314
- 'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
315
- 'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
316
- 'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
317
- 'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
318
- 'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
319
- 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
320
- 'vi-vn-x-central', 'vi-vn-x-south', 'yue'
321
- ] # fmt: off
322
-
323
-
324
- def make_cond_dict(
325
- text: str = "It would be nice to have time for testing, indeed.",
326
- language: str = "en-us",
327
- speaker: torch.Tensor | None = None,
328
- emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
329
- fmax: float = 22050.0,
330
- pitch_std: float = 20.0,
331
- speaking_rate: float = 15.0,
332
- vqscore_8: list[float] = [0.78] * 8,
333
- ctc_loss: float = 0.0,
334
- dnsmos_ovrl: float = 4.0,
335
- speaker_noised: bool = False,
336
- unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
337
- device: str = "cuda",
338
- ) -> dict:
339
- """
340
- A helper to build the 'cond_dict' that the model expects.
341
- By default, it will generate a random speaker embedding
342
- """
343
- assert language.lower() in supported_language_codes, "Please pick a supported language"
344
-
345
- language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
346
-
347
- cond_dict = {
348
- "espeak": ([text], [language]),
349
- "speaker": speaker,
350
- "emotion": emotion,
351
- "fmax": fmax,
352
- "pitch_std": pitch_std,
353
- "speaking_rate": speaking_rate,
354
- "language_id": language_code_to_id[language],
355
- "vqscore_8": vqscore_8,
356
- "ctc_loss": ctc_loss,
357
- "dnsmos_ovrl": dnsmos_ovrl,
358
- "speaker_noised": int(speaker_noised),
359
- }
360
-
361
- for k in unconditional_keys:
362
- cond_dict.pop(k, None)
363
-
364
- for k, v in cond_dict.items():
365
- if isinstance(v, (float, int, list)):
366
- v = torch.tensor(v)
367
- if isinstance(v, torch.Tensor):
368
- cond_dict[k] = v.view(1, 1, -1).to(device)
369
-
370
- if k == "emotion":
371
- cond_dict[k] /= cond_dict[k].sum(dim=-1)
372
-
373
- return cond_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qhash/config.py DELETED
@@ -1,38 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Literal
3
-
4
-
5
- @dataclass
6
- class BackboneConfig:
7
- d_model: int = 1024
8
- d_intermediate: int = 0
9
- attn_mlp_d_intermediate: int = 0
10
- n_layer: int = 16
11
- ssm_cfg: dict = field(default_factory=dict)
12
- attn_layer_idx: list = field(default_factory=list)
13
- attn_cfg: dict = field(default_factory=dict)
14
- rms_norm: bool = False
15
- residual_in_fp32: bool = False
16
- norm_epsilon: float = 1e-5
17
-
18
-
19
- @dataclass
20
- class PrefixConditionerConfig:
21
- conditioners: list[dict]
22
- projection: Literal["none", "linear", "mlp"]
23
-
24
-
25
- @dataclass
26
- class ZonosConfig:
27
- backbone: BackboneConfig
28
- prefix_conditioner: PrefixConditionerConfig
29
- eos_token_id: int = 1024
30
- masked_token_id: int = 1025
31
-
32
- @classmethod
33
- def from_dict(cls, d: dict) -> "ZonosConfig":
34
- d = d.copy()
35
- backbone_config = BackboneConfig(**d.pop("backbone"))
36
- prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
37
- config = cls(backbone_config, prefix_conditioner_config, **d)
38
- return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qhash/model.py DELETED
@@ -1,270 +0,0 @@
1
- import json
2
- from typing import Callable
3
-
4
- import safetensors
5
- import torch
6
- import torch.nn as nn
7
- from huggingface_hub import hf_hub_download
8
- from mamba_ssm.utils.generation import InferenceParams
9
- from tqdm import tqdm
10
-
11
- from qhash.backbone import ZonosBackbone
12
- from qhash.autoencoder import DACAutoencoder
13
- from qhash.codebook_pattern import apply_delay_pattern, revert_delay_pattern
14
- from qhash.conditioning import PrefixConditioner
15
- from qhash.config import ZonosConfig
16
- from qhash.sampling import sample_from_logits
17
- from qhash.speaker_cloning import SpeakerEmbeddingLDA
18
-
19
-
20
- class Zonos(nn.Module):
21
- def __init__(self, config: ZonosConfig):
22
- super().__init__()
23
- self.config = config
24
- dim = config.backbone.d_model
25
- self.eos_token_id = config.eos_token_id
26
- self.masked_token_id = config.masked_token_id
27
-
28
- self.autoencoder = DACAutoencoder()
29
- self.backbone = ZonosBackbone(config.backbone)
30
- self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
31
- self.spk_clone_model = None
32
-
33
- # TODO: pad to multiple of at least 8
34
- self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
35
- self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
36
-
37
- self._cg_graph = None
38
- self._cg_batch_size = None
39
- self._cg_input_ids = None
40
- self._cg_logits = None
41
- self._cg_inference_params = None
42
- self._cg_scale = None
43
-
44
- @classmethod
45
- def from_pretrained(cls, repo_id: str, revision: str | None = None, device: str = "cuda") -> "Zonos":
46
- config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
47
- model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
48
- return cls.from_local(config_path, model_path, device)
49
-
50
- @classmethod
51
- def from_local(cls, config_path: str, model_path: str, device: str = "cuda") -> "Zonos":
52
- config = ZonosConfig.from_dict(json.load(open(config_path)))
53
- model = cls(config).to(device, torch.bfloat16)
54
- model.autoencoder.dac.to(device)
55
-
56
- sd = model.state_dict()
57
- with safetensors.safe_open(model_path, framework="pt") as f:
58
- for k in f.keys():
59
- sd[k] = f.get_tensor(k)
60
- model.load_state_dict(sd)
61
-
62
- return model
63
-
64
- def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
65
- """Generate a speaker embedding from an audio clip."""
66
- if self.spk_clone_model is None:
67
- self.spk_clone_model = SpeakerEmbeddingLDA()
68
- _, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
69
- return spk_embedding.unsqueeze(0).bfloat16()
70
-
71
- def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
72
- return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
73
-
74
- def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
75
- return torch.stack([head(hidden_states) for head in self.heads], dim=1)
76
-
77
- def _compute_logits(
78
- self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
79
- ) -> torch.Tensor:
80
- """
81
- Pass `hidden_states` into `backbone` and `multi_head`, applying
82
- classifier-free guidance if `cfg_scale != 1.0`.
83
- """
84
- last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
85
- logits = self.apply_heads(last_hidden_states).squeeze(2).float()
86
- if cfg_scale != 1.0:
87
- cond_logits, uncond_logits = logits.chunk(2)
88
- logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
89
- return logits
90
-
91
- def _decode_one_token(
92
- self,
93
- input_ids: torch.Tensor,
94
- inference_params: InferenceParams,
95
- cfg_scale: float,
96
- ) -> torch.Tensor:
97
- """
98
- Single-step decode. Prepares the hidden states, possibly replicates them
99
- for CFG, and then delegates to `_compute_logits`.
100
-
101
- Below we wrap this function with a simple CUDA Graph capturing mechanism,
102
- doing 3 warmup steps if needed and then capturing or replaying the graph.
103
- We only recapture if the batch size changes.
104
- """
105
- # TODO: support cfg_scale==1
106
- if cfg_scale == 1.0:
107
- hidden_states = self.embed_codes(input_ids)
108
- return self._compute_logits(hidden_states, inference_params, cfg_scale)
109
-
110
- bsz = input_ids.size(0)
111
-
112
- need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
113
-
114
- if need_capture:
115
- self._cg_graph = None
116
-
117
- self._cg_batch_size = bsz
118
- self._cg_inference_params = inference_params
119
- self._cg_scale = cfg_scale
120
-
121
- for _ in range(3):
122
- hidden_states = self.embed_codes(input_ids)
123
- hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
124
- logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
125
-
126
- self._cg_input_ids = input_ids.clone()
127
- self._cg_logits = torch.empty_like(logits)
128
-
129
- g = torch.cuda.CUDAGraph()
130
-
131
- def capture_region():
132
- hidden_states_local = self.embed_codes(self._cg_input_ids)
133
- hidden_states_local = hidden_states_local.repeat(2, 1, 1)
134
- self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
135
-
136
- with torch.cuda.graph(g):
137
- capture_region()
138
-
139
- self._cg_graph = g
140
-
141
- else:
142
- self._cg_input_ids.copy_(input_ids)
143
-
144
- self._cg_graph.replay()
145
-
146
- return self._cg_logits
147
-
148
- def _prefill(
149
- self,
150
- prefix_hidden_states: torch.Tensor,
151
- input_ids: torch.Tensor,
152
- inference_params: InferenceParams,
153
- cfg_scale: float,
154
- ) -> torch.Tensor:
155
- """
156
- "Prefill" mode: we already have `prefix_hidden_states`, and we want
157
- to append new embeddings, then compute the logits.
158
- """
159
- # Replicate input_ids if CFG is enabled
160
- if cfg_scale != 1.0:
161
- input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
162
- hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
163
- return self._compute_logits(hidden_states, inference_params, cfg_scale)
164
-
165
- def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
166
- key_value_memory_dict = {
167
- i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
168
- for i, layer in enumerate(self.backbone.layers)
169
- }
170
- lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32, device="cuda")
171
- return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
172
-
173
- def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
174
- if uncond_dict is None:
175
- uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
176
- return torch.cat(
177
- [
178
- self.prefix_conditioner(cond_dict),
179
- self.prefix_conditioner(uncond_dict),
180
- ]
181
- )
182
-
183
- @torch.inference_mode()
184
- def generate(
185
- self,
186
- prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
187
- audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
188
- max_new_tokens: int = 86 * 30,
189
- cfg_scale: float = 2.0,
190
- batch_size: int = 1,
191
- sampling_params: dict = dict(min_p=0.1),
192
- progress_bar: bool = True,
193
- callback: Callable[[torch.Tensor, int, int], bool] | None = None,
194
- ):
195
- assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
196
- prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
197
-
198
- unknown_token = -1
199
- audio_seq_len = prefix_audio_len + max_new_tokens
200
- seq_len = prefix_conditioning.shape[1] + audio_seq_len
201
-
202
- inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
203
-
204
- codes = torch.full((batch_size, 9, audio_seq_len), unknown_token, device="cuda")
205
- if audio_prefix_codes is not None:
206
- codes[..., :prefix_audio_len] = audio_prefix_codes
207
-
208
- delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
209
-
210
- delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
211
-
212
- logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
213
- next_token = sample_from_logits(logits, **sampling_params)
214
-
215
- offset = delayed_prefix_audio_codes.shape[2]
216
- frame = delayed_codes[..., offset : offset + 1]
217
- frame.masked_scatter_(frame == unknown_token, next_token)
218
-
219
- prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
220
- inference_params.seqlen_offset += prefix_length
221
- inference_params.lengths_per_sample[:] += prefix_length
222
-
223
- logit_bias = torch.zeros_like(logits)
224
- logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
225
-
226
- stopping = torch.zeros(batch_size, dtype=torch.bool, device="cuda")
227
- max_steps = delayed_codes.shape[2] - offset
228
- remaining_steps = torch.full((batch_size,), max_steps, device="cuda")
229
- progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
230
-
231
- step = 0
232
- while torch.max(remaining_steps) > 0:
233
- offset += 1
234
- input_ids = delayed_codes[..., offset - 1 : offset]
235
- logits = self._decode_one_token(input_ids, inference_params, cfg_scale)
236
-
237
- next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
238
- eos_in_cb0 = next_token[:, 0] == self.eos_token_id
239
-
240
- remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
241
- stopping |= eos_in_cb0[:, 0]
242
-
243
- eos_codebook_idx = 9 - remaining_steps
244
- eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
245
- for i in range(next_token.shape[0]):
246
- if stopping[i]:
247
- idx = eos_codebook_idx[i].item()
248
- next_token[i, :idx] = self.masked_token_id
249
- next_token[i, idx] = self.eos_token_id
250
-
251
- frame = delayed_codes[..., offset : offset + 1]
252
- frame.masked_scatter_(frame == unknown_token, next_token)
253
- inference_params.seqlen_offset += 1
254
- inference_params.lengths_per_sample[:] += 1
255
-
256
- remaining_steps -= 1
257
-
258
- progress.update()
259
- step += 1
260
-
261
- if callback is not None and not callback(frame, step, max_steps):
262
- break
263
-
264
- out_codes = revert_delay_pattern(delayed_codes)
265
- out_codes.masked_fill_(out_codes >= 1024, 0)
266
- out_codes = out_codes[..., : offset - 9]
267
-
268
- self._cg_graph = None # reset cuda graph to avoid cache changes
269
-
270
- return out_codes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qhash/sampling.py DELETED
@@ -1,141 +0,0 @@
1
- import torch
2
-
3
-
4
- def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
5
- """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
6
-
7
- Args:
8
- input (torch.Tensor): The input tensor containing probabilities.
9
- num_samples (int): Number of samples to draw.
10
- replacement (bool): Whether to draw with replacement or not.
11
- Keywords args:
12
- generator (torch.Generator): A pseudorandom number generator for sampling.
13
- Returns:
14
- torch.Tensor: Last dimension contains num_samples indices
15
- sampled from the multinomial probability distribution
16
- located in the last dimension of tensor input.
17
- """
18
-
19
- if num_samples == 1:
20
- q = torch.empty_like(input).exponential_(1, generator=generator)
21
- return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
22
-
23
- input_ = input.reshape(-1, input.shape[-1])
24
- output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
25
- output = output_.reshape(*list(input.shape[:-1]), -1)
26
- return output
27
-
28
-
29
- def apply_top_k(
30
- probs: torch.Tensor,
31
- k: int,
32
- ) -> torch.Tensor:
33
- """Sample next token from top K values along the last dimension of the input probs tensor.
34
-
35
- Args:
36
- probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
37
- k (int): The k in “top-k”.
38
- Returns:
39
- torch.Tensor: Sampled tokens.
40
- """
41
- v, _ = torch.topk(probs, min(k, probs.size(-1)))
42
- pivot = v.select(-1, -1).unsqueeze(-1)
43
- probs = torch.where(probs < pivot, 0.0, probs)
44
- probs.div_(probs.sum(dim=-1, keepdim=True))
45
- return probs
46
-
47
-
48
- def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
49
- """Sample next token from top P probabilities along the last dimension of the input probs tensor.
50
-
51
- Args:
52
- probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
53
- p (int): The p in “top-p”.
54
- Returns:
55
- torch.Tensor: Sampled tokens.
56
- """
57
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
58
- probs_sum = torch.cumsum(probs_sort, dim=-1)
59
- mask = probs_sum - probs_sort > p
60
- probs_sort *= (~mask).float()
61
- probs = probs.scatter(-1, probs_idx, probs_sort)
62
- probs.div_(probs.sum(dim=-1, keepdim=True))
63
- return probs
64
-
65
-
66
- def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
67
- """Sample next token using min-p sampling.
68
-
69
- Args:
70
- scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
71
- min_p (float): Minimum token probability, scaled by the probability of the most likely token.
72
- Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
73
- Returns:
74
- torch.Tensor: Sampled tokens.
75
- """
76
- top_probs, _ = probs.max(dim=-1, keepdim=True)
77
- tokens_to_remove = probs < (min_p * top_probs)
78
- probs = probs.masked_fill(tokens_to_remove, 0.0)
79
- probs.div_(probs.sum(dim=-1, keepdim=True))
80
- return probs
81
-
82
-
83
- def modify_logit_for_repetition_penalty(
84
- logits: torch.Tensor,
85
- generated_tokens: torch.Tensor,
86
- repetition_penalty: float,
87
- repetition_penalty_window: int,
88
- ):
89
- """See https://arxiv.org/abs/1909.05858
90
- Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
91
- logits: (batch_size, n_codebooks, vocab_size)
92
- generated_tokens: (batch_size, n_codebooks, seq_len)
93
- """
94
- generated_tokens = generated_tokens[..., -repetition_penalty_window:]
95
- generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
96
- rp = torch.full_like(logits, repetition_penalty)
97
- factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
98
- return torch.where(logits <= 0, logits * factors, logits / factors)
99
-
100
-
101
- def sample_from_logits(
102
- logits: torch.Tensor,
103
- temperature: float = 1.0,
104
- top_p: float = 0.0,
105
- top_k: int = 0,
106
- min_p: float = 0.0,
107
- generated_tokens: torch.Tensor | None = None,
108
- repetition_penalty: float = 3.0,
109
- repetition_penalty_window: float = 2,
110
- ) -> torch.Tensor:
111
- """Sample next token from logits using temperature, top-p, top-k, or min-p sampling.
112
-
113
- Args:
114
- logits (torch.Tensor): Input logits with token candidates on the last dimension.
115
- temperature (float): Sampling temperature. Lower temperature results in more deterministic samples.
116
- top_p (float): The p in “top-p”.
117
- top_k (int): The k in “top-k”.
118
- min_p (float): Minimum token probability, scaled by the probability of the most likely token.
119
- Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
120
-
121
- Returns:
122
- torch.Tensor: Sampled tokens.
123
- """
124
- if repetition_penalty != 1.0 and generated_tokens is not None:
125
- logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
126
-
127
- if temperature > 0:
128
- probs = torch.softmax(logits / temperature, dim=-1)
129
-
130
- if top_p > 0:
131
- probs = apply_top_p(probs, top_p)
132
- if top_k > 0:
133
- probs = apply_top_k(probs, top_k)
134
- if min_p > 0:
135
- probs = apply_min_p(probs, min_p)
136
-
137
- next_token = multinomial(probs, num_samples=1)
138
- else:
139
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
140
-
141
- return next_token # [batch_size, num_codebooks, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qhash/speaker_cloning.py DELETED
@@ -1,406 +0,0 @@
1
- import math
2
- from functools import cache
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import torchaudio
8
- from huggingface_hub import hf_hub_download
9
- import os
10
-
11
-
12
- class logFbankCal(nn.Module):
13
- def __init__(
14
- self,
15
- sample_rate: int = 16_000,
16
- n_fft: int = 512,
17
- win_length: float = 0.025,
18
- hop_length: float = 0.01,
19
- n_mels: int = 80,
20
- ):
21
- super().__init__()
22
- self.fbankCal = torchaudio.transforms.MelSpectrogram(
23
- sample_rate=sample_rate,
24
- n_fft=n_fft,
25
- win_length=int(win_length * sample_rate),
26
- hop_length=int(hop_length * sample_rate),
27
- n_mels=n_mels,
28
- )
29
-
30
- def forward(self, x):
31
- out = self.fbankCal(x)
32
- out = torch.log(out + 1e-6)
33
- out = out - out.mean(axis=2).unsqueeze(dim=2)
34
- return out
35
-
36
-
37
- class ASP(nn.Module):
38
- # Attentive statistics pooling
39
- def __init__(self, in_planes, acoustic_dim):
40
- super(ASP, self).__init__()
41
- outmap_size = int(acoustic_dim / 8)
42
- self.out_dim = in_planes * 8 * outmap_size * 2
43
-
44
- self.attention = nn.Sequential(
45
- nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
46
- nn.ReLU(),
47
- nn.BatchNorm1d(128),
48
- nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
49
- nn.Softmax(dim=2),
50
- )
51
-
52
- def forward(self, x):
53
- x = x.reshape(x.size()[0], -1, x.size()[-1])
54
- w = self.attention(x)
55
- mu = torch.sum(x * w, dim=2)
56
- sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
57
- x = torch.cat((mu, sg), 1)
58
-
59
- x = x.view(x.size()[0], -1)
60
- return x
61
-
62
-
63
- class SimAMBasicBlock(nn.Module):
64
- expansion = 1
65
-
66
- def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
67
- super(SimAMBasicBlock, self).__init__()
68
- self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
69
- self.bn1 = NormLayer(planes)
70
- self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
71
- self.bn2 = NormLayer(planes)
72
- self.relu = nn.ReLU(inplace=True)
73
- self.sigmoid = nn.Sigmoid()
74
-
75
- self.downsample = nn.Sequential()
76
- if stride != 1 or in_planes != self.expansion * planes:
77
- self.downsample = nn.Sequential(
78
- ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
79
- NormLayer(self.expansion * planes),
80
- )
81
-
82
- def forward(self, x):
83
- out = self.relu(self.bn1(self.conv1(x)))
84
- out = self.bn2(self.conv2(out))
85
- out = self.SimAM(out)
86
- out += self.downsample(x)
87
- out = self.relu(out)
88
- return out
89
-
90
- def SimAM(self, X, lambda_p=1e-4):
91
- n = X.shape[2] * X.shape[3] - 1
92
- d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
93
- v = d.sum(dim=[2, 3], keepdim=True) / n
94
- E_inv = d / (4 * (v + lambda_p)) + 0.5
95
- return X * self.sigmoid(E_inv)
96
-
97
-
98
- class BasicBlock(nn.Module):
99
- expansion = 1
100
-
101
- def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
102
- super(BasicBlock, self).__init__()
103
- self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
104
- self.bn1 = NormLayer(planes)
105
- self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
106
- self.bn2 = NormLayer(planes)
107
- self.relu = nn.ReLU(inplace=True)
108
-
109
- self.downsample = nn.Sequential()
110
- if stride != 1 or in_planes != self.expansion * planes:
111
- self.downsample = nn.Sequential(
112
- ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
113
- NormLayer(self.expansion * planes),
114
- )
115
-
116
- def forward(self, x):
117
- out = self.relu(self.bn1(self.conv1(x)))
118
- out = self.bn2(self.conv2(out))
119
- out += self.downsample(x)
120
- out = self.relu(out)
121
- return out
122
-
123
-
124
- class Bottleneck(nn.Module):
125
- expansion = 4
126
-
127
- def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
128
- super(Bottleneck, self).__init__()
129
- self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
130
- self.bn1 = nn.BatchNorm2d(planes)
131
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
132
- self.bn2 = nn.BatchNorm2d(planes)
133
- self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
134
- self.bn3 = nn.BatchNorm2d(self.expansion * planes)
135
-
136
- self.shortcut = nn.Sequential()
137
- if stride != 1 or in_planes != self.expansion * planes:
138
- self.shortcut = nn.Sequential(
139
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
140
- nn.BatchNorm2d(self.expansion * planes),
141
- )
142
-
143
- def forward(self, x):
144
- out = F.relu(self.bn1(self.conv1(x)))
145
- out = F.relu(self.bn2(self.conv2(out)))
146
- out = self.bn3(self.conv3(out))
147
- out += self.shortcut(x)
148
- out = F.relu(out)
149
- return out
150
-
151
-
152
- class ResNet(nn.Module):
153
- def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
154
- super(ResNet, self).__init__()
155
- if feat_dim == "1d":
156
- self.NormLayer = nn.BatchNorm1d
157
- self.ConvLayer = nn.Conv1d
158
- elif feat_dim == "2d":
159
- self.NormLayer = nn.BatchNorm2d
160
- self.ConvLayer = nn.Conv2d
161
- elif feat_dim == "3d":
162
- self.NormLayer = nn.BatchNorm3d
163
- self.ConvLayer = nn.Conv3d
164
- else:
165
- print("error")
166
-
167
- self.in_planes = in_planes
168
-
169
- self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
170
- self.bn1 = self.NormLayer(in_planes)
171
- self.relu = nn.ReLU(inplace=True)
172
- self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
173
- self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
174
- self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
175
- self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
176
-
177
- def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
178
- strides = [stride] + [1] * (num_blocks - 1)
179
- layers = []
180
- for stride in strides:
181
- layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
182
- self.in_planes = planes * block.expansion
183
- return nn.Sequential(*layers)
184
-
185
- def forward(self, x):
186
- x = self.relu(self.bn1(self.conv1(x)))
187
- x = self.layer1(x)
188
- x = self.layer2(x)
189
- x = self.layer3(x)
190
- x = self.layer4(x)
191
- return x
192
-
193
-
194
- def ResNet293(in_planes: int, **kwargs):
195
- return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
196
-
197
-
198
- class ResNet293_based(nn.Module):
199
- def __init__(
200
- self,
201
- in_planes: int = 64,
202
- embd_dim: int = 256,
203
- acoustic_dim: int = 80,
204
- featCal=None,
205
- dropout: float = 0,
206
- **kwargs,
207
- ):
208
- super(ResNet293_based, self).__init__()
209
- self.featCal = featCal
210
- self.front = ResNet293(in_planes)
211
- block_expansion = SimAMBasicBlock.expansion
212
- self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
213
- self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
214
- self.drop = nn.Dropout(dropout) if dropout else None
215
-
216
- def forward(self, x):
217
- x = self.featCal(x)
218
- x = self.front(x.unsqueeze(dim=1))
219
- x = self.pooling(x)
220
- if self.drop:
221
- x = self.drop(x)
222
- x = self.bottleneck(x)
223
- return x
224
-
225
-
226
- class SEModule(nn.Module):
227
- def __init__(self, channels, bottleneck=128):
228
- super(SEModule, self).__init__()
229
- self.se = nn.Sequential(
230
- nn.AdaptiveAvgPool1d(1),
231
- nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
232
- nn.ReLU(),
233
- # nn.BatchNorm1d(bottleneck), # Removed
234
- nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
235
- nn.Sigmoid(),
236
- )
237
-
238
- def forward(self, input):
239
- x = self.se(input)
240
- return input * x
241
-
242
-
243
- class Bottle2neck(nn.Module):
244
- def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
245
- super(Bottle2neck, self).__init__()
246
- width = int(math.floor(planes / scale))
247
- self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
248
- self.bn1 = nn.BatchNorm1d(width * scale)
249
- self.nums = scale - 1
250
- convs = []
251
- bns = []
252
- num_pad = math.floor(kernel_size / 2) * dilation
253
- for i in range(self.nums):
254
- convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
255
- bns.append(nn.BatchNorm1d(width))
256
- self.convs = nn.ModuleList(convs)
257
- self.bns = nn.ModuleList(bns)
258
- self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
259
- self.bn3 = nn.BatchNorm1d(planes)
260
- self.relu = nn.ReLU()
261
- self.width = width
262
- self.se = SEModule(planes)
263
-
264
- def forward(self, x):
265
- residual = x
266
- out = self.conv1(x)
267
- out = self.relu(out)
268
- out = self.bn1(out)
269
-
270
- spx = torch.split(out, self.width, 1)
271
- for i in range(self.nums):
272
- if i == 0:
273
- sp = spx[i]
274
- else:
275
- sp = sp + spx[i]
276
- sp = self.convs[i](sp)
277
- sp = self.relu(sp)
278
- sp = self.bns[i](sp)
279
- if i == 0:
280
- out = sp
281
- else:
282
- out = torch.cat((out, sp), 1)
283
- out = torch.cat((out, spx[self.nums]), 1)
284
-
285
- out = self.conv3(out)
286
- out = self.relu(out)
287
- out = self.bn3(out)
288
-
289
- out = self.se(out)
290
- out += residual
291
- return out
292
-
293
-
294
- class ECAPA_TDNN(nn.Module):
295
- def __init__(self, C, featCal):
296
- super(ECAPA_TDNN, self).__init__()
297
- self.featCal = featCal
298
- self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
299
- self.relu = nn.ReLU()
300
- self.bn1 = nn.BatchNorm1d(C)
301
- self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
302
- self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
303
- self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
304
- # I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
305
- self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
306
- self.attention = nn.Sequential(
307
- nn.Conv1d(4608, 256, kernel_size=1),
308
- nn.ReLU(),
309
- nn.BatchNorm1d(256),
310
- nn.Tanh(), # Added
311
- nn.Conv1d(256, 1536, kernel_size=1),
312
- nn.Softmax(dim=2),
313
- )
314
- self.bn5 = nn.BatchNorm1d(3072)
315
- self.fc6 = nn.Linear(3072, 192)
316
- self.bn6 = nn.BatchNorm1d(192)
317
-
318
- def forward(self, x):
319
- x = self.featCal(x)
320
- x = self.conv1(x)
321
- x = self.relu(x)
322
- x = self.bn1(x)
323
-
324
- x1 = self.layer1(x)
325
- x2 = self.layer2(x + x1)
326
- x3 = self.layer3(x + x1 + x2)
327
-
328
- x = self.layer4(torch.cat((x1, x2, x3), dim=1))
329
- x = self.relu(x)
330
-
331
- t = x.size()[-1]
332
-
333
- global_x = torch.cat(
334
- (
335
- x,
336
- torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
337
- torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
338
- ),
339
- dim=1,
340
- )
341
-
342
- w = self.attention(global_x)
343
-
344
- mu = torch.sum(x * w, dim=2)
345
- sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
346
-
347
- x = torch.cat((mu, sg), 1)
348
- x = self.bn5(x)
349
- x = self.fc6(x)
350
- x = self.bn6(x)
351
-
352
- return x
353
-
354
-
355
- class SpeakerEmbedding(nn.Module):
356
- def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = "cuda"):
357
- super().__init__()
358
- self.device = device
359
- with torch.device(device):
360
- self.model = ResNet293_based()
361
- self.model.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
362
- self.model.featCal = logFbankCal()
363
-
364
- self.requires_grad_(False).eval()
365
-
366
- @property
367
- def dtype(self):
368
- return next(self.parameters()).dtype
369
-
370
- @cache
371
- def _get_resampler(self, orig_sample_rate: int):
372
- return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
373
-
374
- def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
375
- assert wav.ndim < 3
376
- if wav.ndim == 2:
377
- wav = wav.mean(0, keepdim=True)
378
- wav = self._get_resampler(sample_rate)(wav)
379
- return wav
380
-
381
- def forward(self, wav: torch.Tensor, sample_rate: int):
382
- wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
383
- return self.model(wav).to(wav.device)
384
-
385
- class SpeakerEmbeddingLDA(nn.Module):
386
- def __init__(
387
- self,
388
- device: str = "cuda",
389
- ):
390
- super().__init__()
391
- spk_model_path = hf_hub_download(repo_id="Quantamhash/Qhash-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base.pt")
392
- lda_spk_model_path = hf_hub_download(repo_id="Quantamhash/Qhash-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base_LDA-128.pt")
393
-
394
- self.device = device
395
- with torch.device(device):
396
- self.model = SpeakerEmbedding(spk_model_path, device)
397
- lda_sd = torch.load(lda_spk_model_path, weights_only=True)
398
- out_features, in_features = lda_sd["weight"].shape
399
- self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
400
- self.lda.load_state_dict(lda_sd)
401
-
402
- self.requires_grad_(False).eval()
403
-
404
- def forward(self, wav: torch.Tensor, sample_rate: int):
405
- emb = self.model(wav, sample_rate).to(torch.float32)
406
- return emb, self.lda(emb)