Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| # Import alias module before outetts to setup whisper redirection | |
| import alias as _alias | |
| import outetts | |
| import json | |
| import tempfile | |
| import hashlib | |
| import os | |
| import re | |
| from typing import Optional | |
| from llama_cpp.llama import LlamaGrammar | |
| from outetts.version.interface import InterfaceLLAMACPP | |
| from outetts.models.info import MODEL_INFO | |
| from outetts.utils import helpers | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| from transformers import BitsAndBytesConfig | |
| import spaces | |
| import numpy as np | |
| from collections import OrderedDict | |
| # Available OuteTTS models based on the documentation | |
| MODELS = {v.value: v for _, v in outetts.Models.__members__.items()} | |
| MODEL_QUANTIZATION = { | |
| outetts.Models.VERSION_0_1_SIZE_350M: outetts.LlamaCppQuantization.FP16, | |
| outetts.Models.VERSION_0_2_SIZE_500M: outetts.LlamaCppQuantization.FP16, | |
| outetts.Models.VERSION_0_3_SIZE_500M: outetts.LlamaCppQuantization.FP16, | |
| } | |
| # Cache for speaker profiles to avoid re-transcribing the same audio | |
| speaker_cache = {} | |
| SPLIT_SYMBOL = { | |
| outetts.InterfaceVersion.V1: '<|space|>', | |
| outetts.InterfaceVersion.V2: '<|space|>', | |
| outetts.InterfaceVersion.V3: ' ', | |
| } | |
| def word_to_grammar(word): | |
| if all(ord(c) < 128 for c in word): | |
| return f'"{word}"' | |
| return f'[{"".join(OrderedDict.fromkeys(word))}]+' | |
| # patch InterfaceLLAMACPP, inject new _generate method | |
| InterfaceLLAMACPP._orig_generate = InterfaceLLAMACPP._generate | |
| def ggml_generate(self, input_ids, config): | |
| tokenizer = self.prompt_processor.tokenizer | |
| split = SPLIT_SYMBOL.get(self.config.interface_version, ' ') | |
| prompt = tokenizer.decode(input_ids, skip_special_tokens=False) | |
| prompt_no_special = tokenizer.decode(input_ids, skip_special_tokens=True).strip() | |
| if '<|text_start|>' not in prompt: | |
| return self._orig_generate(input_ids, config) | |
| speaker_text_last = prompt_no_special.split('\n').pop() | |
| text = prompt[prompt.index('<|text_start|>')+14:prompt.index('<|text_end|>')] | |
| gen_text = text[text.index(speaker_text_last)+len(speaker_text_last):].strip(split) if speaker_text_last in text else text | |
| words = [word_to_grammar(word) for word in gen_text.split(split)] | |
| if self.config.interface_version == outetts.InterfaceVersion.V2: | |
| config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\ | |
| root ::= NL? {' audioBlock '.join(words)} audioEnd NL EOS? | |
| audioBlock ::= TIME CODE* space NL? | |
| TEXT ::= [A-Za-z0-9 .,?!]+ | |
| EOS ::= "<|im_end|>" | |
| emotionStart ::= "<|emotion_start|>" | |
| emotionEnd ::= "<|emotion_end|>" | |
| audioEnd ::= "<|audio_end|>" | |
| space ::= "<|space|>" | |
| WORD ::= {' | '.join(words)} | |
| NL ::= [\\n] | |
| TIME ::= "<|t_" DECIMAL "|>" | |
| CODE ::= "<|" DIGITS "|>" | |
| DIGITS ::= [0-9]+ | |
| DECIMAL ::= [0-9]+ "." [0-9]+ | |
| punch ::= "<|" [a-z_]+ "|>" | |
| """) | |
| elif self.config.interface_version == outetts.InterfaceVersion.V3: | |
| config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\ | |
| root ::= leadWord wordBlock* audioEnd NL EOS? | |
| leadWord ::= WORD audioBlock | |
| wordBlock ::= wordStart WORD audioBlock | |
| audioBlock ::= codeBlock wordEnd NL? | |
| codeBlock ::= features TIME energy spectralCentroid pitch CODE CODES* | |
| TEXT ::= [A-Za-z0-9.,!?]+ | |
| EOS ::= "<|im_end|>" | |
| audioEnd ::= "<|audio_end|>" | |
| wordStart ::= "<|word_start|>" | |
| wordEnd ::= "<|word_end|>" | |
| features ::= "<|features|>" | |
| energy ::= "<|energy_" DIGITS "|>" | |
| spectralCentroid ::= "<|spectral_centroid_" DIGITS "|>" | |
| pitch ::= "<|pitch_" DIGITS "|>" | |
| WORD ::= {' | '.join(words)} | |
| NL ::= [\\n] | |
| TIME ::= "<|t_" DECIMAL "|>" | |
| CODE ::= "<|code|>" | |
| CODES ::= CODE1 CODE2 | |
| CODE1 ::= "<|c1_" DIGITS "|>" | |
| CODE2 ::= "<|c2_" DIGITS "|>" | |
| DIGITS ::= [0-9]+ | |
| DECIMAL ::= [0-9]+ "." [0-9]+ | |
| """) | |
| return self._orig_generate(input_ids, config) | |
| InterfaceLLAMACPP._generate = ggml_generate | |
| def get_file_hash(file_path): | |
| """Calculate MD5 hash of a file for caching purposes.""" | |
| hash_md5 = hashlib.md5() | |
| with open(file_path, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest() | |
| def try_ggml_model(model: outetts.Models, quantization: outetts.LlamaCppQuantization): | |
| model_config = MODEL_INFO[model] | |
| repo = f"OuteAI/{model.value}-GGUF" | |
| filename = f"{model.value}-{quantization.value}.gguf" | |
| model_path = hf_hub_download( | |
| repo_id=repo, | |
| filename=filename, | |
| local_dir=os.path.join(helpers.get_cache_dir(), "gguf"), | |
| local_files_only=False | |
| ) | |
| generation_type = outetts.GenerationType.CHUNKED | |
| # if model_config['interface_version'] == outetts.InterfaceVersion.V3: | |
| # generation_type = outetts.GenerationType.GUIDED_WORDS | |
| return outetts.ModelConfig( | |
| model_path=model_path, | |
| tokenizer_path=f"OuteAI/{model.value}", | |
| backend=outetts.Backend.LLAMACPP, | |
| n_gpu_layers=99, | |
| verbose=False, | |
| device=None, | |
| dtype=None, | |
| additional_model_config={}, | |
| audio_codec_path=None, | |
| generation_type=generation_type, | |
| **model_config | |
| ) | |
| def get_interface(model_name: str): | |
| """Get interface instance for the model (no caching to avoid CUDA memory issues).""" | |
| model = MODELS[model_name] | |
| try: | |
| quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q8_0) | |
| config = try_ggml_model(model, quantization) | |
| except: | |
| has_cuda = torch.cuda.is_available() | |
| model_config = MODEL_INFO[model] | |
| config = outetts.ModelConfig( | |
| model_path=f"OuteAI/{model_name}", | |
| tokenizer_path=f"OuteAI/{model_name}", | |
| backend=outetts.Backend.HF, | |
| additional_model_config={ | |
| "device_map": "auto" if has_cuda else "cpu", | |
| "quantization_config": BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| llm_int8_enable_fp32_cpu_offload=True | |
| ) if has_cuda else None, | |
| }, | |
| **model_config | |
| ) | |
| # Initialize the interface | |
| interface = outetts.Interface(config=config) | |
| return interface | |
| def get_or_create_speaker(interface, audio_file): | |
| """Get speaker from cache or create new one if not cached.""" | |
| # Calculate file hash for caching | |
| file_hash = get_file_hash(audio_file) | |
| cache_key = f"{interface.config.interface_version}_{file_hash}" | |
| # Check if speaker profile is already cached | |
| if cache_key in speaker_cache: | |
| print(f"✅ Using cached speaker profile for {os.path.basename(audio_file)}") | |
| return json.loads(speaker_cache[cache_key]) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Create new speaker profile | |
| print(f"🔄 Creating new speaker profile for {os.path.basename(audio_file)}") | |
| try: | |
| speaker = interface.create_speaker(audio_file, whisper_model="large-v3-turbo", whisper_device=device) | |
| # Cache the speaker profile | |
| speaker_cache[cache_key] = json.dumps(speaker) | |
| print(f"💾 Cached speaker profile ({len(speaker_cache)} total cached)") | |
| return speaker | |
| except Exception as e: | |
| return f"❌ Error creating speaker profile: {str(e)}" | |
| def create_speaker_and_generate(model_name, audio_file, test_text: Optional[str] = None, temperature: float = 0.4): | |
| """Create speaker from audio and optionally generate test audio.""" | |
| if audio_file is None: | |
| # Return default values for startup/caching purposes | |
| return "Please upload an audio file to create a speaker profile.", None | |
| # Get interface (no caching to avoid CUDA memory issues) | |
| interface = get_interface(model_name) | |
| # Get or create speaker profile (with caching) | |
| speaker_result = get_or_create_speaker(interface, audio_file) | |
| # Check if speaker_result is an error message | |
| if isinstance(speaker_result, str) and speaker_result.startswith("❌"): | |
| return speaker_result, None | |
| # Convert speaker dict to formatted JSON | |
| speaker_json = json.dumps(speaker_result, indent=2, ensure_ascii=False) | |
| # Generate test audio if text is provided | |
| generated_audio = None | |
| if test_text and test_text.strip(): | |
| output = interface.generate( | |
| config=outetts.GenerationConfig( | |
| text=test_text, | |
| speaker=speaker_result, | |
| sampler_config=outetts.SamplerConfig( | |
| temperature=temperature | |
| ), | |
| max_length=MODEL_INFO[MODELS[model_name]]["max_seq_length"] | |
| ) | |
| ) | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| output.save(f.name) | |
| generated_audio = f.name | |
| return speaker_json, generated_audio | |
| example_text = "Hello, this is a test of the OuteTTS speaker profile." | |
| # Create the Gradio interface | |
| demo = gr.Interface( | |
| fn=create_speaker_and_generate, | |
| inputs=[ | |
| gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[-1], | |
| label="Select OuteTTS Model", | |
| info="Choose the model variant to use" | |
| ), | |
| gr.Audio( | |
| label="Upload Reference Audio (Max 20 seconds)", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ), | |
| gr.Textbox( | |
| label="Test Text (Optional)", | |
| placeholder="Enter text to generate speech (leave empty to only create speaker profile)...", | |
| lines=3, | |
| value=None | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.4, | |
| label="Temperature", | |
| info="Controls randomness in generation" | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| label="Speaker Profile (JSON)", | |
| lines=15, | |
| max_lines=20, | |
| show_copy_button=True | |
| ), | |
| gr.Audio( | |
| label="Generated Test Audio (if text provided)", | |
| type="filepath" | |
| ) | |
| ], | |
| title="🎙️ OuteTTS Speaker Creator", | |
| description="Create and manage speaker profiles for OuteTTS text-to-speech synthesis. Upload audio to create a speaker profile, and optionally provide test text to generate sample audio.", | |
| theme=gr.themes.Soft(), | |
| examples=[ | |
| ["OuteTTS-1.0-0.6B", None, example_text, 0.2], | |
| ["OuteTTS-0.3-500M", None, example_text, 0.2], | |
| ], | |
| cache_examples=False, | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| # Launch with optimized configuration for HuggingFace Spaces | |
| demo.launch( | |
| server_name="0.0.0.0", # Allow external connections | |
| server_port=7860, | |
| share=False, # Set to True if you want a public link | |
| show_api=True, # Show API documentation | |
| show_error=True # Show detailed error messages | |
| ) | |