File size: 3,750 Bytes
7d71c91 6eea8b2 7d71c91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | """
V6 Tokenizer β char-level for Bulgarian TTS with MioCodec
==========================================================
Same character set as V5, but adapted for:
- MioCodec single codebook (no interleaving)
- Speaker embedding (no speaker tokens in encoder input)
"""
import re
import torch
from typing import Optional
from config import (
TEXT_CHARS, TEXT_OFFSET, AUDIO_OFFSET,
SPECIAL_TOKENS, NUM_SPECIAL_TOKENS, CODEC_CODEBOOK_SIZE,
TOTAL_VOCAB_SIZE,
PAD_TOKEN_ID, START_OF_TEXT_TOKEN_ID, END_OF_TEXT_TOKEN_ID,
START_OF_SPEECH_TOKEN_ID, END_OF_SPEECH_TOKEN_ID,
is_audio_token, is_special_token, is_text_token,
)
class TTSTokenizer:
def __init__(self):
self.char2id: dict[str, int] = {}
self.id2char: dict[int, str] = {}
for i, ch in enumerate(TEXT_CHARS):
tid = TEXT_OFFSET + i
self.char2id[ch] = tid
self.id2char[tid] = ch
self._special_id_to_name = {v: k for k, v in SPECIAL_TOKENS.items()}
self.vocab_size = TOTAL_VOCAB_SIZE
self.text_vocab_size = len(TEXT_CHARS)
def normalize_text(self, text: str) -> str:
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r'[ββ]', '-', text)
text = re.sub(r'[«»β""]', '"', text)
return text
def encode_text(self, text: str) -> list[int]:
text = self.normalize_text(text)
return [self.char2id[ch] for ch in text if ch in self.char2id]
def decode_text(self, ids: list[int]) -> str:
return "".join(self.id2char.get(t, "") for t in ids if is_text_token(t))
# ββ Encoder-Decoder methods ββββββββββββββββββββββββββββββ
def build_encoder_input(self, text: str) -> torch.Tensor:
"""
Encoder input: <sot> text_chars <eot>
No speaker token β speaker info comes from embedding.
"""
text_ids = self.encode_text(text)
seq = [START_OF_TEXT_TOKEN_ID] + text_ids + [END_OF_TEXT_TOKEN_ID]
return torch.tensor(seq, dtype=torch.long)
def build_decoder_input(self, audio_codes: torch.Tensor) -> torch.Tensor:
"""
Decoder input: <sos> [audio_codes + AUDIO_OFFSET] <eos>
audio_codes: raw MioCodec codes in [0, 12799]
"""
seq = (
[START_OF_SPEECH_TOKEN_ID]
+ (audio_codes + AUDIO_OFFSET).tolist()
+ [END_OF_SPEECH_TOKEN_ID]
)
return torch.tensor(seq, dtype=torch.long)
def build_decoder_prefix(self) -> torch.Tensor:
"""For inference: just <sos> to start generation."""
return torch.tensor([START_OF_SPEECH_TOKEN_ID], dtype=torch.long)
def extract_audio_codes(self, sequence: torch.Tensor) -> Optional[torch.Tensor]:
"""Extract raw MioCodec codes from a token sequence."""
mask = torch.tensor([is_audio_token(t.item()) for t in sequence])
if not mask.any():
return None
return sequence[mask] - AUDIO_OFFSET
def describe(self, seq: torch.Tensor, max_tok: int = 30) -> str:
parts = []
for t in seq[:max_tok]:
tid = t.item()
if is_special_token(tid):
parts.append(self._special_id_to_name.get(tid, f"<sp_{tid}>"))
elif is_text_token(tid):
ch = self.id2char.get(tid, "?")
parts.append(ch if ch != " " else "Β·")
elif is_audio_token(tid):
code = tid - AUDIO_OFFSET
parts.append(f"βͺ{code}")
else:
parts.append(f"?{tid}")
r = " ".join(parts)
if len(seq) > max_tok:
r += f" ... [{len(seq) - max_tok} more]"
return r
|