Spaces:
Paused
Paused
| import spaces | |
| import os | |
| import torch | |
| import soundfile as sf | |
| import logging | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| from datetime import datetime | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ------------------------------ | |
| # Logging | |
| # ------------------------------ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| # ------------------------------ | |
| # Global Model | |
| # ------------------------------ | |
| MODEL = None | |
| TOKENIZER = None | |
| MODEL_ID = "rahul7star/mir-TTS" | |
| # ------------------------------ | |
| # Helper Functions | |
| # ------------------------------ | |
| def load_model(): | |
| """Lazy load model and tokenizer.""" | |
| global MODEL, TOKENIZER | |
| if MODEL is None or TOKENIZER is None: | |
| logging.info(f"Loading model: {MODEL_ID}") | |
| MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID).cuda() | |
| TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID) | |
| logging.info("Model loaded on GPU") | |
| return MODEL, TOKENIZER | |
| def validate_audio_input(audio_path): | |
| """Validate and preprocess audio input.""" | |
| if not audio_path or not os.path.exists(audio_path): | |
| raise ValueError("Audio file not found") | |
| audio, sr = librosa.load(audio_path, sr=None, duration=30) | |
| if len(audio) == 0: | |
| raise ValueError("Audio is empty") | |
| # Minimum 0.5 seconds | |
| if len(audio) < int(0.5 * sr): | |
| raise ValueError("Audio too short, must be >=0.5s") | |
| # Resample to 16kHz | |
| if sr != 16000: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) | |
| sr = 16000 | |
| # Normalize | |
| audio = audio / np.max(np.abs(audio)) | |
| # Save temp file | |
| temp_dir = "/tmp" if os.path.exists("/tmp") else "." | |
| temp_path = os.path.join(temp_dir, f"processed_{os.path.basename(audio_path)}") | |
| sf.write(temp_path, audio, samplerate=sr) | |
| return temp_path, sr | |
| # ------------------------------ | |
| # Core Generation Function | |
| # ------------------------------ | |
| def generate_speech(text, prompt_audio_path): | |
| """Generate speech from text with reference audio.""" | |
| try: | |
| model, tokenizer = load_model() | |
| if not text or not text.strip(): | |
| raise ValueError("Text is empty") | |
| # Preprocess audio | |
| processed_audio, sr = validate_audio_input(prompt_audio_path) | |
| # Encode audio as context tokens | |
| audio_input_ids = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": "Encode audio context"}], | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| # Simple text generation using tokens | |
| text_input_ids = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": text}], | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| outputs = model.generate( | |
| **text_input_ids, | |
| max_new_tokens=512 | |
| ) | |
| generated_text = tokenizer.decode(outputs[0][text_input_ids["input_ids"].shape[-1]:]) | |
| # For demo, return generated text as placeholder audio | |
| # You can integrate your TTS codec here | |
| dummy_audio = np.random.rand(sr * 2).astype("float32") * 0.01 | |
| # Cleanup | |
| if os.path.exists(processed_audio): | |
| os.remove(processed_audio) | |
| return dummy_audio, 48000 | |
| except Exception as e: | |
| logging.error(f"Generation error: {e}") | |
| raise e | |
| # ------------------------------ | |
| # Gradio Interface | |
| # ------------------------------ | |
| def voice_clone_interface(text, prompt_audio_upload, prompt_audio_record): | |
| """Interface callback for voice cloning.""" | |
| try: | |
| prompt_audio = prompt_audio_upload or prompt_audio_record | |
| if not prompt_audio: | |
| return None, "Upload or record reference audio first" | |
| if not text.strip(): | |
| return None, "Enter text to synthesize" | |
| audio, sr = generate_speech(text, prompt_audio) | |
| # Save output | |
| os.makedirs("outputs", exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_path = f"outputs/mir_tts_{timestamp}.wav" | |
| sf.write(output_path, audio, samplerate=sr) | |
| return output_path, "Generation successful!" | |
| except Exception as e: | |
| logging.error(f"Voice clone error: {e}") | |
| return None, f"Error: {e}" | |
| def build_interface(): | |
| """Build Gradio interface.""" | |
| with gr.Blocks(title="MiraTTS Voice Cloning") as demo: | |
| gr.HTML("<h1 style='text-align:center;color:#2563eb;'>MiraTTS Voice Cloning</h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Reference Audio") | |
| prompt_upload = gr.Audio(sources="upload", type="filepath") | |
| prompt_record = gr.Audio(sources="microphone", type="filepath") | |
| with gr.Column(): | |
| gr.Markdown("### Text Input") | |
| text_input = gr.Textbox( | |
| placeholder="Enter text...", | |
| lines=4, | |
| value="Hello! This is a demonstration of MiraTTS" | |
| ) | |
| generate_btn = gr.Button("Generate Speech", variant="primary") | |
| with gr.Row(): | |
| output_audio = gr.Audio(label="Generated Speech", type="filepath", autoplay=True) | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| generate_btn.click( | |
| voice_clone_interface, | |
| inputs=[text_input, prompt_upload, prompt_record], | |
| outputs=[output_audio, status_text] | |
| ) | |
| def clear_all(): | |
| return None, None, "", None, "Ready for new generation" | |
| clear_btn = gr.Button("Clear All", variant="secondary") | |
| clear_btn.click( | |
| clear_all, | |
| outputs=[prompt_upload, prompt_record, text_input, output_audio, status_text] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |