File size: 2,064 Bytes
43e61e8
eda855d
55e5fd0
63ef4fe
55e5fd0
eda855d
 
55e5fd0
eda855d
 
 
 
55e5fd0
eda855d
 
55e5fd0
eda855d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55e5fd0
eda855d
 
 
 
 
 
43e61e8
ceb6264
eda855d
 
55e5fd0
 
eda855d
 
55e5fd0
 
eda855d
 
55e5fd0
eda855d
63ef4fe
ccf3842
63ef4fe
55e5fd0
 
 
 
 
 
 
ceb6264
55e5fd0
 
 
 
63ef4fe
55e5fd0
148e787
 
55e5fd0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from pathlib import Path

import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
from TTS.api import TTS

# ---------- Config for your private model repo ----------
REPO_ID = "DhanuakaDev/SinTts-prev-v0.1"  # private model repo
CHECKPOINT_FILENAME = "checkpoint_80000.pth"  # change if your file name differs
CONFIG_FILENAME = "config.json"

# Get token from Space secret (Settings -> Variables and secrets)
HF_TOKEN = os.environ.get("HF_TOKEN")

# ---------- Download files from private repo ----------
# hf_hub_download returns a local path in the cache
checkpoint_path = hf_hub_download(
    repo_id=REPO_ID,
    filename=CHECKPOINT_FILENAME,
    token=HF_TOKEN,      # required for private repos
    repo_type="model",   # explicit, though "model" is default
)

config_path = hf_hub_download(
    repo_id=REPO_ID,
    filename=CONFIG_FILENAME,
    token=HF_TOKEN,
    repo_type="model",
)

# ---------- Load TTS model (same style as your local script) ----------
tts = TTS(
    model_path=str(checkpoint_path),
    config_path=str(config_path),
    progress_bar=False,
    gpu=False,   # Space uses CPU; enable GPU only if you switch hardware
)

SAMPLE_RATE = tts.synthesizer.output_sample_rate

# ---------- Inference function ----------
def tts_generate(text: str):
    text = text.strip()
    if not text:
        return None

    # Generate audio (same call as in your local script)
    wav = tts.tts(text)

    # Ensure numpy 1D array for Gradio
    wav = np.asarray(wav, dtype="float32").flatten()

    # Gradio Audio(type="numpy") expects (sample_rate, np.ndarray)
    return (SAMPLE_RATE, wav)

# ---------- Gradio UI ----------
demo = gr.Interface(
    fn=tts_generate,
    inputs=gr.Textbox(
        label="Input text",
        placeholder="Type Sinhala text here…",
        lines=3,
    ),
    outputs=gr.Audio(
        label="Generated speech",
        type="numpy",
    ),
    title="Sinhala TTS ",
    description="Sinhala TTS model - Research-stage model",
)

if __name__ == "__main__":
    demo.launch()