File size: 10,716 Bytes
2b90282
 
 
 
 
 
 
 
 
 
c07485b
2b90282
 
 
 
 
 
 
c07485b
2b90282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07485b
2b90282
 
 
 
 
 
 
 
 
 
c07485b
 
 
2b90282
 
c07485b
 
2b90282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07485b
 
2b90282
 
c07485b
2b90282
 
 
 
 
 
 
 
c07485b
 
 
 
 
 
 
 
2b90282
c07485b
 
 
2b90282
c07485b
 
2b90282
 
c07485b
 
 
 
 
 
 
2b90282
c07485b
 
2b90282
 
 
 
 
 
 
c07485b
 
 
 
 
 
 
 
 
 
2b90282
 
 
 
c07485b
 
2b90282
c07485b
2b90282
c07485b
 
 
 
 
 
 
 
 
 
 
 
2b90282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f59768
2b90282
 
 
 
 
 
 
 
 
 
 
 
 
 
c07485b
2b90282
 
 
 
 
 
c07485b
2b90282
 
 
 
c07485b
 
 
2b90282
c07485b
 
 
 
 
 
 
 
 
 
 
 
 
 
2b90282
 
 
 
 
c07485b
 
 
2b90282
 
 
c07485b
 
 
 
 
2b90282
 
 
 
 
 
 
 
 
 
 
c07485b
 
 
2b90282
 
 
 
 
 
e9e1122
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""Gradio demo for Accent Vectors.

Lets users synthesise speech with a controllable accent directly in the
browser — no local setup required.

Models are downloaded from Hugging Face on first use and cached for the
lifetime of the Space instance.
"""

import os
import json
import tempfile

import gradio as gr
import torch
from huggingface_hub import snapshot_download

from accent_task_vectors.inference import load_xtts_model, attach_lora_adapter
from accent_task_vectors.inference.inference import _scale_lora

# ---------------------------------------------------------------------------
# Model registry (mirrors download_checkpoints.py)
# ---------------------------------------------------------------------------

PRETRAINED_REPO = "NewGame/pretrained-xtts"

MODELS = {
    ("English",  "English"):  "NewGame/english-accent-english-xtts",
    ("English",  "Hindi"):    "NewGame/hindi-accent-english-xtts",
    ("English",  "German"):   "NewGame/german-accent-english-xtts",
    ("English",  "French"):   "NewGame/french-accent-english-xtts",
    ("English",  "Spanish"):  "NewGame/spanish-accent-english-xtts",
    ("English",  "Mandarin"): "NewGame/mandarin-accent-english-xtts",
    ("Spanish",  "English"):  "NewGame/english-accent-spanish-xtts",
    ("German",   "English"):  "NewGame/english-accent-german-xtts",
    ("Mandarin", "English"):  "NewGame/english-accent-mandarin-xtts",
}

# Language code passed to the TTS model
LANGUAGE_CODES = {
    "English":  "en",
    "Spanish":  "es",
    "German":   "de",
    "Mandarin": "zh-cn",
}

# Accents available for each output language
ACCENTS_BY_LANGUAGE = {
    "English":  ["English", "Hindi", "German", "French", "Spanish", "Mandarin"],
    "Spanish":  ["English"],
    "German":   ["English"],
    "Mandarin": ["English"],
}

# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------

CACHE_DIR      = os.environ.get("MODEL_CACHE_DIR", "model_cache")
PRETRAINED_DIR = os.path.join(CACHE_DIR, "pretrained")

_PRETRAINED_PATH_FIELDS = {
    "mel_norm_file":   "mel_stats.pth",
    "dvae_checkpoint": "dvae.pth",
    "xtts_checkpoint": "model.pth",
    "tokenizer_file":  "vocab.json",
}

# ---------------------------------------------------------------------------
# In-memory model cache
#   _model_cache:    (language, accent1, accent2|None) -> tts
#   _current_coeffs: same key -> (coeff1, coeff2)
# ---------------------------------------------------------------------------

_model_cache:    dict = {}
_current_coeffs: dict = {}
_device = "cuda" if torch.cuda.is_available() else "cpu"


def _patch_config(config_path: str, pretrained_dir: str) -> None:
    with open(config_path) as f:
        config = json.load(f)

    abs_pretrained = os.path.abspath(pretrained_dir)
    changed = False

    def _patch(obj):
        nonlocal changed
        if isinstance(obj, dict):
            for key, filename in _PRETRAINED_PATH_FIELDS.items():
                if key in obj:
                    new_val = os.path.join(abs_pretrained, filename)
                    if obj[key] != new_val:
                        obj[key] = new_val
                        changed = True
            for v in obj.values():
                _patch(v)

    _patch(config)

    if changed:
        with open(config_path, "w") as f:
            json.dump(config, f, indent=2)


def _ensure_pretrained() -> None:
    if not os.path.isdir(PRETRAINED_DIR):
        print(f"Downloading pretrained model from {PRETRAINED_REPO} …")
        snapshot_download(
            repo_id=PRETRAINED_REPO,
            repo_type="model",
            local_dir=PRETRAINED_DIR,
        )


def _download_lora(language: str, accent: str) -> str:
    """Download a LoRA adapter if needed; return its local directory."""
    lora_dir = os.path.join(CACHE_DIR, f"{accent.lower()}-accent-{language.lower()}")
    if not os.path.isdir(lora_dir):
        repo_id = MODELS[(language, accent)]
        print(f"Downloading LoRA adapter from {repo_id} …")
        snapshot_download(
            repo_id=repo_id,
            repo_type="model",
            local_dir=lora_dir,
            allow_patterns=["config.json", "lora/best_model/**"],
        )
        _patch_config(os.path.join(lora_dir, "config.json"), PRETRAINED_DIR)
    return lora_dir


def _load_model(language: str, accent1: str, accent2: str | None):
    """Return a cached TTS model with adapter(s) loaded at coeff=1.0."""
    key = (language, accent1, accent2)
    if key in _model_cache:
        return _model_cache[key]

    _ensure_pretrained()

    lora_dir1 = _download_lora(language, accent1)
    checkpoint_path = os.path.join(PRETRAINED_DIR, "checkpoint_0.pth")
    config_path     = os.path.join(lora_dir1, "config.json")
    lora_path1      = os.path.join(lora_dir1, "lora", "best_model")

    tts = load_xtts_model(checkpoint_path, config_path, device=_device)
    tts = attach_lora_adapter(tts, lora_path=lora_path1, adapter_name="default", scaling_coef=1.0)

    if accent2 is not None:
        lora_dir2  = _download_lora(language, accent2)
        lora_path2 = os.path.join(lora_dir2, "lora", "best_model")
        tts = attach_lora_adapter(tts, lora_path=lora_path2, adapter_name="other", scaling_coef=1.0)
        tts.synthesizer.tts_model.set_adapter(["default", "other"])

    _model_cache[key]    = tts
    _current_coeffs[key] = (1.0, 1.0)
    return tts


# ---------------------------------------------------------------------------
# Inference function called by Gradio
# ---------------------------------------------------------------------------

def synthesise(
    text: str,
    speaker_audio: str,
    language: str,
    accent1: str,
    coeff1: float,
    enable_second: bool,
    accent2: str,
    coeff2: float,
):
    if not text.strip():
        raise gr.Error("Please enter some text to synthesise.")
    if speaker_audio is None:
        raise gr.Error("Please upload a reference speaker audio file.")
    if (language, accent1) not in MODELS:
        raise gr.Error(f"Unsupported combination: language={language}, accent={accent1}.")

    accent2_key = accent2 if enable_second else None

    if enable_second and (language, accent2) not in MODELS:
        raise gr.Error(f"Unsupported combination: language={language}, accent={accent2}.")

    tts = _load_model(language, accent1, accent2_key)
    key = (language, accent1, accent2_key)

    # Rescale adapters from their current cached coefficients to the desired ones
    prev_coeff1, prev_coeff2 = _current_coeffs[key]
    _scale_lora(tts, coeff1 / prev_coeff1, adapter_name="default")
    if accent2_key is not None:
        _scale_lora(tts, coeff2 / prev_coeff2, adapter_name="other")
    _current_coeffs[key] = (coeff1, coeff2 if accent2_key else 1.0)

    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        output_path = tmp.name

    tts.tts_to_file(
        text=text,
        speaker_wav=speaker_audio,
        language=LANGUAGE_CODES[language],
        file_path=output_path,
    )

    return output_path


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------

def update_accent_choices(language: str):
    accents = ACCENTS_BY_LANGUAGE.get(language, [])
    return gr.update(choices=accents, value=accents[0])


with gr.Blocks(title="Accent Vectors") as demo:
    gr.Markdown(
        """
