File size: 3,861 Bytes
c0c84cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf9ad3
c0c84cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf9ad3
 
 
c0c84cf
 
 
 
 
ebf9ad3
c0c84cf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import torch
import soundfile as sf
from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
from tiny_tts.text import phonemes_to_ids
from tiny_tts.nn import commons
from tiny_tts.models.synthesizer import VoiceSynthesizer
from tiny_tts.text.symbols import symbols
from tiny_tts.utils.config import (
    SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
    N_SPEAKERS, SPK2ID, MODEL_PARAMS,
)
from tiny_tts.infer import load_engine

class TinyTTS:
    def __init__(self, checkpoint_path=None, device=None):
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
            
        if checkpoint_path is None:
            # Look for default checkpoint in pacakage
            pkg_dir = os.path.dirname(os.path.abspath(__file__))
            default_ckpt = os.path.join(os.path.dirname(pkg_dir), "checkpoints", "G.pth")
            # 2. Check HuggingFace Cache / Download
            if not os.path.exists(default_ckpt):
                try:
                    from huggingface_hub import hf_hub_download
                    print("Downloading/Loading checkpoint from Hugging Face Hub (backtracking/tiny-tts)...")
                    default_ckpt = hf_hub_download(repo_id="backtracking/tiny-tts", filename="G.pth")
                except ImportError:
                    raise ImportError("huggingface_hub is required to auto-download the model. Run: pip install huggingface_hub")
                except Exception as e:
                    raise ValueError(f"Failed to download checkpoint from Hugging Face: {e}")

            checkpoint_path = default_ckpt
                
        self.model = load_engine(checkpoint_path, self.device)

    def speak(self, text, output_path="output.wav", speaker="MALE", speed=1.0):
        """Synthesize text to speech and save to output_path."""
        print(f"Synthesizing: {text}")

        # Normalize text
        normalized = normalize_text(text)

        # Phonemize
        phones, tones, word2ph = grapheme_to_phoneme(normalized)

        # Convert to sequence
        phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")

        # Add blanks
        if ADD_BLANK:
            phone_ids = commons.insert_blanks(phone_ids, 0)
            tone_ids = commons.insert_blanks(tone_ids, 0)
            lang_ids = commons.insert_blanks(lang_ids, 0)

        x = torch.LongTensor(phone_ids).unsqueeze(0).to(self.device)
        x_lengths = torch.LongTensor([len(phone_ids)]).to(self.device)
        tone = torch.LongTensor(tone_ids).unsqueeze(0).to(self.device)
        language = torch.LongTensor(lang_ids).unsqueeze(0).to(self.device)

        # Speaker ID
        if speaker not in SPK2ID:
            print(f"Warning: Speaker '{speaker}' not found, using ID 0. Available: {list(SPK2ID.keys())}")
            sid = torch.LongTensor([0]).to(self.device)
        else:
            sid = torch.LongTensor([SPK2ID[speaker]]).to(self.device)

        # BERT features (disabled - using zero tensors)
        bert = torch.zeros(1024, len(phone_ids)).to(self.device).unsqueeze(0)
        ja_bert = torch.zeros(768, len(phone_ids)).to(self.device).unsqueeze(0)

        # speed > 1.0 = faster speech, < 1.0 = slower speech
        length_scale = 1.0 / speed

        with torch.no_grad():
            audio, *_ = self.model.infer(
                x, x_lengths, sid, tone, language, bert, ja_bert,
                noise_scale=0.667,
                noise_scale_w=0.8,
                length_scale=length_scale
            )

        audio_np = audio[0, 0].cpu().numpy()
        sf.write(output_path, audio_np, SAMPLING_RATE)
        print(f"Saved audio to {output_path}")
        return audio_np