| --- |
| license: cc-by-4.0 |
| pipeline_tag: text-to-speech |
| language: |
| - en |
| tags: |
| - audio |
| - text-to-speech |
| - attentionless |
| - vocoder |
| base_model: |
| - ivao0/voc |
| --- |
| |
|
|
| # TTS Attentionless VOcoder Streaming |
|
|
| Applies [kyutai TTS 0.75b using](https://huggingface.co/kyutai/tts-0.75b-en-public) [Attentionless VOcoder streaming](https://huggingface.co/ivao0/voc) |
|
|
| <table> |
| <tr> |
| <td> |
|
|
| [Voice files](https://huggingface.co/Dionyssos/_TTS075B/tree/main/wav) |
|
|
| </td> |
| <td> |
|
|
| <audio controls> |
| <source src="https://huggingface.co/Dionyssos/_TTS075B/resolve/main/example.wav" type="audio/wav"> |
| TTS 0.75B using attentionless vocoder streaming example audio |
| </audio> |
| </td> |
|
|
| </tr> |
| </table> |
|
|
| ## Example |
|
|
| ```python |
| import torch#2.9.0 cu126 |
| from torch import nn |
| import torch.nn.functional as F |
| from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig#4.49.0 |
| from huggingface_hub import hf_hub_download |
| import re |
| from collections import deque |
| import sphn |
| from safetensors.torch import load_file |
| from sentencepiece import SentencePieceProcessor |
| from einops import rearrange |
| |
| |
| |
| class ActivationGating(nn.Module): |
| |
| def __init__(self, dim_feedforward=4224): |
| super().__init__() |
| d = 2816 if dim_feedforward == 4224 else 2048 |
| self.linear_in = nn.Linear(1024, 2 * d, bias=False) |
| self.linear_out = nn.Linear(d, 1024, bias=False) |
| |
| def forward(self, x): |
| x = F.linear(x, self.linear_in.weight) |
| B, T, _ = x.shape |
| x = x.view(B, T, 2, -1) |
| x = F.silu(x[:, :, 0, :]) * x[:, :, 1, :] |
| x = F.linear(x, self.linear_out.weight) |
| return x |
| |
| |
| def apply_rope(q, k, offset=0): |
| q_type = q.dtype |
| q = q.to(torch.float) |
| k = k.to(torch.float) |
| bs, h, _1, d = k.shape |
| |
| # fr = torch.exp(-18.420680743952367 / d * torch.arange(d // 2, device=q.device, dtype=torch.float)) |
| # fr = torch.exp(-18.42068099975586 / d * torch.arange(d // 2, device=q.device, dtype=torch.float)) |
| fr = torch.exp(-18.4206809997 / d * torch.arange(d // 2, device=q.device, dtype=torch.float)) |
| |
| t = offset * fr[None, None, :, None] |
| |
| r = torch.cos(t) |
| i = torch.sin(t) |
| |
| q = q.view(bs, h, d // 2, 2) # interleave |
| k = k.view(bs, h, d // 2, 2) |
| |
| qor = q[:, :, :, :1] * r - q[:, :, :, 1:] * i |
| qoi = q[:, :, :, :1] * i + q[:, :, :, 1:] * r |
| kor = k[:, :, :, :1] * r - k[:, :, :, 1:] * i |
| koi = k[:, :, :, :1] * i + k[:, :, :, 1:] * r |
| |
| qo = torch.cat([qor.to(dtype=q_type), qoi.to(dtype=q_type)], dim=3) |
| ko = torch.cat([kor.to(dtype=q_type), koi.to(dtype=q_type)], dim=3) |
| |
| return qo.view(bs, h, 1, d), ko.view(bs, h, 1, d) |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, d=1024): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.full((1, 1, d), 1.0, dtype=torch.float64)) |
| |
| def forward(self, x): |
| x = x.to(torch.float64) |
| v = 9e-9 + torch.mean(x * x, dim=2, keepdim=True) |
| return (x * (self.alpha * torch.rsqrt(v))).to(torch.bfloat16) |
| |
| |
| class LLMAttention(nn.Module): |
| |
| def __init__(self, weights_per_step): |
| super().__init__() |
| self.weights_per_step = weights_per_step |
| self.k_history = None |
| self.v_history = None |
| p = 9 if weights_per_step else 1 |
| self.out_projs = nn.ModuleList([nn.Linear(1024, 1024, bias=False) for _ in range(p)]) |
| self.in_projs = nn.ModuleList([nn.Linear(1024, 3 * 1024, bias=False) for _ in range(p)]) |
| |
| def forward(self, query): |
| |
| offset = 0 if self.k_history is None else self.k_history.shape[2] # if overpass RoPE untrained or DPF 16x |
| |
| if (self.weights_per_step and offset % self.weights_per_step == 0) or (offset % 473 == 0): |
| self.k_history = None |
| self.v_history = None |
| offset = 0 |
| |
| if self.weights_per_step: |
| x = self.in_projs[offset if offset < 9 else 8](query) |
| else: |
| x = self.in_projs[0](query) |
| q, k, v = rearrange(x, "b t (p h d) -> p b h t d", p=3, h=16) |
| q, k = apply_rope(q, k, offset=offset) |
| # KVCACHE |
| if self.k_history is not None: |
| self.k_history = torch.cat([self.k_history, k], 2) |
| self.v_history = torch.cat([self.v_history, v], 2) |
| else: |
| self.k_history = k |
| self.v_history = v |
| k = self.k_history |
| v = self.v_history |
| # ones-bool attn mask sounds better than passing no mask argument |
| x = F.scaled_dot_product_attention(q, k, v, torch.ones(k.shape[0], 1, 1, k.shape[2],dtype=torch.bool, device=k.device)) |
| x = rearrange(x, "b h t d -> b t (h d)") |
| if self.weights_per_step: |
| return self.out_projs[offset if offset < 9 else 8](x) |
| return self.out_projs[0](x) |
| |
| |
| class LLMTransformerLayer(nn.Module): |
| |
| def __init__(self, weights_per_step=None): |
| super().__init__() |
| self.self_attn = LLMAttention(weights_per_step=weights_per_step) |
| self.norm1 = RMSNorm() |
| self.norm2 = RMSNorm() |
| self.weights_per_step = weights_per_step |
| if self.weights_per_step: |
| self.gating = nn.ModuleList([ActivationGating(3072) for _ in range(9)]) |
| else: |
| self.gating = ActivationGating() |
| |
| def forward(self, x): |
| x = self.self_attn(self.norm1(x)) + x |
| if self.weights_per_step: |
| p = self.self_attn.k_history.shape[2] - 1 |
| return x + self.gating[p if p < 9 else 8](self.norm2(x)) |
| return x + self.gating(self.norm2(x)) |
| |
| |
| class LLMTransformer(nn.Module): |
| |
| def __init__( |
| self, |
| num_layers=24, |
| weights_per_step=False): |
| super().__init__() |
| self.layers = nn.ModuleList( |
| [ |
| LLMTransformerLayer(weights_per_step=weights_per_step) |
| for _ in range(num_layers) |
| ]) |
| |
| def forward(self, x): |
| for lay in self.layers: |
| x = lay(x) |
| return x |
| |
| |
| class Voc(Wav2Vec2PreTrainedModel): |
| |
| '''For using different batch_siz -> Voc._flush() |
| ''' |
| |
| def __init__(self, config=PretrainedConfig()): |
| super().__init__(config=config) |
| self.encoder_transformer = VocTransformer() |
| self.decoder_transformer = VocTransformer() |
| self.encoder = SEANetEncoder() |
| self.decoder = SEANetDecoder() |
| self.sample_rate = 24000 |
| self.quantizer = SplitResidualVectorQuantizer() |
| self.downsample = BufferConv1d(512, 512, kernel_size=4, stride=2, groups=1, bias=False) |
| upsample_channel_wise_bug = True |
| self.upsample = BufferConvTranspose1d(512, 512, kernel_size=4, |
| groups=512 if upsample_channel_wise_bug else 1, |
| stride=2, bias=False) |
| self.frame_rate = 12.5 |
| self.encode_buffer = None |
| |
| def _flush(self): |
| '''stream buffers have tensors of old batch size! Voc()._flush() to clean buffers |
| ''' |
| self.encode_buffer = None # holds unused (incomplete windows of len < 1920) - we need 1920 to produce 1 token |
| if self.downsample.previous is not None: |
| self.downsample.previous = None |
| if self.upsample.partial is not None: |
| self.upsample.partial = None |
| for arch in [self.encoder, self.decoder]: |
| for _m in arch.model: |
| if type(_m) is SEANetResnetBlock: |
| for _b in _m.block: |
| if type(_b) is BufferConv1d: |
| if _b.previous is not None: |
| _b.previous = None |
| if type(_m) is BufferConv1d: |
| if _m.previous is not None: |
| _m.previous = None |
| if type(_m) is BufferConvTranspose1d: |
| if _m.partial is not None: |
| _m.partial = None |
| |
| @torch.no_grad() |
| def encode(self, x): |
| '''24KHz audio to codes |
| x : [bs, 1, 24 KHz] |
| c : [bs, 8, time] = 1920 audio samples produce 1 time frame (of n_q codebooks) |
| ''' |
| if self.encode_buffer is not None: |
| x = torch.cat([self.encode_buffer, x], 2) |
| _bs, _1, _len = x.shape |
| num_frames = int(_len / 1920) |
| leftover = x[:, :, (num_frames+1) * 1920:] |
| if leftover.shape[2] > 0: |
| self.encode_buffer = leftover |
| else: |
| self.encode_buffer = None |
| torch.cuda.empty_cache() |
| if num_frames > 0: |
| c = [] |
| for n in range(num_frames): |
| e = self.encoder(x[:, :, n * 1920:(n + 1) * 1920]) |
| e = self.encoder_transformer(e) |
| e = self.downsample(e) |
| _c = self.quantizer.encode(e) |
| c.append(_c) |
| c = torch.cat(c, 2) |
| else: |
| # num_frames = 0 Early exit -> for x.shape[2]<1920 fill conv buffers but can't output token |
| c = torch.empty(_bs, 16, 0) |
| return c |
| |
| @torch.no_grad() |
| def decode(self, c): |
| '''codes to 24kHZ audio |
| c: [bs, 8, n_tokens] |
| x: [bs, 1, n_tokens * 1920] |
| ''' |
| _hidden = [] |
| for i in range(c.shape[2]): |
| x = self.quantizer.decode(c[:, :, i:i+1]) |
| x = self.upsample(x) |
| x = self.decoder_transformer(x) |
| x = self.decoder(x) |
| _hidden.append(x) |
| return torch.cat(_hidden, 2) # [bs, 1, 24KHz] |
| |
| |
| class SEANetResnetBlock(nn.Module): |
| def __init__( |
| self, |
| dim, |
| kernel_sizes=[3, 1], |
| ): |
| super().__init__() |
| |
| block = [] |
| for i, kernel_size in enumerate(kernel_sizes): |
| |
| block += [ |
| nn.ELU(), |
| BufferConv1d( |
| dim if i == 0 else dim // 2, |
| dim // 2 if i == 0 else dim, |
| kernel_size=kernel_size, |
| bias=True, |
| ), |
| ] |
| |
| self.block = nn.Sequential(*block) |
| |
| def forward(self, x): |
| return x + self.block(x) |
| |
| |
| class SEANetEncoder(nn.Module): |
| def __init__( |
| self, |
| channels=1, # DOES NOT SUPPORT STEREO |
| dimension=512, |
| n_filters=64, |
| ratios=[8, 6, 5, 4], |
| kernel_size=7, |
| last_kernel_size=3, |
| ): |
| super().__init__() |
| self.ratios = list(reversed(ratios)) |
| del ratios |
| mult = 1 |
| model=[ |
| BufferConv1d( |
| channels, |
| mult * n_filters, |
| kernel_size, |
| bias=True |
| ) |
| ] |
| for i, ratio in enumerate(self.ratios): |
| model += [SEANetResnetBlock(mult * n_filters), |
| nn.ELU(), |
| BufferConv1d(mult * n_filters, |
| mult * n_filters * 2, |
| kernel_size=ratio * 2, |
| stride=ratio, |
| bias=True)] |
| mult *= 2 |
| # ENDFOR |
| model += [nn.ELU(), |
| BufferConv1d(mult * n_filters, |
| dimension, |
| last_kernel_size, |
| bias=True)] |
| self.model = nn.Sequential(*model) |
| |
| def forward(self, x): |
| return self.model(x) |
| |
| |
| class SEANetDecoder(nn.Module): |
| |
| def __init__( |
| self, |
| channels=1, |
| dimension=512, |
| n_filters=64, |
| ratios=[8, 6, 5, 4], |
| kernel_size=7, |
| last_kernel_size=3): |
| |
| super().__init__() |
| mult = int(2 ** len(ratios)) |
| model = [BufferConv1d(dimension, |
| mult * n_filters, |
| kernel_size, |
| bias=True)] |
| #UP |
| for i, ratio in enumerate(ratios): |
| model += [nn.ELU(), |
| BufferConvTranspose1d(mult * n_filters, |
| mult * n_filters // 2, |
| kernel_size=ratio * 2, |
| stride=ratio, |
| bias=True), |
| SEANetResnetBlock(mult * n_filters // 2)] |
| mult //= 2 |
| # LAST |
| model += [ |
| nn.ELU(), |
| BufferConv1d( |
| n_filters, |
| channels, |
| last_kernel_size, |
| bias=True |
| ), |
| ] |
| self.model = nn.Sequential(*model) |
| |
| def forward(self, x): |
| return self.model(x) |
| |
| |
| class BufferConv1d(nn.Conv1d): |
| def __init__(self, |
| *args, |
| **kwargs): |
| super().__init__(*args, **kwargs) |
| self.previous = None |
| |
| def forward(self, x): |
| k = self.kernel_size[0] |
| |
| if self.previous is not None: |
| |
| x = torch.cat([self.previous, x], 2) |
| |
| else: # If self.previous is None => Use zero pad |
| |
| if k == 3: |
| |
| p = (2, 0) |
| x = F.pad(x, p, mode='replicate', value=0.0) # skip connections SeaNetResBlk |
| |
| elif k == 4: # ConvTrUpsample is the first conv encountered by decode replicate solves pulse |
| |
| p = (3, 0) |
| x = F.pad(x, p, mode='replicate', value=0.0) |
| |
| elif k == 7: |
| |
| p = (6, 0) |
| x = F.pad(x, p, mode='replicate', value=0.0) |
| |
| elif k == 16: |
| |
| p = (2, 0) |
| x = F.pad(x, p, mode='replicate', value=0.0) # THis can be also constant w/o pulse occur |
| |
| num_frames = int( (x.shape[2] - self.kernel_size[0]) / self.stride[0] ) + 1 # +1 is: k starts at left of x and doing (I-k)/s jumps |
| offset = num_frames * self.stride[0] |
| self.previous = x[..., offset:] |
| return super().forward(x) |
| |
| |
| class BufferConvTranspose1d(nn.ConvTranspose1d): |
| # kernel 5 has only 1 pixel for input (cloned) |
| # https://distill.pub/2016/deconv-checkerboard/ |
| def __init__(self, |
| *args, |
| **kwargs): |
| super().__init__(*args, |
| **kwargs) |
| self.partial = None |
| |
| def forward(self, x): |
| out = super().forward(x) |
| OT = out.shape[2] |
| invalid_steps = self.kernel_size[0] - self.stride[0] |
| if self.partial is not None: |
| PT = self.partial.shape[-1] |
| if self.bias is not None: |
| out[..., :PT] += self.partial - self.bias[:, None] |
| else: |
| out[..., :PT] += self.partial # for ConvTrUpsample1d |
| invalid_steps = self.kernel_size[0] - self.stride[0] |
| self.partial = out[..., OT - invalid_steps :] |
| out = out[...,:OT - invalid_steps] |
| return out |
| |
| |
| class CodeBook(nn.Module): |
| def __init__(self, dim, codebook_size): |
| super().__init__() |
| self.register_buffer('_e', torch.zeros(codebook_size, dim)) |
| |
| def encode(self, x): |
| dist = torch.cdist( |
| x.transpose(1, 2), # [bs, time, 256] |
| self._e[None, :, :] # [1, 2048, 256] |
| ) |
| codes = dist.argmin(2) |
| return codes |
| |
| def decode(self, codes): |
| quantized = F.embedding(codes, self._e) |
| return quantized.transpose(1, 2) # [1, 256, time] |
| |
| |
| class SplitResidualVectorQuantizer(nn.Module): |
| |
| def __init__(self, |
| n_q=None): |
| super().__init__() |
| self.in_proj_s = torch.nn.Conv1d(512, 256, 1, bias=False) |
| self.in_proj_a = torch.nn.Conv1d(512, 256, 1, bias=False) |
| self.out_proj_s = torch.nn.Conv1d(256, 512, 1, bias=False) # reused for all _acoustic_books |
| self.out_proj_a = torch.nn.Conv1d(256, 512, 1, bias=False) |
| self.layers = nn.ModuleList([CodeBook(dim=256, codebook_size=2048) for _ in range(18)]) |
| self._acoustic_books = range(1, 16) # Official Mimi |
| # CODEBOOKS |
| # Here we re use RVQ codebooks for higher fidelity! |
| # Exclude 0 here as it has different proj (in_proj_s) |
| # self._acoustic_books = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 17, 17, 17, 17] |
| |
| def encode(self, x): |
| indices = self.layers[0].encode(self.in_proj_s(x)) # integers |
| all_indices = [ indices[:, None, :], ] |
| x = self.in_proj_a(x) |
| for _cb in self._acoustic_books: |
| indices = self.layers[_cb].encode(x) |
| x = x - self.layers[_cb].decode(indices) |
| all_indices.append(indices[:, None, :]) |
| codes = torch.cat(all_indices, 1) |
| return codes |
| |
| def decode(self, codes): |
| _s = self.layers[0].decode(codes[:, 0, :]) |
| _a = torch.zeros([1, 1], device=codes.device) |
| for i, _cb in enumerate(self._acoustic_books): |
| _a = _a + self.layers[_cb].decode(codes[:, i+1, :]) |
| return self.out_proj_s(_s) + self.out_proj_a(_a) # [bs, 512, time] |
| |
| |
| class VocAttention(nn.Module): |
| |
| def __init__(self, |
| embed_dim): |
| |
| super().__init__() |
| self.fused_proj = nn.Parameter(torch.zeros(embed_dim, embed_dim)) |
| |
| def forward(self, x): |
| '''bypass of streaming training''' |
| if x.shape[1] > 1: |
| x = x.mean(1, keepdims=True) |
| x = torch.matmul(x, self.fused_proj) |
| return x # FFN broadcasts to x.shape[1]=2 |
| |
| |
| class VocTransformerLayer(nn.Module): |
| |
| def __init__(self, d_model=512, dim_feedforward=2048): |
| super().__init__() |
| self.self_attn = VocAttention(embed_dim=d_model) |
| self.norm1 = nn.LayerNorm(d_model, eps=1e-5) |
| self.norm2 = nn.LayerNorm(d_model, eps=1e-5) |
| self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) |
| self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) |
| |
| def forward(self, x): |
| x = x + self.self_attn(self.norm1(x)) |
| return x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) |
| |
| |
| class VocTransformer(nn.Module): |
| |
| def __init__(self): |
| |
| super().__init__() |
| self.layers = nn.ModuleList(VocTransformerLayer() for _ in range(8)) |
| |
| def forward(self, x): |
| x = x.transpose(1, 2) |
| for la in self.layers: |
| x = la(x) |
| return x.transpose(1, 2) |
| |
| class Entry(): |
| def __init__(self, tokens=None): |
| self.tokens = tokens |
| self.padding = len(tokens) + 2 - 1 |
| |
| class TokenState: |
| |
| def __init__(self, entries = None): |
| self.entries = entries |
| self.queued = deque([]) |
| self.lookahead_queued = deque() |
| self.end_step = None |
| self.forced_padding = 2 |
| |
| class TTSModel(nn.Module): |
| |
| def __init__(self): |
| super().__init__() |
| self.tokenizer = SentencePieceProcessor(str(hf_hub_download(repo_id='kyutai/tts-0.75b-en-public', |
| filename='tokenizer_spm_8k_en_fr_audio.model'))) |
| with torch.device("meta"): |
| self.emb = nn.ModuleList([ScaledEmbedding(2049, 1024) for _ in range(16)]) |
| self.text_emb = ScaledEmbedding(8001, 1024, demux_second_stream=True) |
| self.transformer = LLMTransformer() |
| self.out_norm = RMSNorm() |
| self.depformer_in = nn.ModuleList([nn.Linear(1024, 1024, bias=False) for _ in range(9)]) |
| self.depformer_emb = nn.ModuleList([ScaledEmbedding(2049, 128) for _ in range(16 - 1)]) |
| self.depformer_text_emb = ScaledEmbedding(8001, 128, demux_second_stream=True) |
| self.depformer = LLMTransformer(num_layers=4, weights_per_step=16) |
| self.linears = nn.ModuleList([nn.Linear(1024, 2048, bias=False) for _ in range(16)]) # DPF heads |
| |
| state_d = load_file(hf_hub_download(repo_id='Dionyssos/_TTS075B', filename='tts_075B.safetensors')) |
| self.load_state_dict(state_d, assign=True, strict=True) #overwrite devices of rand init params |
| self.to(dtype=torch.bfloat16).eval() |
| |
| def prepare_script(self, script='Type your text here.'): |
| entries = [] |
| # break is indicated as e.g. <break time="3s"/> |
| event_re = re.compile(r"(?:<break\s+time=\"([0-9]+(?:.[0-9]*)?)s\"\s*/?>)|(?:\s+)") |
| line = script.replace('’', "'").replace(':', " ").replace('(', "").replace(')', "") |
| while line: |
| match = event_re.search(line) |
| if match is None: |
| break |
| word = line[:match.start()] |
| line = line[match.end():] |
| if word: |
| entries.append(Entry(tokens=self.tokenizer.encode(word))) |
| if match.group(1): |
| raise ValueError |
| # break_duration = float(match.group(1)) |
| # padding = int(round(break_duration * frame_rate)) |
| # entry = Entry(tokens=[], text='', padding=padding) |
| # entries.append(entry) |
| if line: |
| entries.append(Entry(tokens=self.tokenizer.encode(line))) |
| return entries |
| |
| @property |
| def device(self): |
| return next(iter(self.parameters())).device |
| |
| @torch.no_grad() |
| def generate(self, text=None, |
| voice_path=None, mimi=None, |
| play=16): |
| _wav, _ = sphn.read(voice_path, |
| sample_rate=24000) |
| _wav = mimi.encode(torch.from_numpy(_wav).to(device=self.device)[None])[0, :, :] # limit frames of voice prefix |
| state = TokenState(entries=deque(self.prepare_script(script=text))) |
| upper_lim = 2 * sum([len(p.tokens) for p in state.entries]) |
| self.cache = torch.full((2,17, 4), -1, device=self.device, dtype=torch.long) |
| pcms = []#final audio to return |
| for offset in range(4 * upper_lim): |
| print(f'{offset=} of {upper_lim=}',end='\r') |
| if state.end_step is not None: |
| if offset >= state.end_step + 16 + 4: |
| break |
| |
| input_ = self.cache[:, :, offset % self.cache.shape[2]].clone() |
| |
| if offset == 0: |
| input_[:, 0] = 8000 # so we dont have to reset cfg txr = -1 for offset >0 |
| input_[:, 1:] = 2048 |
| |
| if offset < 3: |
| input_[:, 2:] = 2048 |
| |
| |
| x = self.text_emb(input_[:, :1]) |
| for cb_ in range(16): |
| x = self.emb[cb_](input_[:, cb_ + 1 : cb_ + 2]) + x |
| x = self.out_norm(self.transformer(x)) |
| |
| |
| token = -1 |
| if offset > _wav.shape[1]: |
| token = 0 |
| # START |
| if state.queued: |
| token = 3 |
| if state.forced_padding > 0: |
| token = 3 |
| #=================================== |
| if token == 0: |
| if state.entries: |
| e = state.entries.popleft() |
| if e.tokens: |
| state.queued.extend(e.tokens) |
| lookahead =2 |
| for e2 in state.entries: |
| if e2.tokens: |
| lookahead -= 1 |
| if lookahead == 0: |
| state.lookahead_queued.extend(e2.tokens) |
| break |
| # print('\neeee',e2,'\n\n') |
| # raise ValueError |
| else: |
| token = 3 |
| state.forced_padding = e.padding |
| # print(f'\n\n=========o=============\n{state.lookahead_queued=} {state.queued=}===================\n\n') |
| else: |
| token = 3 |
| if state.end_step is None: |
| token = 0 |
| if state.end_step is None: |
| state.end_step = offset |
| #============================================== |
| output=0 |
| if token == 3: |
| if state.forced_padding > 0: |
| state.forced_padding -= 1 |
| if state.queued: |
| output = state.queued.popleft() |
| else: |
| output = 3 |
| # ========================== |
| second = -1 |
| if output == 0: |
| second = 0 |
| if state.queued: |
| output = state.queued.popleft() |
| else: |
| output = 3 |
| elif state.lookahead_queued: |
| second = state.lookahead_queued.popleft() # Difference of queued and lookahead_queued? |
| token = (second + 1) * 8001 + output |
| |
| # audio tokens |
| ac = (offset + 1) % self.cache.shape[2] |
| self.cache[0, 0, ac] = token |
| audio_tokens = torch.ones([1, 16], device=x.device, dtype=torch.long) |
| if offset > play: |
| prev_token = torch.tensor([[token]], device=x.device, dtype=torch.long) |
| for _cb in range(16): |
| last_token_input = None |
| if _cb == 0: |
| last_token_input = self.depformer_text_emb(prev_token.repeat(2, 1)) |
| else: |
| last_token_input = self.depformer_emb[_cb - 1](prev_token) |
| dep_output = self.depformer(self.depformer_in[_cb if _cb < 9 else 8](x) + last_token_input) |
| logits = self.linears[_cb](dep_output) |
| prev_token = (2.0 * logits[0, :, :] - logits[1, :, :]).argmax(1) |
| audio_tokens[0, _cb] = prev_token |
| # voXcopy |
| if offset > play and offset < play + 1 + _wav.shape[1]: |
| audio_tokens[:, 0] = _wav[0, offset - play - 1] |
| if offset > play and offset < play + 2 + _wav.shape[1]: |
| audio_tokens[:, 1:] = _wav[1:, offset - play - 2] |
| # next turn |
| self.cache[0, 1:, ac] = audio_tokens |
| # cfg |
| if offset > 16 + 2 + _wav.shape[1]: |
| if offset > 16 + 4 + _wav.shape[1]: |
| self.cache[1, 1:, ac] = self.cache[0, 1:, ac] |
| else: |
| self.cache[1, 1, ac] = self.cache[0, 1, ac] |
| # ivao0/voc |
| if offset > 20 + _wav.shape[1]: |
| audio_tokens[:, 0] = self.cache[0, 1, (offset - 1) % self.cache.shape[2]] # previous |
| pcms.append(mimi.decode(audio_tokens[:, :, None])) # [1,1,1920] |
| x = torch.cat(pcms, dim=2)[0, 0, :] |
| return x.cpu().numpy() |
| |
| class ScaledEmbedding(nn.Embedding): |
| def __init__(self, num_embeddings=None, embedding_dim=None, demux_second_stream=False): |
| super().__init__(num_embeddings, embedding_dim) |
| self.zero_idx = -1 |
| self.low_rank = None |
| self.demux_second_stream = demux_second_stream |
| if self.demux_second_stream: |
| self.out1 = nn.Linear(embedding_dim, 1024, bias=False) |
| self.out2 = nn.Linear(embedding_dim, 1024, bias=False) |
| else: |
| if embedding_dim != 1024: |
| self.low_rank = nn.Linear(embedding_dim, 1024, bias=False) |
| |
| def forward(self, input): |
| is_zero = input == self.zero_idx |
| zero = torch.zeros(1, dtype=input.dtype, device=input.device) |
| input = input.clamp(min=0) |
| if self.demux_second_stream: |
| left = super().forward(input % self.num_embeddings) |
| right = input // self.num_embeddings - 1 |
| right_zero = (right < 0)[..., None] |
| right.clamp_(min=0) |
| right = super().forward(right) |
| y = self.out1(left) + torch.where(right_zero, zero, self.out2(right)) |
| y = torch.where(is_zero[..., None], zero, y) |
| else: |
| y = super().forward(input) |
| y = torch.where(is_zero[..., None], zero, y) |
| if self.low_rank is not None: |
| # Can only see low_rank if no demux second stream |
| y = self.low_rank(y) # applies after |
| return y |
| |
| text = '''Far over the misty mountains cold |
| To dungeons deep and caverns old |
| We must away ere break of day |
| To seek the pale enchanted gold. |
| |
| The dwarves of yore made mighty spells, |
| While hammers fell like ringing bells |
| In places deep, where dark things sleep, |
| In hollow halls beneath the fells. |
| |
| For ancient king and elvish lord |
| There many a gleaming golden hoard |
| They shaped and wrought, and light they caught |
| To hide in gems on hilt of sword. |
| |
| On silver necklaces they strung |
| The flowering stars, on crowns they hung |
| The dragon-fire, in twisted wire |
| They meshed the light of moon and sun. |
| |
| Far over the misty mountains cold |
| To dungeons deep and caverns old |
| We must away, ere break of day, |
| To claim our long-forgotten gold. |
| Farewell we call to hearth and hall! |
| Though wind may blow and rain may fall, |
| We must away ere break of day |
| Far over wood and mountain tall.''' |
| |
| |
| device = 'cpu' # 'cuda:0' |
| tts_model = TTSModel().eval().to(device) |
| mimi = Voc.from_pretrained('ivao0/voc').eval().to(device) |
| x = tts_model.generate(text=text, |
| voice_path=hf_hub_download(repo_id='Dionyssos/_TTS075B', filename='wav/en_US_m-ailabs_mary_ann.wav'), |
| mimi=mimi) |
| sphn.write_wav(f'dsm_tts.wav', x, 24000) |
| |
| ``` |
|
|