Spaces:
Sleeping
Sleeping
File size: 5,517 Bytes
10303b0 821a89d 4d2241c 821a89d 4d2241c 10303b0 4d2241c a12a098 4d2241c a12a098 8d1af86 10303b0 e82747c a12a098 4d2241c a12a098 9175bc3 439306b 821a89d 9175bc3 821a89d 9175bc3 a12a098 4d2241c a12a098 4d2241c a12a098 8d1af86 a12a098 8d1af86 a12a098 10303b0 a12a098 8d1af86 a12a098 8d1af86 a12a098 5142e44 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import subprocess
import sys
import os
import tempfile
from huggingface_hub import hf_hub_download
# --- 1. PRE-FLIGHT: BYPASS BUILD ISOLATION ---
def pre_flight_setup():
try:
import chatterbox
except ImportError:
print("Applying pkuseg build isolation bypass...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0", "cython", "wheel"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "pkuseg==0.0.25"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "chatterbox-tts>=0.1.7"])
pre_flight_setup()
# --- 2. MAIN APPLICATION ---
import gradio as gr
import torch
import torch.nn as nn
import torchaudio as ta
from peft import PeftModel
from chatterbox.tts import ChatterboxTTS
from chatterbox.models.tokenizers import EnTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
repo_id = "Praha-Labs/PrahaTTS-ML"
def load_model():
print(f"Loading base Chatterbox model on {device}...")
model = ChatterboxTTS.from_pretrained(device=device)
print("Applying custom Indic tokenizer...")
try:
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_indic.json")
model.tokenizer = EnTokenizer(tokenizer_path)
except Exception as e:
print(f"Error during tokenizer inject: {e}")
# --- CRITICAL FIX: MANUALLY RESIZE PYTORCH EMBEDDINGS ---
# We must resize the base model's vocabulary layers to match the new
# Malayalam vocab size (2573) before loading the adapter weights.
vocab_size = 2573
print(f"Resizing base model embeddings to handle vocab size of {vocab_size}...")
target_layer = model.t3 if hasattr(model, 't3') else model
if hasattr(target_layer, 'text_emb'):
embed_dim = target_layer.text_emb.embedding_dim
target_layer.text_emb = nn.Embedding(vocab_size, embed_dim)
if hasattr(target_layer, 'text_head'):
in_features = target_layer.text_head.in_features
has_bias = target_layer.text_head.bias is not None
target_layer.text_head = nn.Linear(in_features, vocab_size, bias=has_bias)
# Send resized layers to the correct device
target_layer.to(device)
print("Loading LoRA adapter weights...")
try:
if hasattr(model, 't3'):
model.t3 = PeftModel.from_pretrained(model.t3, repo_id)
else:
model = PeftModel.from_pretrained(model, repo_id)
print("LoRA adapter loaded successfully.")
except Exception as e:
print(f"Failed to load PEFT adapter: {e}")
return model
# Initialize Model
tts_model = load_model()
def synthesize_audio(text, ref_audio, exaggeration, cfg_weight):
if not text.strip():
return None, "Please enter some text."
audio_prompt_path = ref_audio if ref_audio else None
try:
wav = tts_model.generate(
text,
audio_prompt_path=audio_prompt_path,
exaggeration=exaggeration,
cfg_weight=cfg_weight
)
temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
ta.save(temp_out.name, wav.cpu(), tts_model.sr)
return temp_out.name, "Generation successful!"
except Exception as e:
return None, f"Generation Error: {str(e)}"
# Define the Gradio Interface
with gr.Blocks(title="PrahaTTS-ML: Malayalam TTS", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🗣️ PrahaTTS-ML: Malayalam LoRA Adapter")
gr.Markdown(
"This Space runs the [Praha-Labs/PrahaTTS-ML](https://huggingface.co/Praha-Labs/PrahaTTS-ML) model. "
"It is a Malayalam LoRA adapter built on top of ResembleAI's Chatterbox non-turbo TTS model. \n\n"
"**Note**: Provide up to 5-10 seconds of clear reference audio for voice cloning capabilities."
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input Text (Malayalam)",
lines=4,
placeholder="നമസ്കാരം, നിങ്ങൾക്കെങ്ങനെയുണ്ട്?"
)
ref_audio_input = gr.Audio(
label="Reference Voice Audio (Optional, for Voice Cloning)",
type="filepath"
)
with gr.Accordion("Advanced Voice Controls", open=False):
exaggeration_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Emotion Exaggeration",
info="Lower for monotone, higher for dramatic/expressive"
)
cfg_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="CFG Weight",
info="Lower if speech is too fast, higher to strictly mimic the reference voice"
)
generate_btn = gr.Button("Synthesize Speech", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Generated Output", interactive=False)
status_output = gr.Textbox(label="Status Logging", interactive=False)
generate_btn.click(
fn=synthesize_audio,
inputs=[text_input, ref_audio_input, exaggeration_slider, cfg_slider],
outputs=[audio_output, status_output]
)
if __name__ == "__main__":
demo.launch() |