MayerBhasha / app.py
Umong's picture
Remove unsupported max_length param
abe6ed2 verified
import os
import tempfile
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
# Optional HF Spaces GPU decorator
try:
import spaces
gpu_decorator = spaces.GPU
except ImportError:
def gpu_decorator(fn):
return fn
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig
from f5_tts.model import CFM, DiT
from f5_tts.infer.utils_infer import (
device,
load_checkpoint,
load_vocoder,
preprocess_ref_audio_text,
infer_process,
target_sample_rate,
hop_length,
n_fft,
win_length,
n_mel_channels,
)
from f5_tts.model.utils import get_tokenizer
# Config
MODEL_CKPT = hf_hub_download(os.environ.get("MODEL_REPO_ID"), "model_slim.pt", token=os.environ.get("HF_TOKEN"))
VOCAB_FILE = "data/Bengali/vocab.txt"
WHISPER_MODEL = "bengaliAI/tugstugi_bengaliai-asr_whisper-medium"
# Model architecture (same as F5TTS_v1_Base)
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=True,
qk_norm=None,
conv_layers=4,
pe_attn_head=None,
)
# Globals
ema_model = None
vocoder = None
bn_asr_model = None
bn_asr_processor = None
def load_models():
global ema_model, vocoder
if ema_model is not None:
return
print("Loading Bengali TTS model...")
vocab_char_map, vocab_size = get_tokenizer(VOCAB_FILE, "custom")
model = CFM(
transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type="vocos",
),
odeint_kwargs=dict(method="euler"),
vocab_char_map=vocab_char_map,
).to(device)
ema_model = load_checkpoint(model, MODEL_CKPT, device, use_ema=True)
print("Loading vocoder...")
vocoder = load_vocoder(vocoder_name="vocos", is_local=False, device=device)
print("Models loaded.")
def init_bengali_asr():
global bn_asr_model, bn_asr_processor
if bn_asr_model is not None:
return
print("Loading Bengali ASR...")
bn_asr_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL)
bn_asr_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL).to(device)
# Fix outdated generation config
bn_asr_model.generation_config = GenerationConfig.from_pretrained("openai/whisper-medium")
print("Bengali ASR loaded.")
def transcribe_bengali(audio_path: str) -> str:
init_bengali_asr()
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
input_features = bn_asr_processor(
waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt"
).input_features.to(device)
predicted_ids = bn_asr_model.generate(input_features, language="bn", task="transcribe")
text = bn_asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return text.strip()
def preprocess_ref_audio_text_bn(ref_audio, ref_text, show_info=print):
"""Wrapper that uses Bengali ASR instead of default whisper."""
# Use original preprocessing for audio clipping/silence
from f5_tts.infer.utils_infer import (
_ref_audio_cache,
remove_silence_edges,
)
from pydub import AudioSegment, silence
import hashlib
show_info("Converting audio...")
with open(ref_audio, "rb") as f:
audio_hash = hashlib.md5(f.read()).hexdigest()
if audio_hash in _ref_audio_cache:
processed_audio = _ref_audio_cache[audio_hash]
else:
tempfile_kwargs = {"delete": False, "suffix": ".wav"}
with tempfile.NamedTemporaryFile(**tempfile_kwargs) as f:
temp_path = f.name
aseg = AudioSegment.from_file(ref_audio)
# Clip to 15s using silence detection
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000:
show_info("Audio over 15s, clipping.")
break
non_silent_wave += seg
if len(non_silent_wave) > 15000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + seg) > 15000:
break
non_silent_wave += seg
aseg = non_silent_wave
if len(aseg) > 15000:
aseg = aseg[:15000]
show_info("Audio over 15s, hard clip.")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(temp_path, format="wav")
processed_audio = temp_path
_ref_audio_cache[audio_hash] = processed_audio
# Bengali transcription if no ref_text
if not ref_text.strip():
show_info("Transcribing with Bengali ASR...")
ref_text = transcribe_bengali(processed_audio)
# Ensure proper ending punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("।"):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += "। "
print("ref_text:", ref_text)
return processed_audio, ref_text
@gpu_decorator
def generate_tts(ref_audio, gen_text, speed):
if ref_audio is None:
return None, "Please provide reference audio."
if not gen_text.strip():
return None, "Please enter text to generate."
load_models()
try:
ref_audio_processed, ref_text_processed = preprocess_ref_audio_text_bn(
ref_audio, ""
)
audio, sr, _ = infer_process(
ref_audio_processed,
ref_text_processed,
gen_text,
ema_model,
vocoder,
mel_spec_type="vocos",
speed=speed,
device=device,
)
return (sr, audio), f"Generated with ref: '{ref_text_processed[:50]}...'"
except Exception as e:
return None, f"Error: {str(e)}"
# Gradio UI
with gr.Blocks(title="Bengali TTS") as demo:
gr.Markdown("# Bengali Text-to-Speech")
gr.Markdown("Upload or record Bengali audio (max 15s) as reference, then generate speech.")
with gr.Row():
with gr.Column():
ref_audio = gr.Audio(
label="Reference Audio (record or upload, max 15s)",
type="filepath",
)
gen_text = gr.Textbox(
label="Text to Generate (Bengali)",
placeholder="Enter Bengali text here...",
lines=3,
)
speed = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speed",
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_audio = gr.Audio(label="Generated Audio", type="numpy")
status = gr.Textbox(label="Status", interactive=False)
generate_btn.click(
fn=generate_tts,
inputs=[ref_audio, gen_text, speed],
outputs=[output_audio, status],
)
if __name__ == "__main__":
demo.launch()