Spaces:
Build error
Build error
| import os | |
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| # Optional HF Spaces GPU decorator | |
| try: | |
| import spaces | |
| gpu_decorator = spaces.GPU | |
| except ImportError: | |
| def gpu_decorator(fn): | |
| return fn | |
| import torchaudio | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig | |
| from f5_tts.model import CFM, DiT | |
| from f5_tts.infer.utils_infer import ( | |
| device, | |
| load_checkpoint, | |
| load_vocoder, | |
| preprocess_ref_audio_text, | |
| infer_process, | |
| target_sample_rate, | |
| hop_length, | |
| n_fft, | |
| win_length, | |
| n_mel_channels, | |
| ) | |
| from f5_tts.model.utils import get_tokenizer | |
| # Config | |
| MODEL_CKPT = hf_hub_download(os.environ.get("MODEL_REPO_ID"), "model_slim.pt", token=os.environ.get("HF_TOKEN")) | |
| VOCAB_FILE = "data/Bengali/vocab.txt" | |
| WHISPER_MODEL = "bengaliAI/tugstugi_bengaliai-asr_whisper-medium" | |
| # Model architecture (same as F5TTS_v1_Base) | |
| model_cfg = dict( | |
| dim=1024, | |
| depth=22, | |
| heads=16, | |
| ff_mult=2, | |
| text_dim=512, | |
| text_mask_padding=True, | |
| qk_norm=None, | |
| conv_layers=4, | |
| pe_attn_head=None, | |
| ) | |
| # Globals | |
| ema_model = None | |
| vocoder = None | |
| bn_asr_model = None | |
| bn_asr_processor = None | |
| def load_models(): | |
| global ema_model, vocoder | |
| if ema_model is not None: | |
| return | |
| print("Loading Bengali TTS model...") | |
| vocab_char_map, vocab_size = get_tokenizer(VOCAB_FILE, "custom") | |
| model = CFM( | |
| transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), | |
| mel_spec_kwargs=dict( | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| win_length=win_length, | |
| n_mel_channels=n_mel_channels, | |
| target_sample_rate=target_sample_rate, | |
| mel_spec_type="vocos", | |
| ), | |
| odeint_kwargs=dict(method="euler"), | |
| vocab_char_map=vocab_char_map, | |
| ).to(device) | |
| ema_model = load_checkpoint(model, MODEL_CKPT, device, use_ema=True) | |
| print("Loading vocoder...") | |
| vocoder = load_vocoder(vocoder_name="vocos", is_local=False, device=device) | |
| print("Models loaded.") | |
| def init_bengali_asr(): | |
| global bn_asr_model, bn_asr_processor | |
| if bn_asr_model is not None: | |
| return | |
| print("Loading Bengali ASR...") | |
| bn_asr_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL) | |
| bn_asr_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL).to(device) | |
| # Fix outdated generation config | |
| bn_asr_model.generation_config = GenerationConfig.from_pretrained("openai/whisper-medium") | |
| print("Bengali ASR loaded.") | |
| def transcribe_bengali(audio_path: str) -> str: | |
| init_bengali_asr() | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != 16000: | |
| waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| input_features = bn_asr_processor( | |
| waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt" | |
| ).input_features.to(device) | |
| predicted_ids = bn_asr_model.generate(input_features, language="bn", task="transcribe") | |
| text = bn_asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return text.strip() | |
| def preprocess_ref_audio_text_bn(ref_audio, ref_text, show_info=print): | |
| """Wrapper that uses Bengali ASR instead of default whisper.""" | |
| # Use original preprocessing for audio clipping/silence | |
| from f5_tts.infer.utils_infer import ( | |
| _ref_audio_cache, | |
| remove_silence_edges, | |
| ) | |
| from pydub import AudioSegment, silence | |
| import hashlib | |
| show_info("Converting audio...") | |
| with open(ref_audio, "rb") as f: | |
| audio_hash = hashlib.md5(f.read()).hexdigest() | |
| if audio_hash in _ref_audio_cache: | |
| processed_audio = _ref_audio_cache[audio_hash] | |
| else: | |
| tempfile_kwargs = {"delete": False, "suffix": ".wav"} | |
| with tempfile.NamedTemporaryFile(**tempfile_kwargs) as f: | |
| temp_path = f.name | |
| aseg = AudioSegment.from_file(ref_audio) | |
| # Clip to 15s using silence detection | |
| non_silent_segs = silence.split_on_silence( | |
| aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 | |
| ) | |
| non_silent_wave = AudioSegment.silent(duration=0) | |
| for seg in non_silent_segs: | |
| if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000: | |
| show_info("Audio over 15s, clipping.") | |
| break | |
| non_silent_wave += seg | |
| if len(non_silent_wave) > 15000: | |
| non_silent_segs = silence.split_on_silence( | |
| aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 | |
| ) | |
| non_silent_wave = AudioSegment.silent(duration=0) | |
| for seg in non_silent_segs: | |
| if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000: | |
| break | |
| non_silent_wave += seg | |
| aseg = non_silent_wave | |
| if len(aseg) > 15000: | |
| aseg = aseg[:15000] | |
| show_info("Audio over 15s, hard clip.") | |
| aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) | |
| aseg.export(temp_path, format="wav") | |
| processed_audio = temp_path | |
| _ref_audio_cache[audio_hash] = processed_audio | |
| # Bengali transcription if no ref_text | |
| if not ref_text.strip(): | |
| show_info("Transcribing with Bengali ASR...") | |
| ref_text = transcribe_bengali(processed_audio) | |
| # Ensure proper ending punctuation | |
| if not ref_text.endswith(". ") and not ref_text.endswith("।"): | |
| if ref_text.endswith("."): | |
| ref_text += " " | |
| else: | |
| ref_text += "। " | |
| print("ref_text:", ref_text) | |
| return processed_audio, ref_text | |
| def generate_tts(ref_audio, gen_text, speed): | |
| if ref_audio is None: | |
| return None, "Please provide reference audio." | |
| if not gen_text.strip(): | |
| return None, "Please enter text to generate." | |
| load_models() | |
| try: | |
| ref_audio_processed, ref_text_processed = preprocess_ref_audio_text_bn( | |
| ref_audio, "" | |
| ) | |
| audio, sr, _ = infer_process( | |
| ref_audio_processed, | |
| ref_text_processed, | |
| gen_text, | |
| ema_model, | |
| vocoder, | |
| mel_spec_type="vocos", | |
| speed=speed, | |
| device=device, | |
| ) | |
| return (sr, audio), f"Generated with ref: '{ref_text_processed[:50]}...'" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # Gradio UI | |
| with gr.Blocks(title="Bengali TTS") as demo: | |
| gr.Markdown("# Bengali Text-to-Speech") | |
| gr.Markdown("Upload or record Bengali audio (max 15s) as reference, then generate speech.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| ref_audio = gr.Audio( | |
| label="Reference Audio (record or upload, max 15s)", | |
| type="filepath", | |
| ) | |
| gen_text = gr.Textbox( | |
| label="Text to Generate (Bengali)", | |
| placeholder="Enter Bengali text here...", | |
| lines=3, | |
| ) | |
| speed = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Speed", | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output_audio = gr.Audio(label="Generated Audio", type="numpy") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| generate_btn.click( | |
| fn=generate_tts, | |
| inputs=[ref_audio, gen_text, speed], | |
| outputs=[output_audio, status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |