Spaces:
Paused
Paused
Delete qhash
Browse files- qhash/autoencoder.py +0 -26
- qhash/backbone.py +0 -50
- qhash/codebook_pattern.py +0 -12
- qhash/conditioning.py +0 -373
- qhash/config.py +0 -38
- qhash/model.py +0 -270
- qhash/sampling.py +0 -141
- qhash/speaker_cloning.py +0 -406
qhash/autoencoder.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torchaudio
|
| 5 |
-
from transformers.models.dac import DacModel
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class DACAutoencoder:
|
| 9 |
-
def __init__(self):
|
| 10 |
-
super().__init__()
|
| 11 |
-
self.dac = DacModel.from_pretrained("Quantamhash/dac_44khz")
|
| 12 |
-
self.dac.eval().requires_grad_(False)
|
| 13 |
-
self.codebook_size = self.dac.config.codebook_size
|
| 14 |
-
self.num_codebooks = self.dac.quantizer.n_codebooks
|
| 15 |
-
self.sampling_rate = self.dac.config.sampling_rate
|
| 16 |
-
|
| 17 |
-
def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
| 18 |
-
wav = torchaudio.functional.resample(wav, sr, 44_100)
|
| 19 |
-
right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
|
| 20 |
-
return torch.nn.functional.pad(wav, (0, right_pad))
|
| 21 |
-
|
| 22 |
-
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
| 23 |
-
return self.dac.encode(wav).audio_codes
|
| 24 |
-
|
| 25 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 26 |
-
return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qhash/backbone.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from mamba_ssm.models.mixer_seq_simple import create_block
|
| 4 |
-
from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
|
| 5 |
-
from mamba_ssm.utils.generation import InferenceParams
|
| 6 |
-
|
| 7 |
-
from qhash.config import BackboneConfig
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class ZonosBackbone(nn.Module):
|
| 11 |
-
def __init__(self, config: BackboneConfig):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.config = config
|
| 14 |
-
|
| 15 |
-
self.layers = nn.ModuleList(
|
| 16 |
-
[
|
| 17 |
-
create_block(
|
| 18 |
-
d_model=config.d_model,
|
| 19 |
-
d_intermediate=config.d_intermediate
|
| 20 |
-
if (i not in config.attn_layer_idx)
|
| 21 |
-
else config.attn_mlp_d_intermediate,
|
| 22 |
-
ssm_cfg=config.ssm_cfg,
|
| 23 |
-
layer_idx=i,
|
| 24 |
-
attn_layer_idx=config.attn_layer_idx,
|
| 25 |
-
attn_cfg=config.attn_cfg,
|
| 26 |
-
norm_epsilon=config.norm_epsilon,
|
| 27 |
-
residual_in_fp32=config.residual_in_fp32,
|
| 28 |
-
fused_add_norm=True,
|
| 29 |
-
rms_norm=config.rms_norm,
|
| 30 |
-
)
|
| 31 |
-
for i in range(config.n_layer)
|
| 32 |
-
]
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
|
| 36 |
-
|
| 37 |
-
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
|
| 38 |
-
residual = None
|
| 39 |
-
for layer in self.layers:
|
| 40 |
-
hidden_states, residual = layer(hidden_states, residual, inference_params)
|
| 41 |
-
|
| 42 |
-
return layer_norm_fn(
|
| 43 |
-
hidden_states,
|
| 44 |
-
self.norm_f.weight,
|
| 45 |
-
self.norm_f.bias,
|
| 46 |
-
residual,
|
| 47 |
-
eps=self.norm_f.eps,
|
| 48 |
-
residual_in_fp32=self.config.residual_in_fp32,
|
| 49 |
-
is_rms_norm=self.config.rms_norm,
|
| 50 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qhash/codebook_pattern.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn.functional as F
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
|
| 6 |
-
codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
|
| 7 |
-
return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def revert_delay_pattern(codes: torch.Tensor):
|
| 11 |
-
_, n_q, seq_len = codes.shape
|
| 12 |
-
return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qhash/conditioning.py
DELETED
|
@@ -1,373 +0,0 @@
|
|
| 1 |
-
from functools import cache
|
| 2 |
-
from typing import Any, Literal, Iterable
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
|
| 7 |
-
from qhash.config import PrefixConditionerConfig
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Conditioner(nn.Module):
|
| 11 |
-
def __init__(
|
| 12 |
-
self,
|
| 13 |
-
output_dim: int,
|
| 14 |
-
name: str,
|
| 15 |
-
cond_dim: int | None = None,
|
| 16 |
-
projection: Literal["none", "linear", "mlp"] = "none",
|
| 17 |
-
uncond_type: Literal["learned", "none"] = "none",
|
| 18 |
-
**kwargs,
|
| 19 |
-
):
|
| 20 |
-
super().__init__()
|
| 21 |
-
self.name = name
|
| 22 |
-
self.output_dim = output_dim
|
| 23 |
-
self.cond_dim = cond_dim = cond_dim or output_dim
|
| 24 |
-
|
| 25 |
-
if projection == "linear":
|
| 26 |
-
self.project = nn.Linear(cond_dim, output_dim)
|
| 27 |
-
elif projection == "mlp":
|
| 28 |
-
self.project = nn.Sequential(
|
| 29 |
-
nn.Linear(cond_dim, output_dim),
|
| 30 |
-
nn.SiLU(),
|
| 31 |
-
nn.Linear(output_dim, output_dim),
|
| 32 |
-
)
|
| 33 |
-
else:
|
| 34 |
-
self.project = nn.Identity()
|
| 35 |
-
|
| 36 |
-
self.uncond_vector = None
|
| 37 |
-
if uncond_type == "learned":
|
| 38 |
-
self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
|
| 39 |
-
|
| 40 |
-
def apply_cond(self, *inputs: Any) -> torch.Tensor:
|
| 41 |
-
raise NotImplementedError()
|
| 42 |
-
|
| 43 |
-
def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
|
| 44 |
-
if inputs is None:
|
| 45 |
-
assert self.uncond_vector is not None
|
| 46 |
-
return self.uncond_vector.data.view(1, 1, -1)
|
| 47 |
-
|
| 48 |
-
cond = self.apply_cond(*inputs)
|
| 49 |
-
cond = self.project(cond)
|
| 50 |
-
return cond
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
| 54 |
-
import re
|
| 55 |
-
import unicodedata
|
| 56 |
-
|
| 57 |
-
import inflect
|
| 58 |
-
import torch
|
| 59 |
-
import torch.nn as nn
|
| 60 |
-
from kanjize import number2kanji
|
| 61 |
-
from phonemizer.backend import EspeakBackend
|
| 62 |
-
from sudachipy import Dictionary, SplitMode
|
| 63 |
-
|
| 64 |
-
# --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
|
| 65 |
-
|
| 66 |
-
_inflect = inflect.engine()
|
| 67 |
-
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
| 68 |
-
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
| 69 |
-
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
| 70 |
-
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
| 71 |
-
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
| 72 |
-
_number_re = re.compile(r"[0-9]+")
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def _remove_commas(m: re.Match) -> str:
|
| 76 |
-
return m.group(1).replace(",", "")
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def _expand_decimal_point(m: re.Match) -> str:
|
| 80 |
-
return m.group(1).replace(".", " point ")
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def _expand_dollars(m: re.Match) -> str:
|
| 84 |
-
match = m.group(1)
|
| 85 |
-
parts = match.split(".")
|
| 86 |
-
if len(parts) > 2:
|
| 87 |
-
return match + " dollars" # Unexpected format
|
| 88 |
-
dollars = int(parts[0]) if parts[0] else 0
|
| 89 |
-
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
| 90 |
-
if dollars and cents:
|
| 91 |
-
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 92 |
-
cent_unit = "cent" if cents == 1 else "cents"
|
| 93 |
-
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
| 94 |
-
elif dollars:
|
| 95 |
-
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 96 |
-
return "%s %s" % (dollars, dollar_unit)
|
| 97 |
-
elif cents:
|
| 98 |
-
cent_unit = "cent" if cents == 1 else "cents"
|
| 99 |
-
return "%s %s" % (cents, cent_unit)
|
| 100 |
-
else:
|
| 101 |
-
return "zero dollars"
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def _expand_ordinal(m: re.Match) -> str:
|
| 105 |
-
return _inflect.number_to_words(m.group(0))
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def _expand_number(m: re.Match) -> str:
|
| 109 |
-
num = int(m.group(0))
|
| 110 |
-
if num > 1000 and num < 3000:
|
| 111 |
-
if num == 2000:
|
| 112 |
-
return "two thousand"
|
| 113 |
-
elif num > 2000 and num < 2010:
|
| 114 |
-
return "two thousand " + _inflect.number_to_words(num % 100)
|
| 115 |
-
elif num % 100 == 0:
|
| 116 |
-
return _inflect.number_to_words(num // 100) + " hundred"
|
| 117 |
-
else:
|
| 118 |
-
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
| 119 |
-
else:
|
| 120 |
-
return _inflect.number_to_words(num, andword="")
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def normalize_numbers(text: str) -> str:
|
| 124 |
-
text = re.sub(_comma_number_re, _remove_commas, text)
|
| 125 |
-
text = re.sub(_pounds_re, r"\1 pounds", text)
|
| 126 |
-
text = re.sub(_dollars_re, _expand_dollars, text)
|
| 127 |
-
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
| 128 |
-
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
| 129 |
-
text = re.sub(_number_re, _expand_number, text)
|
| 130 |
-
return text
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
# --- Number normalization code end ---
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
|
| 137 |
-
SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
|
| 138 |
-
|
| 139 |
-
_punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
|
| 140 |
-
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
| 141 |
-
_letters_ipa = (
|
| 142 |
-
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
symbols = [*_punctuation, *_letters, *_letters_ipa]
|
| 146 |
-
_symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def _get_symbol_id(s: str) -> int:
|
| 150 |
-
return _symbol_to_id.get(s, 1)
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def get_symbol_ids(text: str) -> list[int]:
|
| 154 |
-
return list(map(_get_symbol_id, text))
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
|
| 158 |
-
phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
|
| 159 |
-
lengths = list(map(len, phoneme_ids))
|
| 160 |
-
longest = max(lengths)
|
| 161 |
-
phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
|
| 162 |
-
return torch.tensor(phoneme_ids), lengths
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
|
| 166 |
-
text = unicodedata.normalize("NFKC", text)
|
| 167 |
-
text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
|
| 168 |
-
final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
|
| 169 |
-
return final_text
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def clean(texts: list[str], languages: list[str]) -> list[str]:
|
| 173 |
-
texts_out = []
|
| 174 |
-
for text, language in zip(texts, languages):
|
| 175 |
-
if "ja" in language:
|
| 176 |
-
text = normalize_jp_text(text)
|
| 177 |
-
else:
|
| 178 |
-
text = normalize_numbers(text)
|
| 179 |
-
texts_out.append(text)
|
| 180 |
-
return texts_out
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
@cache
|
| 184 |
-
def get_backend(language: str) -> "EspeakBackend":
|
| 185 |
-
import logging
|
| 186 |
-
|
| 187 |
-
from phonemizer.backend import EspeakBackend
|
| 188 |
-
|
| 189 |
-
logger = logging.getLogger("phonemizer")
|
| 190 |
-
backend = EspeakBackend(
|
| 191 |
-
language,
|
| 192 |
-
preserve_punctuation=True,
|
| 193 |
-
with_stress=True,
|
| 194 |
-
punctuation_marks=_punctuation,
|
| 195 |
-
logger=logger,
|
| 196 |
-
)
|
| 197 |
-
logger.setLevel(logging.ERROR)
|
| 198 |
-
return backend
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def phonemize(texts: list[str], languages: list[str]) -> list[str]:
|
| 202 |
-
texts = clean(texts, languages)
|
| 203 |
-
|
| 204 |
-
batch_phonemes = []
|
| 205 |
-
for text, language in zip(texts, languages):
|
| 206 |
-
backend = get_backend(language)
|
| 207 |
-
phonemes = backend.phonemize([text], strip=True)
|
| 208 |
-
batch_phonemes.append(phonemes[0])
|
| 209 |
-
|
| 210 |
-
return batch_phonemes
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
class EspeakPhonemeConditioner(Conditioner):
|
| 214 |
-
def __init__(self, output_dim: int, **kwargs):
|
| 215 |
-
super().__init__(output_dim, **kwargs)
|
| 216 |
-
self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
|
| 217 |
-
|
| 218 |
-
def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
|
| 219 |
-
"""
|
| 220 |
-
Args:
|
| 221 |
-
texts: list of texts to convert to phonemes
|
| 222 |
-
languages: ISO 639-1 -or otherwise eSpeak compatible- language code
|
| 223 |
-
"""
|
| 224 |
-
device = self.phoneme_embedder.weight.device
|
| 225 |
-
|
| 226 |
-
phonemes = phonemize(texts, languages)
|
| 227 |
-
phoneme_ids, _ = tokenize_phonemes(phonemes)
|
| 228 |
-
phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
|
| 229 |
-
|
| 230 |
-
return phoneme_embeds
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
class FourierConditioner(Conditioner):
|
| 237 |
-
def __init__(
|
| 238 |
-
self,
|
| 239 |
-
output_dim: int,
|
| 240 |
-
input_dim: int = 1,
|
| 241 |
-
std: float = 1.0,
|
| 242 |
-
min_val: float = 0.0,
|
| 243 |
-
max_val: float = 1.0,
|
| 244 |
-
**kwargs,
|
| 245 |
-
):
|
| 246 |
-
assert output_dim % 2 == 0
|
| 247 |
-
super().__init__(output_dim, **kwargs)
|
| 248 |
-
self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
|
| 249 |
-
self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
|
| 250 |
-
|
| 251 |
-
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
| 252 |
-
assert x.shape[-1] == self.input_dim
|
| 253 |
-
x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
|
| 254 |
-
f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
|
| 255 |
-
return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
class IntegerConditioner(Conditioner):
|
| 259 |
-
def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
|
| 260 |
-
super().__init__(output_dim, **kwargs)
|
| 261 |
-
self.min_val = min_val
|
| 262 |
-
self.max_val = max_val
|
| 263 |
-
self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
|
| 264 |
-
|
| 265 |
-
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
| 266 |
-
assert x.shape[-1] == 1
|
| 267 |
-
return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
class PassthroughConditioner(Conditioner):
|
| 271 |
-
def __init__(self, output_dim: int, **kwargs):
|
| 272 |
-
super().__init__(output_dim, **kwargs)
|
| 273 |
-
|
| 274 |
-
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
| 275 |
-
assert x.shape[-1] == self.cond_dim
|
| 276 |
-
return x
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
_cond_cls_map = {
|
| 280 |
-
"PassthroughConditioner": PassthroughConditioner,
|
| 281 |
-
"EspeakPhonemeConditioner": EspeakPhonemeConditioner,
|
| 282 |
-
"FourierConditioner": FourierConditioner,
|
| 283 |
-
"IntegerConditioner": IntegerConditioner,
|
| 284 |
-
}
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
|
| 288 |
-
return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
class PrefixConditioner(Conditioner):
|
| 292 |
-
def __init__(self, config: PrefixConditionerConfig, output_dim: int):
|
| 293 |
-
super().__init__(output_dim, "prefix", projection=config.projection)
|
| 294 |
-
self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
|
| 295 |
-
self.norm = nn.LayerNorm(output_dim)
|
| 296 |
-
self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
|
| 297 |
-
|
| 298 |
-
def forward(self, cond_dict: dict) -> torch.Tensor:
|
| 299 |
-
if not set(cond_dict).issuperset(self.required_keys):
|
| 300 |
-
raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
|
| 301 |
-
conds = []
|
| 302 |
-
for conditioner in self.conditioners:
|
| 303 |
-
conds.append(conditioner(cond_dict.get(conditioner.name)))
|
| 304 |
-
max_bsz = max(map(len, conds))
|
| 305 |
-
assert all(c.shape[0] in (max_bsz, 1) for c in conds)
|
| 306 |
-
conds = [c.expand(max_bsz, -1, -1) for c in conds]
|
| 307 |
-
return self.norm(self.project(torch.cat(conds, dim=-2)))
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
supported_language_codes = [
|
| 311 |
-
'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
|
| 312 |
-
'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
|
| 313 |
-
'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
|
| 314 |
-
'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
|
| 315 |
-
'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
|
| 316 |
-
'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
|
| 317 |
-
'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
|
| 318 |
-
'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
|
| 319 |
-
'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
|
| 320 |
-
'vi-vn-x-central', 'vi-vn-x-south', 'yue'
|
| 321 |
-
] # fmt: off
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
def make_cond_dict(
|
| 325 |
-
text: str = "It would be nice to have time for testing, indeed.",
|
| 326 |
-
language: str = "en-us",
|
| 327 |
-
speaker: torch.Tensor | None = None,
|
| 328 |
-
emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
|
| 329 |
-
fmax: float = 22050.0,
|
| 330 |
-
pitch_std: float = 20.0,
|
| 331 |
-
speaking_rate: float = 15.0,
|
| 332 |
-
vqscore_8: list[float] = [0.78] * 8,
|
| 333 |
-
ctc_loss: float = 0.0,
|
| 334 |
-
dnsmos_ovrl: float = 4.0,
|
| 335 |
-
speaker_noised: bool = False,
|
| 336 |
-
unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
|
| 337 |
-
device: str = "cuda",
|
| 338 |
-
) -> dict:
|
| 339 |
-
"""
|
| 340 |
-
A helper to build the 'cond_dict' that the model expects.
|
| 341 |
-
By default, it will generate a random speaker embedding
|
| 342 |
-
"""
|
| 343 |
-
assert language.lower() in supported_language_codes, "Please pick a supported language"
|
| 344 |
-
|
| 345 |
-
language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
|
| 346 |
-
|
| 347 |
-
cond_dict = {
|
| 348 |
-
"espeak": ([text], [language]),
|
| 349 |
-
"speaker": speaker,
|
| 350 |
-
"emotion": emotion,
|
| 351 |
-
"fmax": fmax,
|
| 352 |
-
"pitch_std": pitch_std,
|
| 353 |
-
"speaking_rate": speaking_rate,
|
| 354 |
-
"language_id": language_code_to_id[language],
|
| 355 |
-
"vqscore_8": vqscore_8,
|
| 356 |
-
"ctc_loss": ctc_loss,
|
| 357 |
-
"dnsmos_ovrl": dnsmos_ovrl,
|
| 358 |
-
"speaker_noised": int(speaker_noised),
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
for k in unconditional_keys:
|
| 362 |
-
cond_dict.pop(k, None)
|
| 363 |
-
|
| 364 |
-
for k, v in cond_dict.items():
|
| 365 |
-
if isinstance(v, (float, int, list)):
|
| 366 |
-
v = torch.tensor(v)
|
| 367 |
-
if isinstance(v, torch.Tensor):
|
| 368 |
-
cond_dict[k] = v.view(1, 1, -1).to(device)
|
| 369 |
-
|
| 370 |
-
if k == "emotion":
|
| 371 |
-
cond_dict[k] /= cond_dict[k].sum(dim=-1)
|
| 372 |
-
|
| 373 |
-
return cond_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qhash/config.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass, field
|
| 2 |
-
from typing import Literal
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
@dataclass
|
| 6 |
-
class BackboneConfig:
|
| 7 |
-
d_model: int = 1024
|
| 8 |
-
d_intermediate: int = 0
|
| 9 |
-
attn_mlp_d_intermediate: int = 0
|
| 10 |
-
n_layer: int = 16
|
| 11 |
-
ssm_cfg: dict = field(default_factory=dict)
|
| 12 |
-
attn_layer_idx: list = field(default_factory=list)
|
| 13 |
-
attn_cfg: dict = field(default_factory=dict)
|
| 14 |
-
rms_norm: bool = False
|
| 15 |
-
residual_in_fp32: bool = False
|
| 16 |
-
norm_epsilon: float = 1e-5
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
@dataclass
|
| 20 |
-
class PrefixConditionerConfig:
|
| 21 |
-
conditioners: list[dict]
|
| 22 |
-
projection: Literal["none", "linear", "mlp"]
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class ZonosConfig:
|
| 27 |
-
backbone: BackboneConfig
|
| 28 |
-
prefix_conditioner: PrefixConditionerConfig
|
| 29 |
-
eos_token_id: int = 1024
|
| 30 |
-
masked_token_id: int = 1025
|
| 31 |
-
|
| 32 |
-
@classmethod
|
| 33 |
-
def from_dict(cls, d: dict) -> "ZonosConfig":
|
| 34 |
-
d = d.copy()
|
| 35 |
-
backbone_config = BackboneConfig(**d.pop("backbone"))
|
| 36 |
-
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
|
| 37 |
-
config = cls(backbone_config, prefix_conditioner_config, **d)
|
| 38 |
-
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qhash/model.py
DELETED
|
@@ -1,270 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
from typing import Callable
|
| 3 |
-
|
| 4 |
-
import safetensors
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
from huggingface_hub import hf_hub_download
|
| 8 |
-
from mamba_ssm.utils.generation import InferenceParams
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
-
|
| 11 |
-
from qhash.backbone import ZonosBackbone
|
| 12 |
-
from qhash.autoencoder import DACAutoencoder
|
| 13 |
-
from qhash.codebook_pattern import apply_delay_pattern, revert_delay_pattern
|
| 14 |
-
from qhash.conditioning import PrefixConditioner
|
| 15 |
-
from qhash.config import ZonosConfig
|
| 16 |
-
from qhash.sampling import sample_from_logits
|
| 17 |
-
from qhash.speaker_cloning import SpeakerEmbeddingLDA
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class Zonos(nn.Module):
|
| 21 |
-
def __init__(self, config: ZonosConfig):
|
| 22 |
-
super().__init__()
|
| 23 |
-
self.config = config
|
| 24 |
-
dim = config.backbone.d_model
|
| 25 |
-
self.eos_token_id = config.eos_token_id
|
| 26 |
-
self.masked_token_id = config.masked_token_id
|
| 27 |
-
|
| 28 |
-
self.autoencoder = DACAutoencoder()
|
| 29 |
-
self.backbone = ZonosBackbone(config.backbone)
|
| 30 |
-
self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
|
| 31 |
-
self.spk_clone_model = None
|
| 32 |
-
|
| 33 |
-
# TODO: pad to multiple of at least 8
|
| 34 |
-
self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
|
| 35 |
-
self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
|
| 36 |
-
|
| 37 |
-
self._cg_graph = None
|
| 38 |
-
self._cg_batch_size = None
|
| 39 |
-
self._cg_input_ids = None
|
| 40 |
-
self._cg_logits = None
|
| 41 |
-
self._cg_inference_params = None
|
| 42 |
-
self._cg_scale = None
|
| 43 |
-
|
| 44 |
-
@classmethod
|
| 45 |
-
def from_pretrained(cls, repo_id: str, revision: str | None = None, device: str = "cuda") -> "Zonos":
|
| 46 |
-
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
| 47 |
-
model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
| 48 |
-
return cls.from_local(config_path, model_path, device)
|
| 49 |
-
|
| 50 |
-
@classmethod
|
| 51 |
-
def from_local(cls, config_path: str, model_path: str, device: str = "cuda") -> "Zonos":
|
| 52 |
-
config = ZonosConfig.from_dict(json.load(open(config_path)))
|
| 53 |
-
model = cls(config).to(device, torch.bfloat16)
|
| 54 |
-
model.autoencoder.dac.to(device)
|
| 55 |
-
|
| 56 |
-
sd = model.state_dict()
|
| 57 |
-
with safetensors.safe_open(model_path, framework="pt") as f:
|
| 58 |
-
for k in f.keys():
|
| 59 |
-
sd[k] = f.get_tensor(k)
|
| 60 |
-
model.load_state_dict(sd)
|
| 61 |
-
|
| 62 |
-
return model
|
| 63 |
-
|
| 64 |
-
def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
| 65 |
-
"""Generate a speaker embedding from an audio clip."""
|
| 66 |
-
if self.spk_clone_model is None:
|
| 67 |
-
self.spk_clone_model = SpeakerEmbeddingLDA()
|
| 68 |
-
_, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
|
| 69 |
-
return spk_embedding.unsqueeze(0).bfloat16()
|
| 70 |
-
|
| 71 |
-
def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
|
| 72 |
-
return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
|
| 73 |
-
|
| 74 |
-
def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 75 |
-
return torch.stack([head(hidden_states) for head in self.heads], dim=1)
|
| 76 |
-
|
| 77 |
-
def _compute_logits(
|
| 78 |
-
self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
|
| 79 |
-
) -> torch.Tensor:
|
| 80 |
-
"""
|
| 81 |
-
Pass `hidden_states` into `backbone` and `multi_head`, applying
|
| 82 |
-
classifier-free guidance if `cfg_scale != 1.0`.
|
| 83 |
-
"""
|
| 84 |
-
last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
|
| 85 |
-
logits = self.apply_heads(last_hidden_states).squeeze(2).float()
|
| 86 |
-
if cfg_scale != 1.0:
|
| 87 |
-
cond_logits, uncond_logits = logits.chunk(2)
|
| 88 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
| 89 |
-
return logits
|
| 90 |
-
|
| 91 |
-
def _decode_one_token(
|
| 92 |
-
self,
|
| 93 |
-
input_ids: torch.Tensor,
|
| 94 |
-
inference_params: InferenceParams,
|
| 95 |
-
cfg_scale: float,
|
| 96 |
-
) -> torch.Tensor:
|
| 97 |
-
"""
|
| 98 |
-
Single-step decode. Prepares the hidden states, possibly replicates them
|
| 99 |
-
for CFG, and then delegates to `_compute_logits`.
|
| 100 |
-
|
| 101 |
-
Below we wrap this function with a simple CUDA Graph capturing mechanism,
|
| 102 |
-
doing 3 warmup steps if needed and then capturing or replaying the graph.
|
| 103 |
-
We only recapture if the batch size changes.
|
| 104 |
-
"""
|
| 105 |
-
# TODO: support cfg_scale==1
|
| 106 |
-
if cfg_scale == 1.0:
|
| 107 |
-
hidden_states = self.embed_codes(input_ids)
|
| 108 |
-
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
| 109 |
-
|
| 110 |
-
bsz = input_ids.size(0)
|
| 111 |
-
|
| 112 |
-
need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
|
| 113 |
-
|
| 114 |
-
if need_capture:
|
| 115 |
-
self._cg_graph = None
|
| 116 |
-
|
| 117 |
-
self._cg_batch_size = bsz
|
| 118 |
-
self._cg_inference_params = inference_params
|
| 119 |
-
self._cg_scale = cfg_scale
|
| 120 |
-
|
| 121 |
-
for _ in range(3):
|
| 122 |
-
hidden_states = self.embed_codes(input_ids)
|
| 123 |
-
hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
|
| 124 |
-
logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
|
| 125 |
-
|
| 126 |
-
self._cg_input_ids = input_ids.clone()
|
| 127 |
-
self._cg_logits = torch.empty_like(logits)
|
| 128 |
-
|
| 129 |
-
g = torch.cuda.CUDAGraph()
|
| 130 |
-
|
| 131 |
-
def capture_region():
|
| 132 |
-
hidden_states_local = self.embed_codes(self._cg_input_ids)
|
| 133 |
-
hidden_states_local = hidden_states_local.repeat(2, 1, 1)
|
| 134 |
-
self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
|
| 135 |
-
|
| 136 |
-
with torch.cuda.graph(g):
|
| 137 |
-
capture_region()
|
| 138 |
-
|
| 139 |
-
self._cg_graph = g
|
| 140 |
-
|
| 141 |
-
else:
|
| 142 |
-
self._cg_input_ids.copy_(input_ids)
|
| 143 |
-
|
| 144 |
-
self._cg_graph.replay()
|
| 145 |
-
|
| 146 |
-
return self._cg_logits
|
| 147 |
-
|
| 148 |
-
def _prefill(
|
| 149 |
-
self,
|
| 150 |
-
prefix_hidden_states: torch.Tensor,
|
| 151 |
-
input_ids: torch.Tensor,
|
| 152 |
-
inference_params: InferenceParams,
|
| 153 |
-
cfg_scale: float,
|
| 154 |
-
) -> torch.Tensor:
|
| 155 |
-
"""
|
| 156 |
-
"Prefill" mode: we already have `prefix_hidden_states`, and we want
|
| 157 |
-
to append new embeddings, then compute the logits.
|
| 158 |
-
"""
|
| 159 |
-
# Replicate input_ids if CFG is enabled
|
| 160 |
-
if cfg_scale != 1.0:
|
| 161 |
-
input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
|
| 162 |
-
hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
|
| 163 |
-
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
| 164 |
-
|
| 165 |
-
def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
|
| 166 |
-
key_value_memory_dict = {
|
| 167 |
-
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
|
| 168 |
-
for i, layer in enumerate(self.backbone.layers)
|
| 169 |
-
}
|
| 170 |
-
lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32, device="cuda")
|
| 171 |
-
return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
|
| 172 |
-
|
| 173 |
-
def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
|
| 174 |
-
if uncond_dict is None:
|
| 175 |
-
uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
|
| 176 |
-
return torch.cat(
|
| 177 |
-
[
|
| 178 |
-
self.prefix_conditioner(cond_dict),
|
| 179 |
-
self.prefix_conditioner(uncond_dict),
|
| 180 |
-
]
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
@torch.inference_mode()
|
| 184 |
-
def generate(
|
| 185 |
-
self,
|
| 186 |
-
prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
|
| 187 |
-
audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
|
| 188 |
-
max_new_tokens: int = 86 * 30,
|
| 189 |
-
cfg_scale: float = 2.0,
|
| 190 |
-
batch_size: int = 1,
|
| 191 |
-
sampling_params: dict = dict(min_p=0.1),
|
| 192 |
-
progress_bar: bool = True,
|
| 193 |
-
callback: Callable[[torch.Tensor, int, int], bool] | None = None,
|
| 194 |
-
):
|
| 195 |
-
assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
|
| 196 |
-
prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
|
| 197 |
-
|
| 198 |
-
unknown_token = -1
|
| 199 |
-
audio_seq_len = prefix_audio_len + max_new_tokens
|
| 200 |
-
seq_len = prefix_conditioning.shape[1] + audio_seq_len
|
| 201 |
-
|
| 202 |
-
inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
|
| 203 |
-
|
| 204 |
-
codes = torch.full((batch_size, 9, audio_seq_len), unknown_token, device="cuda")
|
| 205 |
-
if audio_prefix_codes is not None:
|
| 206 |
-
codes[..., :prefix_audio_len] = audio_prefix_codes
|
| 207 |
-
|
| 208 |
-
delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
|
| 209 |
-
|
| 210 |
-
delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
|
| 211 |
-
|
| 212 |
-
logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
|
| 213 |
-
next_token = sample_from_logits(logits, **sampling_params)
|
| 214 |
-
|
| 215 |
-
offset = delayed_prefix_audio_codes.shape[2]
|
| 216 |
-
frame = delayed_codes[..., offset : offset + 1]
|
| 217 |
-
frame.masked_scatter_(frame == unknown_token, next_token)
|
| 218 |
-
|
| 219 |
-
prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
|
| 220 |
-
inference_params.seqlen_offset += prefix_length
|
| 221 |
-
inference_params.lengths_per_sample[:] += prefix_length
|
| 222 |
-
|
| 223 |
-
logit_bias = torch.zeros_like(logits)
|
| 224 |
-
logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
|
| 225 |
-
|
| 226 |
-
stopping = torch.zeros(batch_size, dtype=torch.bool, device="cuda")
|
| 227 |
-
max_steps = delayed_codes.shape[2] - offset
|
| 228 |
-
remaining_steps = torch.full((batch_size,), max_steps, device="cuda")
|
| 229 |
-
progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
|
| 230 |
-
|
| 231 |
-
step = 0
|
| 232 |
-
while torch.max(remaining_steps) > 0:
|
| 233 |
-
offset += 1
|
| 234 |
-
input_ids = delayed_codes[..., offset - 1 : offset]
|
| 235 |
-
logits = self._decode_one_token(input_ids, inference_params, cfg_scale)
|
| 236 |
-
|
| 237 |
-
next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
|
| 238 |
-
eos_in_cb0 = next_token[:, 0] == self.eos_token_id
|
| 239 |
-
|
| 240 |
-
remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
|
| 241 |
-
stopping |= eos_in_cb0[:, 0]
|
| 242 |
-
|
| 243 |
-
eos_codebook_idx = 9 - remaining_steps
|
| 244 |
-
eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
|
| 245 |
-
for i in range(next_token.shape[0]):
|
| 246 |
-
if stopping[i]:
|
| 247 |
-
idx = eos_codebook_idx[i].item()
|
| 248 |
-
next_token[i, :idx] = self.masked_token_id
|
| 249 |
-
next_token[i, idx] = self.eos_token_id
|
| 250 |
-
|
| 251 |
-
frame = delayed_codes[..., offset : offset + 1]
|
| 252 |
-
frame.masked_scatter_(frame == unknown_token, next_token)
|
| 253 |
-
inference_params.seqlen_offset += 1
|
| 254 |
-
inference_params.lengths_per_sample[:] += 1
|
| 255 |
-
|
| 256 |
-
remaining_steps -= 1
|
| 257 |
-
|
| 258 |
-
progress.update()
|
| 259 |
-
step += 1
|
| 260 |
-
|
| 261 |
-
if callback is not None and not callback(frame, step, max_steps):
|
| 262 |
-
break
|
| 263 |
-
|
| 264 |
-
out_codes = revert_delay_pattern(delayed_codes)
|
| 265 |
-
out_codes.masked_fill_(out_codes >= 1024, 0)
|
| 266 |
-
out_codes = out_codes[..., : offset - 9]
|
| 267 |
-
|
| 268 |
-
self._cg_graph = None # reset cuda graph to avoid cache changes
|
| 269 |
-
|
| 270 |
-
return out_codes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qhash/sampling.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
| 5 |
-
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
| 6 |
-
|
| 7 |
-
Args:
|
| 8 |
-
input (torch.Tensor): The input tensor containing probabilities.
|
| 9 |
-
num_samples (int): Number of samples to draw.
|
| 10 |
-
replacement (bool): Whether to draw with replacement or not.
|
| 11 |
-
Keywords args:
|
| 12 |
-
generator (torch.Generator): A pseudorandom number generator for sampling.
|
| 13 |
-
Returns:
|
| 14 |
-
torch.Tensor: Last dimension contains num_samples indices
|
| 15 |
-
sampled from the multinomial probability distribution
|
| 16 |
-
located in the last dimension of tensor input.
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
if num_samples == 1:
|
| 20 |
-
q = torch.empty_like(input).exponential_(1, generator=generator)
|
| 21 |
-
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
| 22 |
-
|
| 23 |
-
input_ = input.reshape(-1, input.shape[-1])
|
| 24 |
-
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
| 25 |
-
output = output_.reshape(*list(input.shape[:-1]), -1)
|
| 26 |
-
return output
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def apply_top_k(
|
| 30 |
-
probs: torch.Tensor,
|
| 31 |
-
k: int,
|
| 32 |
-
) -> torch.Tensor:
|
| 33 |
-
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
| 34 |
-
|
| 35 |
-
Args:
|
| 36 |
-
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
| 37 |
-
k (int): The k in “top-k”.
|
| 38 |
-
Returns:
|
| 39 |
-
torch.Tensor: Sampled tokens.
|
| 40 |
-
"""
|
| 41 |
-
v, _ = torch.topk(probs, min(k, probs.size(-1)))
|
| 42 |
-
pivot = v.select(-1, -1).unsqueeze(-1)
|
| 43 |
-
probs = torch.where(probs < pivot, 0.0, probs)
|
| 44 |
-
probs.div_(probs.sum(dim=-1, keepdim=True))
|
| 45 |
-
return probs
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
| 49 |
-
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
| 53 |
-
p (int): The p in “top-p”.
|
| 54 |
-
Returns:
|
| 55 |
-
torch.Tensor: Sampled tokens.
|
| 56 |
-
"""
|
| 57 |
-
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 58 |
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 59 |
-
mask = probs_sum - probs_sort > p
|
| 60 |
-
probs_sort *= (~mask).float()
|
| 61 |
-
probs = probs.scatter(-1, probs_idx, probs_sort)
|
| 62 |
-
probs.div_(probs.sum(dim=-1, keepdim=True))
|
| 63 |
-
return probs
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
|
| 67 |
-
"""Sample next token using min-p sampling.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
|
| 71 |
-
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
| 72 |
-
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
| 73 |
-
Returns:
|
| 74 |
-
torch.Tensor: Sampled tokens.
|
| 75 |
-
"""
|
| 76 |
-
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
| 77 |
-
tokens_to_remove = probs < (min_p * top_probs)
|
| 78 |
-
probs = probs.masked_fill(tokens_to_remove, 0.0)
|
| 79 |
-
probs.div_(probs.sum(dim=-1, keepdim=True))
|
| 80 |
-
return probs
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def modify_logit_for_repetition_penalty(
|
| 84 |
-
logits: torch.Tensor,
|
| 85 |
-
generated_tokens: torch.Tensor,
|
| 86 |
-
repetition_penalty: float,
|
| 87 |
-
repetition_penalty_window: int,
|
| 88 |
-
):
|
| 89 |
-
"""See https://arxiv.org/abs/1909.05858
|
| 90 |
-
Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
|
| 91 |
-
logits: (batch_size, n_codebooks, vocab_size)
|
| 92 |
-
generated_tokens: (batch_size, n_codebooks, seq_len)
|
| 93 |
-
"""
|
| 94 |
-
generated_tokens = generated_tokens[..., -repetition_penalty_window:]
|
| 95 |
-
generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
|
| 96 |
-
rp = torch.full_like(logits, repetition_penalty)
|
| 97 |
-
factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
|
| 98 |
-
return torch.where(logits <= 0, logits * factors, logits / factors)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def sample_from_logits(
|
| 102 |
-
logits: torch.Tensor,
|
| 103 |
-
temperature: float = 1.0,
|
| 104 |
-
top_p: float = 0.0,
|
| 105 |
-
top_k: int = 0,
|
| 106 |
-
min_p: float = 0.0,
|
| 107 |
-
generated_tokens: torch.Tensor | None = None,
|
| 108 |
-
repetition_penalty: float = 3.0,
|
| 109 |
-
repetition_penalty_window: float = 2,
|
| 110 |
-
) -> torch.Tensor:
|
| 111 |
-
"""Sample next token from logits using temperature, top-p, top-k, or min-p sampling.
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
logits (torch.Tensor): Input logits with token candidates on the last dimension.
|
| 115 |
-
temperature (float): Sampling temperature. Lower temperature results in more deterministic samples.
|
| 116 |
-
top_p (float): The p in “top-p”.
|
| 117 |
-
top_k (int): The k in “top-k”.
|
| 118 |
-
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
| 119 |
-
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
| 120 |
-
|
| 121 |
-
Returns:
|
| 122 |
-
torch.Tensor: Sampled tokens.
|
| 123 |
-
"""
|
| 124 |
-
if repetition_penalty != 1.0 and generated_tokens is not None:
|
| 125 |
-
logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
|
| 126 |
-
|
| 127 |
-
if temperature > 0:
|
| 128 |
-
probs = torch.softmax(logits / temperature, dim=-1)
|
| 129 |
-
|
| 130 |
-
if top_p > 0:
|
| 131 |
-
probs = apply_top_p(probs, top_p)
|
| 132 |
-
if top_k > 0:
|
| 133 |
-
probs = apply_top_k(probs, top_k)
|
| 134 |
-
if min_p > 0:
|
| 135 |
-
probs = apply_min_p(probs, min_p)
|
| 136 |
-
|
| 137 |
-
next_token = multinomial(probs, num_samples=1)
|
| 138 |
-
else:
|
| 139 |
-
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 140 |
-
|
| 141 |
-
return next_token # [batch_size, num_codebooks, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qhash/speaker_cloning.py
DELETED
|
@@ -1,406 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
from functools import cache
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
import torchaudio
|
| 8 |
-
from huggingface_hub import hf_hub_download
|
| 9 |
-
import os
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class logFbankCal(nn.Module):
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
sample_rate: int = 16_000,
|
| 16 |
-
n_fft: int = 512,
|
| 17 |
-
win_length: float = 0.025,
|
| 18 |
-
hop_length: float = 0.01,
|
| 19 |
-
n_mels: int = 80,
|
| 20 |
-
):
|
| 21 |
-
super().__init__()
|
| 22 |
-
self.fbankCal = torchaudio.transforms.MelSpectrogram(
|
| 23 |
-
sample_rate=sample_rate,
|
| 24 |
-
n_fft=n_fft,
|
| 25 |
-
win_length=int(win_length * sample_rate),
|
| 26 |
-
hop_length=int(hop_length * sample_rate),
|
| 27 |
-
n_mels=n_mels,
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
def forward(self, x):
|
| 31 |
-
out = self.fbankCal(x)
|
| 32 |
-
out = torch.log(out + 1e-6)
|
| 33 |
-
out = out - out.mean(axis=2).unsqueeze(dim=2)
|
| 34 |
-
return out
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class ASP(nn.Module):
|
| 38 |
-
# Attentive statistics pooling
|
| 39 |
-
def __init__(self, in_planes, acoustic_dim):
|
| 40 |
-
super(ASP, self).__init__()
|
| 41 |
-
outmap_size = int(acoustic_dim / 8)
|
| 42 |
-
self.out_dim = in_planes * 8 * outmap_size * 2
|
| 43 |
-
|
| 44 |
-
self.attention = nn.Sequential(
|
| 45 |
-
nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
|
| 46 |
-
nn.ReLU(),
|
| 47 |
-
nn.BatchNorm1d(128),
|
| 48 |
-
nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
|
| 49 |
-
nn.Softmax(dim=2),
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
def forward(self, x):
|
| 53 |
-
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
| 54 |
-
w = self.attention(x)
|
| 55 |
-
mu = torch.sum(x * w, dim=2)
|
| 56 |
-
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
| 57 |
-
x = torch.cat((mu, sg), 1)
|
| 58 |
-
|
| 59 |
-
x = x.view(x.size()[0], -1)
|
| 60 |
-
return x
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class SimAMBasicBlock(nn.Module):
|
| 64 |
-
expansion = 1
|
| 65 |
-
|
| 66 |
-
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
| 67 |
-
super(SimAMBasicBlock, self).__init__()
|
| 68 |
-
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 69 |
-
self.bn1 = NormLayer(planes)
|
| 70 |
-
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 71 |
-
self.bn2 = NormLayer(planes)
|
| 72 |
-
self.relu = nn.ReLU(inplace=True)
|
| 73 |
-
self.sigmoid = nn.Sigmoid()
|
| 74 |
-
|
| 75 |
-
self.downsample = nn.Sequential()
|
| 76 |
-
if stride != 1 or in_planes != self.expansion * planes:
|
| 77 |
-
self.downsample = nn.Sequential(
|
| 78 |
-
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 79 |
-
NormLayer(self.expansion * planes),
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
def forward(self, x):
|
| 83 |
-
out = self.relu(self.bn1(self.conv1(x)))
|
| 84 |
-
out = self.bn2(self.conv2(out))
|
| 85 |
-
out = self.SimAM(out)
|
| 86 |
-
out += self.downsample(x)
|
| 87 |
-
out = self.relu(out)
|
| 88 |
-
return out
|
| 89 |
-
|
| 90 |
-
def SimAM(self, X, lambda_p=1e-4):
|
| 91 |
-
n = X.shape[2] * X.shape[3] - 1
|
| 92 |
-
d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
|
| 93 |
-
v = d.sum(dim=[2, 3], keepdim=True) / n
|
| 94 |
-
E_inv = d / (4 * (v + lambda_p)) + 0.5
|
| 95 |
-
return X * self.sigmoid(E_inv)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
class BasicBlock(nn.Module):
|
| 99 |
-
expansion = 1
|
| 100 |
-
|
| 101 |
-
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
| 102 |
-
super(BasicBlock, self).__init__()
|
| 103 |
-
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 104 |
-
self.bn1 = NormLayer(planes)
|
| 105 |
-
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 106 |
-
self.bn2 = NormLayer(planes)
|
| 107 |
-
self.relu = nn.ReLU(inplace=True)
|
| 108 |
-
|
| 109 |
-
self.downsample = nn.Sequential()
|
| 110 |
-
if stride != 1 or in_planes != self.expansion * planes:
|
| 111 |
-
self.downsample = nn.Sequential(
|
| 112 |
-
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 113 |
-
NormLayer(self.expansion * planes),
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
def forward(self, x):
|
| 117 |
-
out = self.relu(self.bn1(self.conv1(x)))
|
| 118 |
-
out = self.bn2(self.conv2(out))
|
| 119 |
-
out += self.downsample(x)
|
| 120 |
-
out = self.relu(out)
|
| 121 |
-
return out
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class Bottleneck(nn.Module):
|
| 125 |
-
expansion = 4
|
| 126 |
-
|
| 127 |
-
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
| 128 |
-
super(Bottleneck, self).__init__()
|
| 129 |
-
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
| 130 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
| 131 |
-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 132 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
| 133 |
-
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
| 134 |
-
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
| 135 |
-
|
| 136 |
-
self.shortcut = nn.Sequential()
|
| 137 |
-
if stride != 1 or in_planes != self.expansion * planes:
|
| 138 |
-
self.shortcut = nn.Sequential(
|
| 139 |
-
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 140 |
-
nn.BatchNorm2d(self.expansion * planes),
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
def forward(self, x):
|
| 144 |
-
out = F.relu(self.bn1(self.conv1(x)))
|
| 145 |
-
out = F.relu(self.bn2(self.conv2(out)))
|
| 146 |
-
out = self.bn3(self.conv3(out))
|
| 147 |
-
out += self.shortcut(x)
|
| 148 |
-
out = F.relu(out)
|
| 149 |
-
return out
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
class ResNet(nn.Module):
|
| 153 |
-
def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
|
| 154 |
-
super(ResNet, self).__init__()
|
| 155 |
-
if feat_dim == "1d":
|
| 156 |
-
self.NormLayer = nn.BatchNorm1d
|
| 157 |
-
self.ConvLayer = nn.Conv1d
|
| 158 |
-
elif feat_dim == "2d":
|
| 159 |
-
self.NormLayer = nn.BatchNorm2d
|
| 160 |
-
self.ConvLayer = nn.Conv2d
|
| 161 |
-
elif feat_dim == "3d":
|
| 162 |
-
self.NormLayer = nn.BatchNorm3d
|
| 163 |
-
self.ConvLayer = nn.Conv3d
|
| 164 |
-
else:
|
| 165 |
-
print("error")
|
| 166 |
-
|
| 167 |
-
self.in_planes = in_planes
|
| 168 |
-
|
| 169 |
-
self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 170 |
-
self.bn1 = self.NormLayer(in_planes)
|
| 171 |
-
self.relu = nn.ReLU(inplace=True)
|
| 172 |
-
self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
|
| 173 |
-
self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
|
| 174 |
-
self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
|
| 175 |
-
self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
|
| 176 |
-
|
| 177 |
-
def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
|
| 178 |
-
strides = [stride] + [1] * (num_blocks - 1)
|
| 179 |
-
layers = []
|
| 180 |
-
for stride in strides:
|
| 181 |
-
layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
|
| 182 |
-
self.in_planes = planes * block.expansion
|
| 183 |
-
return nn.Sequential(*layers)
|
| 184 |
-
|
| 185 |
-
def forward(self, x):
|
| 186 |
-
x = self.relu(self.bn1(self.conv1(x)))
|
| 187 |
-
x = self.layer1(x)
|
| 188 |
-
x = self.layer2(x)
|
| 189 |
-
x = self.layer3(x)
|
| 190 |
-
x = self.layer4(x)
|
| 191 |
-
return x
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def ResNet293(in_planes: int, **kwargs):
|
| 195 |
-
return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
class ResNet293_based(nn.Module):
|
| 199 |
-
def __init__(
|
| 200 |
-
self,
|
| 201 |
-
in_planes: int = 64,
|
| 202 |
-
embd_dim: int = 256,
|
| 203 |
-
acoustic_dim: int = 80,
|
| 204 |
-
featCal=None,
|
| 205 |
-
dropout: float = 0,
|
| 206 |
-
**kwargs,
|
| 207 |
-
):
|
| 208 |
-
super(ResNet293_based, self).__init__()
|
| 209 |
-
self.featCal = featCal
|
| 210 |
-
self.front = ResNet293(in_planes)
|
| 211 |
-
block_expansion = SimAMBasicBlock.expansion
|
| 212 |
-
self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
|
| 213 |
-
self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
|
| 214 |
-
self.drop = nn.Dropout(dropout) if dropout else None
|
| 215 |
-
|
| 216 |
-
def forward(self, x):
|
| 217 |
-
x = self.featCal(x)
|
| 218 |
-
x = self.front(x.unsqueeze(dim=1))
|
| 219 |
-
x = self.pooling(x)
|
| 220 |
-
if self.drop:
|
| 221 |
-
x = self.drop(x)
|
| 222 |
-
x = self.bottleneck(x)
|
| 223 |
-
return x
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
class SEModule(nn.Module):
|
| 227 |
-
def __init__(self, channels, bottleneck=128):
|
| 228 |
-
super(SEModule, self).__init__()
|
| 229 |
-
self.se = nn.Sequential(
|
| 230 |
-
nn.AdaptiveAvgPool1d(1),
|
| 231 |
-
nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
|
| 232 |
-
nn.ReLU(),
|
| 233 |
-
# nn.BatchNorm1d(bottleneck), # Removed
|
| 234 |
-
nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
|
| 235 |
-
nn.Sigmoid(),
|
| 236 |
-
)
|
| 237 |
-
|
| 238 |
-
def forward(self, input):
|
| 239 |
-
x = self.se(input)
|
| 240 |
-
return input * x
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
class Bottle2neck(nn.Module):
|
| 244 |
-
def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
|
| 245 |
-
super(Bottle2neck, self).__init__()
|
| 246 |
-
width = int(math.floor(planes / scale))
|
| 247 |
-
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
|
| 248 |
-
self.bn1 = nn.BatchNorm1d(width * scale)
|
| 249 |
-
self.nums = scale - 1
|
| 250 |
-
convs = []
|
| 251 |
-
bns = []
|
| 252 |
-
num_pad = math.floor(kernel_size / 2) * dilation
|
| 253 |
-
for i in range(self.nums):
|
| 254 |
-
convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
|
| 255 |
-
bns.append(nn.BatchNorm1d(width))
|
| 256 |
-
self.convs = nn.ModuleList(convs)
|
| 257 |
-
self.bns = nn.ModuleList(bns)
|
| 258 |
-
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
|
| 259 |
-
self.bn3 = nn.BatchNorm1d(planes)
|
| 260 |
-
self.relu = nn.ReLU()
|
| 261 |
-
self.width = width
|
| 262 |
-
self.se = SEModule(planes)
|
| 263 |
-
|
| 264 |
-
def forward(self, x):
|
| 265 |
-
residual = x
|
| 266 |
-
out = self.conv1(x)
|
| 267 |
-
out = self.relu(out)
|
| 268 |
-
out = self.bn1(out)
|
| 269 |
-
|
| 270 |
-
spx = torch.split(out, self.width, 1)
|
| 271 |
-
for i in range(self.nums):
|
| 272 |
-
if i == 0:
|
| 273 |
-
sp = spx[i]
|
| 274 |
-
else:
|
| 275 |
-
sp = sp + spx[i]
|
| 276 |
-
sp = self.convs[i](sp)
|
| 277 |
-
sp = self.relu(sp)
|
| 278 |
-
sp = self.bns[i](sp)
|
| 279 |
-
if i == 0:
|
| 280 |
-
out = sp
|
| 281 |
-
else:
|
| 282 |
-
out = torch.cat((out, sp), 1)
|
| 283 |
-
out = torch.cat((out, spx[self.nums]), 1)
|
| 284 |
-
|
| 285 |
-
out = self.conv3(out)
|
| 286 |
-
out = self.relu(out)
|
| 287 |
-
out = self.bn3(out)
|
| 288 |
-
|
| 289 |
-
out = self.se(out)
|
| 290 |
-
out += residual
|
| 291 |
-
return out
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
class ECAPA_TDNN(nn.Module):
|
| 295 |
-
def __init__(self, C, featCal):
|
| 296 |
-
super(ECAPA_TDNN, self).__init__()
|
| 297 |
-
self.featCal = featCal
|
| 298 |
-
self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
|
| 299 |
-
self.relu = nn.ReLU()
|
| 300 |
-
self.bn1 = nn.BatchNorm1d(C)
|
| 301 |
-
self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
|
| 302 |
-
self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
|
| 303 |
-
self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
|
| 304 |
-
# I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
|
| 305 |
-
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
|
| 306 |
-
self.attention = nn.Sequential(
|
| 307 |
-
nn.Conv1d(4608, 256, kernel_size=1),
|
| 308 |
-
nn.ReLU(),
|
| 309 |
-
nn.BatchNorm1d(256),
|
| 310 |
-
nn.Tanh(), # Added
|
| 311 |
-
nn.Conv1d(256, 1536, kernel_size=1),
|
| 312 |
-
nn.Softmax(dim=2),
|
| 313 |
-
)
|
| 314 |
-
self.bn5 = nn.BatchNorm1d(3072)
|
| 315 |
-
self.fc6 = nn.Linear(3072, 192)
|
| 316 |
-
self.bn6 = nn.BatchNorm1d(192)
|
| 317 |
-
|
| 318 |
-
def forward(self, x):
|
| 319 |
-
x = self.featCal(x)
|
| 320 |
-
x = self.conv1(x)
|
| 321 |
-
x = self.relu(x)
|
| 322 |
-
x = self.bn1(x)
|
| 323 |
-
|
| 324 |
-
x1 = self.layer1(x)
|
| 325 |
-
x2 = self.layer2(x + x1)
|
| 326 |
-
x3 = self.layer3(x + x1 + x2)
|
| 327 |
-
|
| 328 |
-
x = self.layer4(torch.cat((x1, x2, x3), dim=1))
|
| 329 |
-
x = self.relu(x)
|
| 330 |
-
|
| 331 |
-
t = x.size()[-1]
|
| 332 |
-
|
| 333 |
-
global_x = torch.cat(
|
| 334 |
-
(
|
| 335 |
-
x,
|
| 336 |
-
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
|
| 337 |
-
torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
|
| 338 |
-
),
|
| 339 |
-
dim=1,
|
| 340 |
-
)
|
| 341 |
-
|
| 342 |
-
w = self.attention(global_x)
|
| 343 |
-
|
| 344 |
-
mu = torch.sum(x * w, dim=2)
|
| 345 |
-
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
|
| 346 |
-
|
| 347 |
-
x = torch.cat((mu, sg), 1)
|
| 348 |
-
x = self.bn5(x)
|
| 349 |
-
x = self.fc6(x)
|
| 350 |
-
x = self.bn6(x)
|
| 351 |
-
|
| 352 |
-
return x
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
class SpeakerEmbedding(nn.Module):
|
| 356 |
-
def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = "cuda"):
|
| 357 |
-
super().__init__()
|
| 358 |
-
self.device = device
|
| 359 |
-
with torch.device(device):
|
| 360 |
-
self.model = ResNet293_based()
|
| 361 |
-
self.model.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
|
| 362 |
-
self.model.featCal = logFbankCal()
|
| 363 |
-
|
| 364 |
-
self.requires_grad_(False).eval()
|
| 365 |
-
|
| 366 |
-
@property
|
| 367 |
-
def dtype(self):
|
| 368 |
-
return next(self.parameters()).dtype
|
| 369 |
-
|
| 370 |
-
@cache
|
| 371 |
-
def _get_resampler(self, orig_sample_rate: int):
|
| 372 |
-
return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
|
| 373 |
-
|
| 374 |
-
def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 375 |
-
assert wav.ndim < 3
|
| 376 |
-
if wav.ndim == 2:
|
| 377 |
-
wav = wav.mean(0, keepdim=True)
|
| 378 |
-
wav = self._get_resampler(sample_rate)(wav)
|
| 379 |
-
return wav
|
| 380 |
-
|
| 381 |
-
def forward(self, wav: torch.Tensor, sample_rate: int):
|
| 382 |
-
wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
|
| 383 |
-
return self.model(wav).to(wav.device)
|
| 384 |
-
|
| 385 |
-
class SpeakerEmbeddingLDA(nn.Module):
|
| 386 |
-
def __init__(
|
| 387 |
-
self,
|
| 388 |
-
device: str = "cuda",
|
| 389 |
-
):
|
| 390 |
-
super().__init__()
|
| 391 |
-
spk_model_path = hf_hub_download(repo_id="Quantamhash/Qhash-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base.pt")
|
| 392 |
-
lda_spk_model_path = hf_hub_download(repo_id="Quantamhash/Qhash-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base_LDA-128.pt")
|
| 393 |
-
|
| 394 |
-
self.device = device
|
| 395 |
-
with torch.device(device):
|
| 396 |
-
self.model = SpeakerEmbedding(spk_model_path, device)
|
| 397 |
-
lda_sd = torch.load(lda_spk_model_path, weights_only=True)
|
| 398 |
-
out_features, in_features = lda_sd["weight"].shape
|
| 399 |
-
self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
|
| 400 |
-
self.lda.load_state_dict(lda_sd)
|
| 401 |
-
|
| 402 |
-
self.requires_grad_(False).eval()
|
| 403 |
-
|
| 404 |
-
def forward(self, wav: torch.Tensor, sample_rate: int):
|
| 405 |
-
emb = self.model(wav, sample_rate).to(torch.float32)
|
| 406 |
-
return emb, self.lda(emb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|