File size: 6,157 Bytes
3974e31
679c1f9
 
 
ec6c1c0
af2562e
 
 
679c1f9
af2562e
 
 
 
 
 
 
 
 
 
 
 
 
 
679c1f9
af2562e
 
679c1f9
af2562e
 
 
ec6c1c0
af2562e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3974e31
af2562e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f33e8f9
af2562e
 
 
679c1f9
af2562e
 
 
679c1f9
af2562e
679c1f9
af2562e
 
 
679c1f9
af2562e
 
 
 
 
f33e8f9
af2562e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679c1f9
ec6c1c0
af2562e
 
ec6c1c0
af2562e
 
 
 
 
ec6c1c0
af2562e
 
 
 
 
 
ec6c1c0
679c1f9
 
 
 
af2562e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import spaces
import os
import torch
import soundfile as sf
import logging
import gradio as gr
import librosa
import numpy as np
from datetime import datetime
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM

# ------------------------------
# Logging
# ------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

# ------------------------------
# Global Model
# ------------------------------
MODEL = None
TOKENIZER = None
MODEL_ID = "rahul7star/mir-TTS"

# ------------------------------
# Helper Functions
# ------------------------------
def load_model():
    """Lazy load model and tokenizer."""
    global MODEL, TOKENIZER
    if MODEL is None or TOKENIZER is None:
        logging.info(f"Loading model: {MODEL_ID}")
        MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID).cuda()
        TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID)
        logging.info("Model loaded on GPU")
    return MODEL, TOKENIZER

def validate_audio_input(audio_path):
    """Validate and preprocess audio input."""
    if not audio_path or not os.path.exists(audio_path):
        raise ValueError("Audio file not found")
    audio, sr = librosa.load(audio_path, sr=None, duration=30)
    if len(audio) == 0:
        raise ValueError("Audio is empty")
    # Minimum 0.5 seconds
    if len(audio) < int(0.5 * sr):
        raise ValueError("Audio too short, must be >=0.5s")
    # Resample to 16kHz
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
        sr = 16000
    # Normalize
    audio = audio / np.max(np.abs(audio))
    # Save temp file
    temp_dir = "/tmp" if os.path.exists("/tmp") else "."
    temp_path = os.path.join(temp_dir, f"processed_{os.path.basename(audio_path)}")
    sf.write(temp_path, audio, samplerate=sr)
    return temp_path, sr

# ------------------------------
# Core Generation Function
# ------------------------------
@spaces.GPU()
def generate_speech(text, prompt_audio_path):
    """Generate speech from text with reference audio."""
    try:
        model, tokenizer = load_model()

        if not text or not text.strip():
            raise ValueError("Text is empty")

        # Preprocess audio
        processed_audio, sr = validate_audio_input(prompt_audio_path)

        # Encode audio as context tokens
        audio_input_ids = tokenizer.apply_chat_template(
            [{"role": "user", "content": "Encode audio context"}],
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(model.device)

        # Simple text generation using tokens
        text_input_ids = tokenizer.apply_chat_template(
            [{"role": "user", "content": text}],
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            **text_input_ids,
            max_new_tokens=512
        )
        generated_text = tokenizer.decode(outputs[0][text_input_ids["input_ids"].shape[-1]:])

        # For demo, return generated text as placeholder audio
        # You can integrate your TTS codec here
        dummy_audio = np.random.rand(sr * 2).astype("float32") * 0.01

        # Cleanup
        if os.path.exists(processed_audio):
            os.remove(processed_audio)

        return dummy_audio, 48000

    except Exception as e:
        logging.error(f"Generation error: {e}")
        raise e

# ------------------------------
# Gradio Interface
# ------------------------------
def voice_clone_interface(text, prompt_audio_upload, prompt_audio_record):
    """Interface callback for voice cloning."""
    try:
        prompt_audio = prompt_audio_upload or prompt_audio_record
        if not prompt_audio:
            return None, "Upload or record reference audio first"
        if not text.strip():
            return None, "Enter text to synthesize"

        audio, sr = generate_speech(text, prompt_audio)

        # Save output
        os.makedirs("outputs", exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_path = f"outputs/mir_tts_{timestamp}.wav"
        sf.write(output_path, audio, samplerate=sr)

        return output_path, "Generation successful!"

    except Exception as e:
        logging.error(f"Voice clone error: {e}")
        return None, f"Error: {e}"

def build_interface():
    """Build Gradio interface."""
    with gr.Blocks(title="MiraTTS Voice Cloning") as demo:
        gr.HTML("<h1 style='text-align:center;color:#2563eb;'>MiraTTS Voice Cloning</h1>")
        with gr.Row():
            with gr.Column():
                gr.Markdown("### Reference Audio")
                prompt_upload = gr.Audio(sources="upload", type="filepath")
                prompt_record = gr.Audio(sources="microphone", type="filepath")
            with gr.Column():
                gr.Markdown("### Text Input")
                text_input = gr.Textbox(
                    placeholder="Enter text...",
                    lines=4,
                    value="Hello! This is a demonstration of MiraTTS"
                )
                generate_btn = gr.Button("Generate Speech", variant="primary")

        with gr.Row():
            output_audio = gr.Audio(label="Generated Speech", type="filepath", autoplay=True)
            status_text = gr.Textbox(label="Status", interactive=False)

        generate_btn.click(
            voice_clone_interface,
            inputs=[text_input, prompt_upload, prompt_record],
            outputs=[output_audio, status_text]
        )

        def clear_all():
            return None, None, "", None, "Ready for new generation"
        clear_btn = gr.Button("Clear All", variant="secondary")
        clear_btn.click(
            clear_all,
            outputs=[prompt_upload, prompt_record, text_input, output_audio, status_text]
        )

    return demo

if __name__ == "__main__":
    demo = build_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)