File size: 7,838 Bytes
a53ae3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb4d60
a53ae3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abe6ed2
a53ae3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
import tempfile

import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download

# Optional HF Spaces GPU decorator
try:
    import spaces
    gpu_decorator = spaces.GPU
except ImportError:
    def gpu_decorator(fn):
        return fn
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig

from f5_tts.model import CFM, DiT
from f5_tts.infer.utils_infer import (
    device,
    load_checkpoint,
    load_vocoder,
    preprocess_ref_audio_text,
    infer_process,
    target_sample_rate,
    hop_length,
    n_fft,
    win_length,
    n_mel_channels,
)
from f5_tts.model.utils import get_tokenizer

# Config
MODEL_CKPT = hf_hub_download(os.environ.get("MODEL_REPO_ID"), "model_slim.pt", token=os.environ.get("HF_TOKEN"))
VOCAB_FILE = "data/Bengali/vocab.txt"
WHISPER_MODEL = "bengaliAI/tugstugi_bengaliai-asr_whisper-medium"

# Model architecture (same as F5TTS_v1_Base)
model_cfg = dict(
    dim=1024,
    depth=22,
    heads=16,
    ff_mult=2,
    text_dim=512,
    text_mask_padding=True,
    qk_norm=None,
    conv_layers=4,
    pe_attn_head=None,
)

# Globals
ema_model = None
vocoder = None
bn_asr_model = None
bn_asr_processor = None


def load_models():
    global ema_model, vocoder
    if ema_model is not None:
        return

    print("Loading Bengali TTS model...")
    vocab_char_map, vocab_size = get_tokenizer(VOCAB_FILE, "custom")
    model = CFM(
        transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
        mel_spec_kwargs=dict(
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            n_mel_channels=n_mel_channels,
            target_sample_rate=target_sample_rate,
            mel_spec_type="vocos",
        ),
        odeint_kwargs=dict(method="euler"),
        vocab_char_map=vocab_char_map,
    ).to(device)
    ema_model = load_checkpoint(model, MODEL_CKPT, device, use_ema=True)

    print("Loading vocoder...")
    vocoder = load_vocoder(vocoder_name="vocos", is_local=False, device=device)
    print("Models loaded.")


def init_bengali_asr():
    global bn_asr_model, bn_asr_processor
    if bn_asr_model is not None:
        return

    print("Loading Bengali ASR...")
    bn_asr_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL)
    bn_asr_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL).to(device)
    # Fix outdated generation config
    bn_asr_model.generation_config = GenerationConfig.from_pretrained("openai/whisper-medium")
    print("Bengali ASR loaded.")


def transcribe_bengali(audio_path: str) -> str:
    init_bengali_asr()
    waveform, sr = torchaudio.load(audio_path)
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    input_features = bn_asr_processor(
        waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt"
    ).input_features.to(device)
    predicted_ids = bn_asr_model.generate(input_features, language="bn", task="transcribe")
    text = bn_asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return text.strip()


def preprocess_ref_audio_text_bn(ref_audio, ref_text, show_info=print):
    """Wrapper that uses Bengali ASR instead of default whisper."""
    # Use original preprocessing for audio clipping/silence
    from f5_tts.infer.utils_infer import (
        _ref_audio_cache,
        remove_silence_edges,
    )
    from pydub import AudioSegment, silence
    import hashlib

    show_info("Converting audio...")

    with open(ref_audio, "rb") as f:
        audio_hash = hashlib.md5(f.read()).hexdigest()

    if audio_hash in _ref_audio_cache:
        processed_audio = _ref_audio_cache[audio_hash]
    else:
        tempfile_kwargs = {"delete": False, "suffix": ".wav"}
        with tempfile.NamedTemporaryFile(**tempfile_kwargs) as f:
            temp_path = f.name

        aseg = AudioSegment.from_file(ref_audio)

        # Clip to 15s using silence detection
        non_silent_segs = silence.split_on_silence(
            aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
        )
        non_silent_wave = AudioSegment.silent(duration=0)
        for seg in non_silent_segs:
            if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000:
                show_info("Audio over 15s, clipping.")
                break
            non_silent_wave += seg

        if len(non_silent_wave) > 15000:
            non_silent_segs = silence.split_on_silence(
                aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
            )
            non_silent_wave = AudioSegment.silent(duration=0)
            for seg in non_silent_segs:
                if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000:
                    break
                non_silent_wave += seg

        aseg = non_silent_wave
        if len(aseg) > 15000:
            aseg = aseg[:15000]
            show_info("Audio over 15s, hard clip.")

        aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
        aseg.export(temp_path, format="wav")
        processed_audio = temp_path
        _ref_audio_cache[audio_hash] = processed_audio

    # Bengali transcription if no ref_text
    if not ref_text.strip():
        show_info("Transcribing with Bengali ASR...")
        ref_text = transcribe_bengali(processed_audio)

    # Ensure proper ending punctuation
    if not ref_text.endswith(". ") and not ref_text.endswith("।"):
        if ref_text.endswith("."):
            ref_text += " "
        else:
            ref_text += "। "

    print("ref_text:", ref_text)
    return processed_audio, ref_text


@gpu_decorator
def generate_tts(ref_audio, gen_text, speed):
    if ref_audio is None:
        return None, "Please provide reference audio."
    if not gen_text.strip():
        return None, "Please enter text to generate."

    load_models()

    try:
        ref_audio_processed, ref_text_processed = preprocess_ref_audio_text_bn(
            ref_audio, ""
        )

        audio, sr, _ = infer_process(
            ref_audio_processed,
            ref_text_processed,
            gen_text,
            ema_model,
            vocoder,
            mel_spec_type="vocos",
            speed=speed,
            device=device,
        )

        return (sr, audio), f"Generated with ref: '{ref_text_processed[:50]}...'"

    except Exception as e:
        return None, f"Error: {str(e)}"


# Gradio UI
with gr.Blocks(title="Bengali TTS") as demo:
    gr.Markdown("# Bengali Text-to-Speech")
    gr.Markdown("Upload or record Bengali audio (max 15s) as reference, then generate speech.")

    with gr.Row():
        with gr.Column():
            ref_audio = gr.Audio(
                label="Reference Audio (record or upload, max 15s)",
                type="filepath",
            )
            gen_text = gr.Textbox(
                label="Text to Generate (Bengali)",
                placeholder="Enter Bengali text here...",
                lines=3,
            )
            speed = gr.Slider(
                minimum=0.5,
                maximum=2.0,
                value=1.0,
                step=0.1,
                label="Speed",
            )
            generate_btn = gr.Button("Generate", variant="primary")

        with gr.Column():
            output_audio = gr.Audio(label="Generated Audio", type="numpy")
            status = gr.Textbox(label="Status", interactive=False)

    generate_btn.click(
        fn=generate_tts,
        inputs=[ref_audio, gen_text, speed],
        outputs=[output_audio, status],
    )

if __name__ == "__main__":
    demo.launch()