| | import torch |
| | import torch.nn.functional as F |
| | import re |
| | import sphn |
| | from safetensors.torch import load_file |
| | from sentencepiece import SentencePieceProcessor |
| | from einops import rearrange |
| | from collections import deque |
| | from torch import nn |
| |
|
| | from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig |
| | from huggingface_hub import hf_hub_download |
| |
|
| | torch.set_flush_denormal(True) |
| | torch.use_deterministic_algorithms(True) |
| |
|
| |
|
| | class ActivationGating(nn.Module): |
| |
|
| | def __init__(self, dim_feedforward=4224): |
| | super().__init__() |
| | d = 2816 if dim_feedforward == 4224 else 2048 |
| | self.linear_in = nn.Linear(1024, 2 * d, bias=False) |
| | self.linear_out = nn.Linear(d, 1024, bias=False) |
| |
|
| | def forward(self, x): |
| | x = F.linear(x, self.linear_in.weight) |
| | B, T, _ = x.shape |
| | x = x.view(B, T, 2, -1) |
| | x = F.silu(x[:, :, 0, :]) * x[:, :, 1, :] |
| | x = F.linear(x, self.linear_out.weight) |
| | return x |
| |
|
| |
|
| | def apply_rope(q, k, offset=0): |
| | q_type = q.dtype |
| | q = q.to(torch.float) |
| | k = k.to(torch.float) |
| | bs, h, _1, d = k.shape |
| |
|
| | fr = torch.exp(-18.420680743952367 / d * torch.arange(d // 2, device=q.device, dtype=torch.float)) |
| | |
| | |
| |
|
| | t = offset * fr[None, None, :, None] |
| |
|
| | r = torch.cos(t) |
| | i = torch.sin(t) |
| |
|
| | q = q.view(bs, h, d // 2, 2) |
| | k = k.view(bs, h, d // 2, 2) |
| |
|
| | qor = q[:, :, :, :1] * r - q[:, :, :, 1:] * i |
| | qoi = q[:, :, :, :1] * i + q[:, :, :, 1:] * r |
| | kor = k[:, :, :, :1] * r - k[:, :, :, 1:] * i |
| | koi = k[:, :, :, :1] * i + k[:, :, :, 1:] * r |
| |
|
| | qo = torch.cat([qor.to(dtype=q_type), qoi.to(dtype=q_type)], dim=3) |
| | ko = torch.cat([kor.to(dtype=q_type), koi.to(dtype=q_type)], dim=3) |
| |
|
| | return qo.view(bs, h, 1, d), ko.view(bs, h, 1, d) |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, d=1024): |
| | super().__init__() |
| | self.alpha = nn.Parameter(torch.full((1, 1, d), 1.0, dtype=torch.float64)) |
| |
|
| | def forward(self, x): |
| | x = x.to(torch.float64) |
| | v = 9e-9 + torch.mean(x * x, dim=2, keepdim=True) |
| | return (x * (self.alpha * torch.rsqrt(v))).to(torch.bfloat16) |
| |
|
| |
|
| | class LLMAttention(nn.Module): |
| |
|
| | def __init__(self, weights_per_step): |
| | super().__init__() |
| | self.weights_per_step = weights_per_step |
| | self.k_history = None |
| | self.v_history = None |
| | p = 9 if weights_per_step else 1 |
| | self.out_projs = nn.ModuleList([nn.Linear(1024, 1024, bias=False) for _ in range(p)]) |
| | self.in_projs = nn.ModuleList([nn.Linear(1024, 3 * 1024, bias=False) for _ in range(p)]) |
| |
|
| | def forward(self, query): |
| |
|
| | offset = 0 if self.k_history is None else self.k_history.shape[2] |
| |
|
| | if (self.weights_per_step and offset % self.weights_per_step == 0) or (offset % 473 == 0): |
| | self.k_history = None |
| | self.v_history = None |
| | offset = 0 |
| |
|
| | if self.weights_per_step: |
| | x = self.in_projs[offset if offset < 9 else 8](query) |
| | else: |
| | x = self.in_projs[0](query) |
| | q, k, v = rearrange(x, "b t (p h d) -> p b h t d", p=3, h=16) |
| | q, k = apply_rope(q, k, offset=offset) |
| | |
| | if self.k_history is not None: |
| | self.k_history = torch.cat([self.k_history, k], 2) |
| | self.v_history = torch.cat([self.v_history, v], 2) |
| | else: |
| | self.k_history = k |
| | self.v_history = v |
| | k = self.k_history |
| | v = self.v_history |
| | |
| | x = F.scaled_dot_product_attention(q, k, v, torch.ones(k.shape[0], 1, 1, k.shape[2],dtype=torch.bool, device=k.device)) |
| | x = rearrange(x, "b h t d -> b t (h d)") |
| | if self.weights_per_step: |
| | return self.out_projs[offset if offset < 9 else 8](x) |
| | return self.out_projs[0](x) |
| |
|
| |
|
| | class LLMTransformerLayer(nn.Module): |
| |
|
| | def __init__(self, weights_per_step=None): |
| | super().__init__() |
| | self.self_attn = LLMAttention(weights_per_step=weights_per_step) |
| | self.norm1 = RMSNorm() |
| | self.norm2 = RMSNorm() |
| | self.weights_per_step = weights_per_step |
| | if self.weights_per_step: |
| | self.gating = nn.ModuleList([ActivationGating(3072) for _ in range(9)]) |
| | else: |
| | self.gating = ActivationGating() |
| |
|
| | def forward(self, x): |
| | x = self.self_attn(self.norm1(x)) + x |
| | if self.weights_per_step: |
| | p = self.self_attn.k_history.shape[2] - 1 |
| | return x + self.gating[p if p < 9 else 8](self.norm2(x)) |
| | return x + self.gating(self.norm2(x)) |
| |
|
| |
|
| | class LLMTransformer(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | num_layers=24, |
| | weights_per_step=False): |
| | super().__init__() |
| | self.layers = nn.ModuleList( |
| | [ |
| | LLMTransformerLayer(weights_per_step=weights_per_step) |
| | for _ in range(num_layers) |
| | ]) |
| |
|
| | def forward(self, x): |
| | for lay in self.layers: |
| | x = lay(x) |
| | return x |
| |
|
| |
|
| | class Voc(Wav2Vec2PreTrainedModel): |
| |
|
| | '''For using different batch_siz -> Voc._flush() |
| | ''' |
| |
|
| | def __init__(self, config=PretrainedConfig()): |
| | super().__init__(config=config) |
| | self.encoder_transformer = VocTransformer() |
| | self.decoder_transformer = VocTransformer() |
| | self.encoder = SEANetEncoder() |
| | self.decoder = SEANetDecoder() |
| | self.sample_rate = 24000 |
| | self.quantizer = SplitResidualVectorQuantizer() |
| | self.downsample = BufferConv1d(512, 512, kernel_size=4, stride=2, groups=1, bias=False) |
| | upsample_channel_wise_bug = True |
| | self.upsample = BufferConvTranspose1d(512, 512, kernel_size=4, |
| | groups=512 if upsample_channel_wise_bug else 1, |
| | stride=2, bias=False) |
| | self.frame_rate = 12.5 |
| | self.encode_buffer = None |
| |
|
| | def _flush(self): |
| | '''stream buffers have tensors of old batch size! Voc()._flush() to clean buffers |
| | ''' |
| | self.encode_buffer = None |
| | if self.downsample.previous is not None: |
| | self.downsample.previous = None |
| | if self.upsample.partial is not None: |
| | self.upsample.partial = None |
| | for arch in [self.encoder, self.decoder]: |
| | for _m in arch.model: |
| | if type(_m) is SEANetResnetBlock: |
| | for _b in _m.block: |
| | if type(_b) is BufferConv1d: |
| | if _b.previous is not None: |
| | _b.previous = None |
| | if type(_m) is BufferConv1d: |
| | if _m.previous is not None: |
| | _m.previous = None |
| | if type(_m) is BufferConvTranspose1d: |
| | if _m.partial is not None: |
| | _m.partial = None |
| |
|
| | @torch.no_grad() |
| | def encode(self, x): |
| | '''24KHz audio to codes |
| | x : [bs, 1, 24 KHz] |
| | c : [bs, 8, time] = 1920 audio samples produce 1 time frame (of n_q codebooks) |
| | ''' |
| | if self.encode_buffer is not None: |
| | x = torch.cat([self.encode_buffer, x], 2) |
| | _bs, _1, _len = x.shape |
| | num_frames = int(_len / 1920) |
| | leftover = x[:, :, (num_frames+1) * 1920:] |
| | if leftover.shape[2] > 0: |
| | self.encode_buffer = leftover |
| | else: |
| | self.encode_buffer = None |
| | torch.cuda.empty_cache() |
| | if num_frames > 0: |
| | c = [] |
| | for n in range(num_frames): |
| | e = self.encoder(x[:, :, n * 1920:(n + 1) * 1920]) |
| | e = self.encoder_transformer(e) |
| | e = self.downsample(e) |
| | _c = self.quantizer.encode(e) |
| | c.append(_c) |
| | c = torch.cat(c, 2) |
| | else: |
| | |
| | c = torch.empty(_bs, 16, 0) |
| | return c |
| |
|
| | @torch.no_grad() |
| | def decode(self, c): |
| | '''codes to 24kHZ audio |
| | c: [bs, 8, n_tokens] |
| | x: [bs, 1, n_tokens * 1920] |
| | ''' |
| | _hidden = [] |
| | for i in range(c.shape[2]): |
| | x = self.quantizer.decode(c[:, :, i:i+1]) |
| | x = self.upsample(x) |
| | x = self.decoder_transformer(x) |
| | x = self.decoder(x) |
| | _hidden.append(x) |
| | return torch.cat(_hidden, 2) |
| |
|
| |
|
| | class SEANetResnetBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | kernel_sizes=[3, 1], |
| | ): |
| | super().__init__() |
| |
|
| | block = [] |
| | for i, kernel_size in enumerate(kernel_sizes): |
| |
|
| | block += [ |
| | nn.ELU(), |
| | BufferConv1d( |
| | dim if i == 0 else dim // 2, |
| | dim // 2 if i == 0 else dim, |
| | kernel_size=kernel_size, |
| | bias=True, |
| | ), |
| | ] |
| |
|
| | self.block = nn.Sequential(*block) |
| |
|
| | def forward(self, x): |
| | return x + self.block(x) |
| |
|
| |
|
| | class SEANetEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | channels=1, |
| | dimension=512, |
| | n_filters=64, |
| | ratios=[8, 6, 5, 4], |
| | kernel_size=7, |
| | last_kernel_size=3, |
| | ): |
| | super().__init__() |
| | self.ratios = list(reversed(ratios)) |
| | del ratios |
| | mult = 1 |
| | model=[ |
| | BufferConv1d( |
| | channels, |
| | mult * n_filters, |
| | kernel_size, |
| | bias=True |
| | ) |
| | ] |
| | for i, ratio in enumerate(self.ratios): |
| | model += [SEANetResnetBlock(mult * n_filters), |
| | nn.ELU(), |
| | BufferConv1d(mult * n_filters, |
| | mult * n_filters * 2, |
| | kernel_size=ratio * 2, |
| | stride=ratio, |
| | bias=True)] |
| | mult *= 2 |
| | |
| | model += [nn.ELU(), |
| | BufferConv1d(mult * n_filters, |
| | dimension, |
| | last_kernel_size, |
| | bias=True)] |
| | self.model = nn.Sequential(*model) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| |
|
| | class SEANetDecoder(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | channels=1, |
| | dimension=512, |
| | n_filters=64, |
| | ratios=[8, 6, 5, 4], |
| | kernel_size=7, |
| | last_kernel_size=3): |
| |
|
| | super().__init__() |
| | mult = int(2 ** len(ratios)) |
| | model = [BufferConv1d(dimension, |
| | mult * n_filters, |
| | kernel_size, |
| | bias=True)] |
| | |
| | for i, ratio in enumerate(ratios): |
| | model += [nn.ELU(), |
| | BufferConvTranspose1d(mult * n_filters, |
| | mult * n_filters // 2, |
| | kernel_size=ratio * 2, |
| | stride=ratio, |
| | bias=True), |
| | SEANetResnetBlock(mult * n_filters // 2)] |
| | mult //= 2 |
| | |
| | model += [ |
| | nn.ELU(), |
| | BufferConv1d( |
| | n_filters, |
| | channels, |
| | last_kernel_size, |
| | bias=True |
| | ), |
| | ] |
| | self.model = nn.Sequential(*model) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| |
|
| | class BufferConv1d(nn.Conv1d): |
| | def __init__(self, |
| | *args, |
| | **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.previous = None |
| |
|
| | def forward(self, x): |
| | k = self.kernel_size[0] |
| |
|
| | if self.previous is not None: |
| |
|
| | x = torch.cat([self.previous, x], 2) |
| |
|
| | else: |
| |
|
| | if k == 3: |
| |
|
| | p = (2, 0) |
| | x = F.pad(x, p, mode='replicate', value=0.0) |
| |
|
| | elif k == 4: |
| |
|
| | p = (3, 0) |
| | x = F.pad(x, p, mode='replicate', value=0.0) |
| |
|
| | elif k == 7: |
| |
|
| | p = (6, 0) |
| | x = F.pad(x, p, mode='replicate', value=0.0) |
| |
|
| | elif k == 16: |
| |
|
| | p = (2, 0) |
| | x = F.pad(x, p, mode='replicate', value=0.0) |
| |
|
| | num_frames = int( (x.shape[2] - self.kernel_size[0]) / self.stride[0] ) + 1 |
| | offset = num_frames * self.stride[0] |
| | self.previous = x[..., offset:] |
| | return super().forward(x) |
| |
|
| |
|
| | class BufferConvTranspose1d(nn.ConvTranspose1d): |
| | |
| | |
| | def __init__(self, |
| | *args, |
| | **kwargs): |
| | super().__init__(*args, |
| | **kwargs) |
| | self.partial = None |
| | |
| | def forward(self, x): |
| | out = super().forward(x) |
| | OT = out.shape[2] |
| | invalid_steps = self.kernel_size[0] - self.stride[0] |
| | if self.partial is not None: |
| | PT = self.partial.shape[-1] |
| | if self.bias is not None: |
| | out[..., :PT] += self.partial - self.bias[:, None] |
| | else: |
| | out[..., :PT] += self.partial |
| | invalid_steps = self.kernel_size[0] - self.stride[0] |
| | self.partial = out[..., OT - invalid_steps :] |
| | out = out[...,:OT - invalid_steps] |
| | return out |
| |
|
| |
|
| | class CodeBook(nn.Module): |
| | def __init__(self, dim, codebook_size): |
| | super().__init__() |
| | self.register_buffer('_e', torch.zeros(codebook_size, dim)) |
| |
|
| | def encode(self, x): |
| | dist = torch.cdist( |
| | x.transpose(1, 2), |
| | self._e[None, :, :] |
| | ) |
| | codes = dist.argmin(2) |
| | return codes |
| |
|
| | def decode(self, codes): |
| | quantized = F.embedding(codes, self._e) |
| | return quantized.transpose(1, 2) |
| |
|
| |
|
| | class SplitResidualVectorQuantizer(nn.Module): |
| |
|
| | def __init__(self, |
| | n_q=None): |
| | super().__init__() |
| | self.in_proj_s = torch.nn.Conv1d(512, 256, 1, bias=False) |
| | self.in_proj_a = torch.nn.Conv1d(512, 256, 1, bias=False) |
| | self.out_proj_s = torch.nn.Conv1d(256, 512, 1, bias=False) |
| | self.out_proj_a = torch.nn.Conv1d(256, 512, 1, bias=False) |
| | self.layers = nn.ModuleList([CodeBook(dim=256, codebook_size=2048) for _ in range(18)]) |
| | self._acoustic_books = range(1, 16) |
| | |
| | |
| | |
| | |
| |
|
| | def encode(self, x): |
| | indices = self.layers[0].encode(self.in_proj_s(x)) |
| | all_indices = [ indices[:, None, :], ] |
| | x = self.in_proj_a(x) |
| | for _cb in self._acoustic_books: |
| | indices = self.layers[_cb].encode(x) |
| | x = x - self.layers[_cb].decode(indices) |
| | all_indices.append(indices[:, None, :]) |
| | codes = torch.cat(all_indices, 1) |
| | return codes |
| |
|
| | def decode(self, codes): |
| | _s = self.layers[0].decode(codes[:, 0, :]) |
| | _a = torch.zeros([1, 1], device=codes.device) |
| | for i, _cb in enumerate(self._acoustic_books): |
| | _a = _a + self.layers[_cb].decode(codes[:, i+1, :]) |
| | return self.out_proj_s(_s) + self.out_proj_a(_a) |
| |
|
| |
|
| | class VocAttention(nn.Module): |
| |
|
| | def __init__(self, |
| | embed_dim): |
| |
|
| | super().__init__() |
| | self.fused_proj = nn.Parameter(torch.zeros(embed_dim, embed_dim)) |
| |
|
| | def forward(self, x): |
| | '''bypass of streaming training''' |
| | if x.shape[1] > 1: |
| | x = x.mean(1, keepdims=True) |
| | x = torch.matmul(x, self.fused_proj) |
| | return x |
| |
|
| |
|
| | class VocTransformerLayer(nn.Module): |
| |
|
| | def __init__(self, d_model=512, dim_feedforward=2048): |
| | super().__init__() |
| | self.self_attn = VocAttention(embed_dim=d_model) |
| | self.norm1 = nn.LayerNorm(d_model, eps=1e-5) |
| | self.norm2 = nn.LayerNorm(d_model, eps=1e-5) |
| | self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) |
| |
|
| | def forward(self, x): |
| | x = x + self.self_attn(self.norm1(x)) |
| | return x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) |
| |
|
| |
|
| | class VocTransformer(nn.Module): |
| |
|
| | def __init__(self): |
| |
|
| | super().__init__() |
| | self.layers = nn.ModuleList(VocTransformerLayer() for _ in range(8)) |
| |
|
| | def forward(self, x): |
| | x = x.transpose(1, 2) |
| | for la in self.layers: |
| | x = la(x) |
| | return x.transpose(1, 2) |
| |
|
| | class Entry(): |
| | def __init__(self, tokens=None): |
| | self.tokens = tokens |
| | self.padding = len(tokens) + 2 - 1 |
| |
|
| | class TokenState: |
| |
|
| | def __init__(self, entries = None): |
| | self.entries = entries |
| | self.queued = deque([]) |
| | self.lookahead_queued = deque() |
| | self.end_step = None |
| | self.forced_padding = 2 |
| |
|
| | class TTSModel(nn.Module): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.tokenizer = SentencePieceProcessor(str(hf_hub_download(repo_id='kyutai/tts-0.75b-en-public', |
| | filename='tokenizer_spm_8k_en_fr_audio.model'))) |
| | with torch.device("meta"): |
| | self.emb = nn.ModuleList([ScaledEmbedding(2049, 1024) for _ in range(16)]) |
| | self.text_emb = ScaledEmbedding(8001, 1024, demux_second_stream=True) |
| | self.transformer = LLMTransformer() |
| | self.out_norm = RMSNorm() |
| | self.depformer_in = nn.ModuleList([nn.Linear(1024, 1024, bias=False) for _ in range(9)]) |
| | self.depformer_emb = nn.ModuleList([ScaledEmbedding(2049, 128) for _ in range(16 - 1)]) |
| | self.depformer_text_emb = ScaledEmbedding(8001, 128, demux_second_stream=True) |
| | self.depformer = LLMTransformer(num_layers=4, weights_per_step=16) |
| | self.linears = nn.ModuleList([nn.Linear(1024, 2048, bias=False) for _ in range(16)]) |
| |
|
| | s = load_file('tts_075B.safetensors') |
| | self.load_state_dict(s, assign=True, strict=True) |
| | self.to(dtype=torch.bfloat16).eval() |
| |
|
| | def prepare_script(self, script): |
| | entries = [] |
| | event_re = re.compile(r"(?:<break\s+time=\"([0-9]+(?:.[0-9]*)?)s\"\s*/?>)|(?:\s+)") |
| | line = script.replace('’', "'").replace(':', " ").replace('(', "").replace(')', "") |
| | while line: |
| | match = event_re.search(line) |
| | if match is None: |
| | break |
| | word = line[:match.start()] |
| | line = line[match.end():] |
| | if word: |
| | entries.append(Entry(tokens=self.tokenizer.encode(word))) |
| | if match.group(1): |
| | pass |
| | if line: |
| | entries.append(Entry(tokens=self.tokenizer.encode(line))) |
| | return entries |
| |
|
| |
|
| | @property |
| | def device(self): |
| | return next(iter(self.parameters())).device |
| |
|
| | @torch.no_grad() |
| | def generate(self, |
| | text='Type your text <break time="3s"/ here Farover.', |
| | voice_path=None, |
| | mimi=None): |
| | _wav, _ = sphn.read(voice_path, |
| | sample_rate=24000) |
| | _wav = mimi.encode(torch.from_numpy(_wav).to(device=self.device)[None])[0, :, :] |
| | state = TokenState(entries=deque(self.prepare_script(script=text))) |
| | upper_lim = 2 * sum([len(p.tokens) for p in state.entries]) |
| | self.cache = torch.full((2,17, 4), -1, device=self.device, dtype=torch.long) |
| | pcms = [] |
| | for offset in range(4 * upper_lim): |
| | print(f'{offset=} of {upper_lim=}',end='\r') |
| | if state.end_step is not None: |
| | if offset >= state.end_step + 16 + 4: |
| | break |
| |
|
| | input_ = self.cache[:, :, offset % self.cache.shape[2]].clone() |
| |
|
| | if offset == 0: |
| | input_[:, 0] = 8000 |
| | input_[:, 1:] = 2048 |
| |
|
| | if offset < 3: |
| | input_[:, 2:] = 2048 |
| |
|
| |
|
| | x = self.text_emb(input_[:, :1]) |
| | for cb_ in range(16): |
| | x = self.emb[cb_](input_[:, cb_ + 1 : cb_ + 2]) + x |
| | x = self.out_norm(self.transformer(x)) |
| |
|
| |
|
| | token = -1 |
| | if offset > _wav.shape[1]: |
| | token = 0 |
| | |
| | if state.queued: |
| | token = 3 |
| | if state.forced_padding > 0: |
| | token = 3 |
| | |
| | if token == 0: |
| | if state.entries: |
| | e = state.entries.popleft() |
| | if e.tokens: |
| | state.queued.extend(e.tokens) |
| | lookahead =2 |
| | for e2 in state.entries: |
| | if e2.tokens: |
| | lookahead -= 1 |
| | if lookahead == 0: |
| | state.lookahead_queued.extend(e2.tokens) |
| | break |
| | |
| | |
| | else: |
| | token = 3 |
| | state.forced_padding = e.padding |
| | |
| | else: |
| | token = 3 |
| | if state.end_step is None: |
| | token = 0 |
| | if state.end_step is None: |
| | state.end_step = offset |
| | |
| | output=0 |
| | if token == 3: |
| | if state.forced_padding > 0: |
| | state.forced_padding -= 1 |
| | if state.queued: |
| | output = state.queued.popleft() |
| | else: |
| | output = 3 |
| | |
| | second = -1 |
| | if output == 0: |
| | second = 0 |
| | if state.queued: |
| | output = state.queued.popleft() |
| | else: |
| | output = 3 |
| | elif state.lookahead_queued: |
| | second = state.lookahead_queued.popleft() |
| | token = (second + 1) * 8001 + output |
| | |
| | |
| | ac = (offset + 1) % self.cache.shape[2] |
| |
|
| | self.cache[0, 0, ac] = token |
| |
|
| | if offset > 16: |
| | audio_tokens = input_[:1, 1:] |
| | prev_token = torch.tensor([[token]], device=x.device, dtype=torch.long) |
| | for _cb in range(16): |
| | last_token_input = None |
| | if _cb == 0: |
| | last_token_input = self.depformer_text_emb(prev_token.repeat(2, 1)) |
| | else: |
| | last_token_input = self.depformer_emb[_cb - 1](prev_token) |
| | dep_output = self.depformer(self.depformer_in[_cb if _cb < 9 else 8](x) + last_token_input) |
| | logits = self.linears[_cb](dep_output) |
| | prev_token = (2.0 * logits[0, :, :] - logits[1, :, :]).argmax(1) |
| | audio_tokens[0, _cb] = prev_token |
| | |
| | if offset > 16 and offset < 17 + _wav.shape[1]: |
| | audio_tokens[:, 0] = _wav[0, offset-17] |
| | if offset > 18 and offset < 19 + _wav.shape[1]: |
| | audio_tokens[:, 1:] = _wav[1:, offset -19] |
| | |
| | self.cache[0, 1:, ac] = audio_tokens |
| | |
| | if offset > 16 + 2 + _wav.shape[1]: |
| | if offset > 16 + 4 + _wav.shape[1]: |
| | self.cache[1, 1:, ac] = self.cache[0, 1:, ac] |
| | else: |
| | self.cache[1, 1, ac] = self.cache[0, 1, ac] |
| | |
| | if offset > 20 + _wav.shape[1]: |
| | audio_tokens[:, 0] = self.cache[0, 1, (offset - 1) % self.cache.shape[2]] |
| | pcms.append(mimi.decode(audio_tokens[:, :, None])) |
| | x = torch.cat(pcms, dim=2)[0, 0, :] |
| | return x.cpu().numpy() |
| |
|
| | class ScaledEmbedding(nn.Embedding): |
| | def __init__(self, num_embeddings=None, embedding_dim=None, demux_second_stream=False): |
| | super().__init__(num_embeddings, embedding_dim) |
| | self.zero_idx = -1 |
| | self.low_rank = None |
| | self.demux_second_stream = demux_second_stream |
| | if self.demux_second_stream: |
| | self.out1 = nn.Linear(embedding_dim, 1024, bias=False) |
| | self.out2 = nn.Linear(embedding_dim, 1024, bias=False) |
| | else: |
| | if embedding_dim != 1024: |
| | self.low_rank = nn.Linear(embedding_dim, 1024, bias=False) |
| |
|
| | def forward(self, input): |
| | is_zero = input == self.zero_idx |
| | zero = torch.zeros(1, dtype=input.dtype, device=input.device) |
| | input = input.clamp(min=0) |
| | if self.demux_second_stream: |
| | left = input % self.num_embeddings |
| | right = input // self.num_embeddings |
| | right = right - 1 |
| | left = super().forward(left) |
| | right_zero = (right < 0)[..., None] |
| | right.clamp_(min=0) |
| | right = super().forward(right) |
| | y = self.out1(left) + torch.where(right_zero, zero, self.out2(right)) |
| | y = torch.where(is_zero[..., None], zero, y) |
| | else: |
| | y = super().forward(input) |
| | y = torch.where(is_zero[..., None], zero, y) |
| | if self.low_rank is not None: |
| | |
| | y = self.low_rank(y) |
| | return y |
| |
|
| | text = '''Far over the misty mountains cold |
| | To dungeons deep and caverns old |
| | We must away ere break of day |
| | To seek the pale enchanted gold. |
| | |
| | The dwarves of yore made mighty spells, |
| | While hammers fell like ringing bells |
| | In places deep, where dark things sleep, |
| | In hollow halls beneath the fells. |
| | |
| | For ancient king and elvish lord |
| | There many a gleaming golden hoard |
| | They shaped and wrought, and light they caught |
| | To hide in gems on hilt of sword. |
| | |
| | On silver necklaces they strung |
| | The flowering stars, on crowns they hung |
| | The dragon-fire, in twisted wire |
| | They meshed the light of moon and sun. |
| | |
| | Farewell we call to hearth and hall! |
| | Though wind may blow and rain may fall, |
| | We must away ere break of day |
| | Far over wood and mountain tall.''' |
| |
|
| | device = 'cpu' |
| | tts_model = TTSModel().eval().to(device) |
| | mimi = Voc.from_pretrained('ivao0/voc').eval().to(device) |
| |
|
| | from time import time |
| | t_sta = time() |
| | x = tts_model.generate(text=text, |
| | voice_path='wav/en_US_m-ailabs_mary_ann.wav', |
| | mimi=mimi) |
| | print(time()-t_sta, 'New') |
| | |
| | |
| | sphn.write_wav(f'example.wav', x, 24000) |
| |
|