File size: 5,517 Bytes
10303b0
 
821a89d
4d2241c
821a89d
 
4d2241c
10303b0
 
 
 
 
 
 
 
 
 
 
4d2241c
a12a098
 
4d2241c
a12a098
 
8d1af86
10303b0
e82747c
a12a098
 
4d2241c
a12a098
 
9175bc3
 
439306b
821a89d
9175bc3
821a89d
9175bc3
 
 
a12a098
4d2241c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a12a098
 
4d2241c
 
 
 
a12a098
 
 
 
 
 
 
 
 
8d1af86
a12a098
 
 
 
 
 
 
 
 
 
8d1af86
a12a098
 
 
 
 
 
 
 
 
 
 
10303b0
a12a098
 
 
 
 
 
 
 
 
8d1af86
a12a098
8d1af86
a12a098
 
5142e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import subprocess
import sys
import os
import tempfile
from huggingface_hub import hf_hub_download

# --- 1. PRE-FLIGHT: BYPASS BUILD ISOLATION ---
def pre_flight_setup():
    try:
        import chatterbox
    except ImportError:
        print("Applying pkuseg build isolation bypass...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0", "cython", "wheel"])
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "pkuseg==0.0.25"])
        subprocess.check_call([sys.executable, "-m", "pip", "install", "chatterbox-tts>=0.1.7"])

pre_flight_setup()

# --- 2. MAIN APPLICATION ---
import gradio as gr
import torch
import torch.nn as nn
import torchaudio as ta
from peft import PeftModel

from chatterbox.tts import ChatterboxTTS
from chatterbox.models.tokenizers import EnTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
repo_id = "Praha-Labs/PrahaTTS-ML"

def load_model():
    print(f"Loading base Chatterbox model on {device}...")
    model = ChatterboxTTS.from_pretrained(device=device)

    print("Applying custom Indic tokenizer...")
    try:
        tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_indic.json")
        model.tokenizer = EnTokenizer(tokenizer_path)
    except Exception as e:
        print(f"Error during tokenizer inject: {e}")

    # --- CRITICAL FIX: MANUALLY RESIZE PYTORCH EMBEDDINGS ---
    # We must resize the base model's vocabulary layers to match the new 
    # Malayalam vocab size (2573) before loading the adapter weights.
    vocab_size = 2573
    print(f"Resizing base model embeddings to handle vocab size of {vocab_size}...")
    
    target_layer = model.t3 if hasattr(model, 't3') else model
    
    if hasattr(target_layer, 'text_emb'):
        embed_dim = target_layer.text_emb.embedding_dim
        target_layer.text_emb = nn.Embedding(vocab_size, embed_dim)
        
    if hasattr(target_layer, 'text_head'):
        in_features = target_layer.text_head.in_features
        has_bias = target_layer.text_head.bias is not None
        target_layer.text_head = nn.Linear(in_features, vocab_size, bias=has_bias)
        
    # Send resized layers to the correct device
    target_layer.to(device)

    print("Loading LoRA adapter weights...")
    try:
        if hasattr(model, 't3'):
            model.t3 = PeftModel.from_pretrained(model.t3, repo_id)
        else:
            model = PeftModel.from_pretrained(model, repo_id)
        print("LoRA adapter loaded successfully.")
    except Exception as e:
        print(f"Failed to load PEFT adapter: {e}")
        
    return model

# Initialize Model
tts_model = load_model()

def synthesize_audio(text, ref_audio, exaggeration, cfg_weight):
    if not text.strip():
        return None, "Please enter some text."
        
    audio_prompt_path = ref_audio if ref_audio else None
    
    try:
        wav = tts_model.generate(
            text, 
            audio_prompt_path=audio_prompt_path,
            exaggeration=exaggeration,
            cfg_weight=cfg_weight
        )
        
        temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
        ta.save(temp_out.name, wav.cpu(), tts_model.sr)
        return temp_out.name, "Generation successful!"
        
    except Exception as e:
        return None, f"Generation Error: {str(e)}"

# Define the Gradio Interface
with gr.Blocks(title="PrahaTTS-ML: Malayalam TTS", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🗣️ PrahaTTS-ML: Malayalam LoRA Adapter")
    gr.Markdown(
        "This Space runs the [Praha-Labs/PrahaTTS-ML](https://huggingface.co/Praha-Labs/PrahaTTS-ML) model. "
        "It is a Malayalam LoRA adapter built on top of ResembleAI's Chatterbox non-turbo TTS model. \n\n"
        "**Note**: Provide up to 5-10 seconds of clear reference audio for voice cloning capabilities."
    )
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Input Text (Malayalam)", 
                lines=4, 
                placeholder="നമസ്കാരം, നിങ്ങൾക്കെങ്ങനെയുണ്ട്?"
            )
            ref_audio_input = gr.Audio(
                label="Reference Voice Audio (Optional, for Voice Cloning)", 
                type="filepath"
            )
            
            with gr.Accordion("Advanced Voice Controls", open=False):
                exaggeration_slider = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.5, step=0.05, 
                    label="Emotion Exaggeration", 
                    info="Lower for monotone, higher for dramatic/expressive"
                )
                cfg_slider = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.5, step=0.05, 
                    label="CFG Weight", 
                    info="Lower if speech is too fast, higher to strictly mimic the reference voice"
                )
                
            generate_btn = gr.Button("Synthesize Speech", variant="primary")
            
        with gr.Column():
            audio_output = gr.Audio(label="Generated Output", interactive=False)
            status_output = gr.Textbox(label="Status Logging", interactive=False)
            
    generate_btn.click(
        fn=synthesize_audio,
        inputs=[text_input, ref_audio_input, exaggeration_slider, cfg_slider],
        outputs=[audio_output, status_output]
    )

if __name__ == "__main__":
    demo.launch()