import os import tempfile import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download # Optional HF Spaces GPU decorator try: import spaces gpu_decorator = spaces.GPU except ImportError: def gpu_decorator(fn): return fn import torchaudio from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig from f5_tts.model import CFM, DiT from f5_tts.infer.utils_infer import ( device, load_checkpoint, load_vocoder, preprocess_ref_audio_text, infer_process, target_sample_rate, hop_length, n_fft, win_length, n_mel_channels, ) from f5_tts.model.utils import get_tokenizer # Config MODEL_CKPT = hf_hub_download(os.environ.get("MODEL_REPO_ID"), "model_slim.pt", token=os.environ.get("HF_TOKEN")) VOCAB_FILE = "data/Bengali/vocab.txt" WHISPER_MODEL = "bengaliAI/tugstugi_bengaliai-asr_whisper-medium" # Model architecture (same as F5TTS_v1_Base) model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=True, qk_norm=None, conv_layers=4, pe_attn_head=None, ) # Globals ema_model = None vocoder = None bn_asr_model = None bn_asr_processor = None def load_models(): global ema_model, vocoder if ema_model is not None: return print("Loading Bengali TTS model...") vocab_char_map, vocab_size = get_tokenizer(VOCAB_FILE, "custom") model = CFM( transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type="vocos", ), odeint_kwargs=dict(method="euler"), vocab_char_map=vocab_char_map, ).to(device) ema_model = load_checkpoint(model, MODEL_CKPT, device, use_ema=True) print("Loading vocoder...") vocoder = load_vocoder(vocoder_name="vocos", is_local=False, device=device) print("Models loaded.") def init_bengali_asr(): global bn_asr_model, bn_asr_processor if bn_asr_model is not None: return print("Loading Bengali ASR...") bn_asr_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL) bn_asr_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL).to(device) # Fix outdated generation config bn_asr_model.generation_config = GenerationConfig.from_pretrained("openai/whisper-medium") print("Bengali ASR loaded.") def transcribe_bengali(audio_path: str) -> str: init_bengali_asr() waveform, sr = torchaudio.load(audio_path) if sr != 16000: waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) input_features = bn_asr_processor( waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt" ).input_features.to(device) predicted_ids = bn_asr_model.generate(input_features, language="bn", task="transcribe") text = bn_asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return text.strip() def preprocess_ref_audio_text_bn(ref_audio, ref_text, show_info=print): """Wrapper that uses Bengali ASR instead of default whisper.""" # Use original preprocessing for audio clipping/silence from f5_tts.infer.utils_infer import ( _ref_audio_cache, remove_silence_edges, ) from pydub import AudioSegment, silence import hashlib show_info("Converting audio...") with open(ref_audio, "rb") as f: audio_hash = hashlib.md5(f.read()).hexdigest() if audio_hash in _ref_audio_cache: processed_audio = _ref_audio_cache[audio_hash] else: tempfile_kwargs = {"delete": False, "suffix": ".wav"} with tempfile.NamedTemporaryFile(**tempfile_kwargs) as f: temp_path = f.name aseg = AudioSegment.from_file(ref_audio) # Clip to 15s using silence detection non_silent_segs = silence.split_on_silence( aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000: show_info("Audio over 15s, clipping.") break non_silent_wave += seg if len(non_silent_wave) > 15000: non_silent_segs = silence.split_on_silence( aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000: break non_silent_wave += seg aseg = non_silent_wave if len(aseg) > 15000: aseg = aseg[:15000] show_info("Audio over 15s, hard clip.") aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) aseg.export(temp_path, format="wav") processed_audio = temp_path _ref_audio_cache[audio_hash] = processed_audio # Bengali transcription if no ref_text if not ref_text.strip(): show_info("Transcribing with Bengali ASR...") ref_text = transcribe_bengali(processed_audio) # Ensure proper ending punctuation if not ref_text.endswith(". ") and not ref_text.endswith("।"): if ref_text.endswith("."): ref_text += " " else: ref_text += "। " print("ref_text:", ref_text) return processed_audio, ref_text @gpu_decorator def generate_tts(ref_audio, gen_text, speed): if ref_audio is None: return None, "Please provide reference audio." if not gen_text.strip(): return None, "Please enter text to generate." load_models() try: ref_audio_processed, ref_text_processed = preprocess_ref_audio_text_bn( ref_audio, "" ) audio, sr, _ = infer_process( ref_audio_processed, ref_text_processed, gen_text, ema_model, vocoder, mel_spec_type="vocos", speed=speed, device=device, ) return (sr, audio), f"Generated with ref: '{ref_text_processed[:50]}...'" except Exception as e: return None, f"Error: {str(e)}" # Gradio UI with gr.Blocks(title="Bengali TTS") as demo: gr.Markdown("# Bengali Text-to-Speech") gr.Markdown("Upload or record Bengali audio (max 15s) as reference, then generate speech.") with gr.Row(): with gr.Column(): ref_audio = gr.Audio( label="Reference Audio (record or upload, max 15s)", type="filepath", ) gen_text = gr.Textbox( label="Text to Generate (Bengali)", placeholder="Enter Bengali text here...", lines=3, ) speed = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed", ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(): output_audio = gr.Audio(label="Generated Audio", type="numpy") status = gr.Textbox(label="Status", interactive=False) generate_btn.click( fn=generate_tts, inputs=[ref_audio, gen_text, speed], outputs=[output_audio, status], ) if __name__ == "__main__": demo.launch()