# Accent Vectors
Synthesise speech with a controllable accent — pick the output **language**,
the speaker's **accent**, upload a short reference audio clip, and type your text.

> **Paper:** *Accent Vector: Controllable Accent Manipulation for Multilingual TTS
> Without Accented Data* (submitted to Interspeech 2026)
"""
    )

    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Text to synthesise",
                placeholder="Type something here…",
                lines=3,
            )
            speaker_audio = gr.Audio(
                label="Reference speaker audio (3–10 s)",
                type="filepath",
            )

            with gr.Row():
                language_dd = gr.Dropdown(
                    label="Output language",
                    choices=list(ACCENTS_BY_LANGUAGE.keys()),
                    value="English",
                )
                accent1_dd = gr.Dropdown(
                    label="Speaker accent",
                    choices=ACCENTS_BY_LANGUAGE["English"],
                    value="English",
                )
            coeff1_slider = gr.Slider(
                label="Accent strength",
                minimum=0.0, maximum=1.0, step=0.05, value=1.0,
            )

            with gr.Accordion("Mix a second accent (optional)", open=False):
                enable_second = gr.Checkbox(label="Enable second accent", value=False)
                accent2_dd = gr.Dropdown(
                    label="Second accent",
                    choices=ACCENTS_BY_LANGUAGE["English"],
                    value="Hindi",
                    interactive=True,
                )
                coeff2_slider = gr.Slider(
                    label="Second accent strength",
                    minimum=0.0, maximum=1.0, step=0.05, value=0.5,
                )

            generate_btn = gr.Button("Generate", variant="primary")

        with gr.Column():
            audio_output = gr.Audio(label="Generated speech", type="filepath")

    # Update both accent dropdowns when language changes
    language_dd.change(fn=update_accent_choices, inputs=language_dd, outputs=accent1_dd)
    language_dd.change(fn=update_accent_choices, inputs=language_dd, outputs=accent2_dd)

    generate_btn.click(
        fn=synthesise,
        inputs=[
            text_input, speaker_audio,
            language_dd, accent1_dd, coeff1_slider,
            enable_second, accent2_dd, coeff2_slider,
        ],
        outputs=audio_output,
    )

    gr.Markdown(
        """
---
### How to use
1. **Output language** — the language the model will speak in.
2. **Speaker accent** — the L1 accent of the target speaker style.
3. **Reference audio** — a clean 3–10 second clip of any speaker; the model
   clones the voice while applying the chosen accent.
4. **Accent strength** — LoRA adapter contribution (0 = no accent effect, 1 = full).
5. **Mix a second accent** — optionally blend two accents together by enabling
   a second adapter and setting its strength independently.

Models are downloaded automatically on first use.
"""
    )

if __name__ == "__main__":
    demo.launch()