PrahaTTS-ML / app.py
trysem's picture
Update app.py
4d2241c verified
Raw
History Blame Contribute Delete
5.52 kB
import subprocess
import sys
import os
import tempfile
from huggingface_hub import hf_hub_download
# --- 1. PRE-FLIGHT: BYPASS BUILD ISOLATION ---
def pre_flight_setup():
try:
import chatterbox
except ImportError:
print("Applying pkuseg build isolation bypass...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0", "cython", "wheel"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "pkuseg==0.0.25"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "chatterbox-tts>=0.1.7"])
pre_flight_setup()
# --- 2. MAIN APPLICATION ---
import gradio as gr
import torch
import torch.nn as nn
import torchaudio as ta
from peft import PeftModel
from chatterbox.tts import ChatterboxTTS
from chatterbox.models.tokenizers import EnTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
repo_id = "Praha-Labs/PrahaTTS-ML"
def load_model():
print(f"Loading base Chatterbox model on {device}...")
model = ChatterboxTTS.from_pretrained(device=device)
print("Applying custom Indic tokenizer...")
try:
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_indic.json")
model.tokenizer = EnTokenizer(tokenizer_path)
except Exception as e:
print(f"Error during tokenizer inject: {e}")
# --- CRITICAL FIX: MANUALLY RESIZE PYTORCH EMBEDDINGS ---
# We must resize the base model's vocabulary layers to match the new
# Malayalam vocab size (2573) before loading the adapter weights.
vocab_size = 2573
print(f"Resizing base model embeddings to handle vocab size of {vocab_size}...")
target_layer = model.t3 if hasattr(model, 't3') else model
if hasattr(target_layer, 'text_emb'):
embed_dim = target_layer.text_emb.embedding_dim
target_layer.text_emb = nn.Embedding(vocab_size, embed_dim)
if hasattr(target_layer, 'text_head'):
in_features = target_layer.text_head.in_features
has_bias = target_layer.text_head.bias is not None
target_layer.text_head = nn.Linear(in_features, vocab_size, bias=has_bias)
# Send resized layers to the correct device
target_layer.to(device)
print("Loading LoRA adapter weights...")
try:
if hasattr(model, 't3'):
model.t3 = PeftModel.from_pretrained(model.t3, repo_id)
else:
model = PeftModel.from_pretrained(model, repo_id)
print("LoRA adapter loaded successfully.")
except Exception as e:
print(f"Failed to load PEFT adapter: {e}")
return model
# Initialize Model
tts_model = load_model()
def synthesize_audio(text, ref_audio, exaggeration, cfg_weight):
if not text.strip():
return None, "Please enter some text."
audio_prompt_path = ref_audio if ref_audio else None
try:
wav = tts_model.generate(
text,
audio_prompt_path=audio_prompt_path,
exaggeration=exaggeration,
cfg_weight=cfg_weight
)
temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
ta.save(temp_out.name, wav.cpu(), tts_model.sr)
return temp_out.name, "Generation successful!"
except Exception as e:
return None, f"Generation Error: {str(e)}"
# Define the Gradio Interface
with gr.Blocks(title="PrahaTTS-ML: Malayalam TTS", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🗣️ PrahaTTS-ML: Malayalam LoRA Adapter")
gr.Markdown(
"This Space runs the [Praha-Labs/PrahaTTS-ML](https://huggingface.co/Praha-Labs/PrahaTTS-ML) model. "
"It is a Malayalam LoRA adapter built on top of ResembleAI's Chatterbox non-turbo TTS model. \n\n"
"**Note**: Provide up to 5-10 seconds of clear reference audio for voice cloning capabilities."
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input Text (Malayalam)",
lines=4,
placeholder="നമസ്കാരം, നിങ്ങൾക്കെങ്ങനെയുണ്ട്?"
)
ref_audio_input = gr.Audio(
label="Reference Voice Audio (Optional, for Voice Cloning)",
type="filepath"
)
with gr.Accordion("Advanced Voice Controls", open=False):
exaggeration_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Emotion Exaggeration",
info="Lower for monotone, higher for dramatic/expressive"
)
cfg_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="CFG Weight",
info="Lower if speech is too fast, higher to strictly mimic the reference voice"
)
generate_btn = gr.Button("Synthesize Speech", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Generated Output", interactive=False)
status_output = gr.Textbox(label="Status Logging", interactive=False)
generate_btn.click(
fn=synthesize_audio,
inputs=[text_input, ref_audio_input, exaggeration_slider, cfg_slider],
outputs=[audio_output, status_output]
)
if __name__ == "__main__":
demo.launch()