File size: 1,883 Bytes
e729286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from phonemizer import phonemize
from phonemizer.separator import Separator

class TextEncoder:
    """

    Handles text-to-phoneme conversion for 14 languages.

    """
    def __init__(self, vocab_map=None):
        self.separator = Separator(phone=' ', word='|', syllable='')
        # Maps 14 languages to phonemizer language codes
        self.lang_map = {
            'en': 'en-us', 'zh': 'cmn', 'es': 'es', 'fr': 'fr-fr',
            'de': 'de', 'ja': 'ja', 'ko': 'ko', 'ru': 'ru',
            'pt': 'pt', 'it': 'it', 'hi': 'hi', 'ar': 'ar',
            'tr': 'tr', 'nl': 'nl', 'bn': 'bn'
        }
        # Simple character-to-id mapping (placeholder)
        self.vocab = vocab_map if vocab_map else {c: i for i, c in enumerate(" abcdefghijklmnopqrstuvwxyz|")}

    def preprocess(self, text, lang_code='en'):
        """

        Converts text to phoneme IDs.

        """
        if lang_code not in self.lang_map:
            print(f"Warning: Language {lang_code} not fully supported, defaulting to English backend.")
            backend_lang = 'en-us'
        else:
            backend_lang = self.lang_map[lang_code]

        try:
            # Phonemize
            phonemes = phonemize(
                text,
                language=backend_lang,
                backend='espeak',
                separator=self.separator,
                strip=True,
                preserve_punctuation=True,
                njobs=1
            )
        except RuntimeError:
            print("Warning: eSpeak not found. Falling back to character-level tokenization.")
            phonemes = list(text)  # Simple list of characters as fallback
        
        # Tokenize (Simple lookup for now)
        token_ids = [self.vocab.get(p, 0) for p in phonemes]
        return torch.tensor(token_ids).unsqueeze(0) # Batch dim