hans00's picture
Use grammar to avoid generation error
86b23e4 unverified
import gradio as gr
# Import alias module before outetts to setup whisper redirection
import alias as _alias
import outetts
import json
import tempfile
import hashlib
import os
import re
from typing import Optional
from llama_cpp.llama import LlamaGrammar
from outetts.version.interface import InterfaceLLAMACPP
from outetts.models.info import MODEL_INFO
from outetts.utils import helpers
from huggingface_hub import hf_hub_download
import torch
from transformers import BitsAndBytesConfig
import spaces
import numpy as np
from collections import OrderedDict
# Available OuteTTS models based on the documentation
MODELS = {v.value: v for _, v in outetts.Models.__members__.items()}
MODEL_QUANTIZATION = {
outetts.Models.VERSION_0_1_SIZE_350M: outetts.LlamaCppQuantization.FP16,
outetts.Models.VERSION_0_2_SIZE_500M: outetts.LlamaCppQuantization.FP16,
outetts.Models.VERSION_0_3_SIZE_500M: outetts.LlamaCppQuantization.FP16,
}
# Cache for speaker profiles to avoid re-transcribing the same audio
speaker_cache = {}
SPLIT_SYMBOL = {
outetts.InterfaceVersion.V1: '<|space|>',
outetts.InterfaceVersion.V2: '<|space|>',
outetts.InterfaceVersion.V3: ' ',
}
def word_to_grammar(word):
if all(ord(c) < 128 for c in word):
return f'"{word}"'
return f'[{"".join(OrderedDict.fromkeys(word))}]+'
# patch InterfaceLLAMACPP, inject new _generate method
InterfaceLLAMACPP._orig_generate = InterfaceLLAMACPP._generate
def ggml_generate(self, input_ids, config):
tokenizer = self.prompt_processor.tokenizer
split = SPLIT_SYMBOL.get(self.config.interface_version, ' ')
prompt = tokenizer.decode(input_ids, skip_special_tokens=False)
prompt_no_special = tokenizer.decode(input_ids, skip_special_tokens=True).strip()
if '<|text_start|>' not in prompt:
return self._orig_generate(input_ids, config)
speaker_text_last = prompt_no_special.split('\n').pop()
text = prompt[prompt.index('<|text_start|>')+14:prompt.index('<|text_end|>')]
gen_text = text[text.index(speaker_text_last)+len(speaker_text_last):].strip(split) if speaker_text_last in text else text
words = [word_to_grammar(word) for word in gen_text.split(split)]
if self.config.interface_version == outetts.InterfaceVersion.V2:
config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\
root ::= NL? {' audioBlock '.join(words)} audioEnd NL EOS?
audioBlock ::= TIME CODE* space NL?
TEXT ::= [A-Za-z0-9 .,?!]+
EOS ::= "<|im_end|>"
emotionStart ::= "<|emotion_start|>"
emotionEnd ::= "<|emotion_end|>"
audioEnd ::= "<|audio_end|>"
space ::= "<|space|>"
WORD ::= {' | '.join(words)}
NL ::= [\\n]
TIME ::= "<|t_" DECIMAL "|>"
CODE ::= "<|" DIGITS "|>"
DIGITS ::= [0-9]+
DECIMAL ::= [0-9]+ "." [0-9]+
punch ::= "<|" [a-z_]+ "|>"
""")
elif self.config.interface_version == outetts.InterfaceVersion.V3:
config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\
root ::= leadWord wordBlock* audioEnd NL EOS?
leadWord ::= WORD audioBlock
wordBlock ::= wordStart WORD audioBlock
audioBlock ::= codeBlock wordEnd NL?
codeBlock ::= features TIME energy spectralCentroid pitch CODE CODES*
TEXT ::= [A-Za-z0-9.,!?]+
EOS ::= "<|im_end|>"
audioEnd ::= "<|audio_end|>"
wordStart ::= "<|word_start|>"
wordEnd ::= "<|word_end|>"
features ::= "<|features|>"
energy ::= "<|energy_" DIGITS "|>"
spectralCentroid ::= "<|spectral_centroid_" DIGITS "|>"
pitch ::= "<|pitch_" DIGITS "|>"
WORD ::= {' | '.join(words)}
NL ::= [\\n]
TIME ::= "<|t_" DECIMAL "|>"
CODE ::= "<|code|>"
CODES ::= CODE1 CODE2
CODE1 ::= "<|c1_" DIGITS "|>"
CODE2 ::= "<|c2_" DIGITS "|>"
DIGITS ::= [0-9]+
DECIMAL ::= [0-9]+ "." [0-9]+
""")
return self._orig_generate(input_ids, config)
InterfaceLLAMACPP._generate = ggml_generate
def get_file_hash(file_path):
"""Calculate MD5 hash of a file for caching purposes."""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def try_ggml_model(model: outetts.Models, quantization: outetts.LlamaCppQuantization):
model_config = MODEL_INFO[model]
repo = f"OuteAI/{model.value}-GGUF"
filename = f"{model.value}-{quantization.value}.gguf"
model_path = hf_hub_download(
repo_id=repo,
filename=filename,
local_dir=os.path.join(helpers.get_cache_dir(), "gguf"),
local_files_only=False
)
generation_type = outetts.GenerationType.CHUNKED
# if model_config['interface_version'] == outetts.InterfaceVersion.V3:
# generation_type = outetts.GenerationType.GUIDED_WORDS
return outetts.ModelConfig(
model_path=model_path,
tokenizer_path=f"OuteAI/{model.value}",
backend=outetts.Backend.LLAMACPP,
n_gpu_layers=99,
verbose=False,
device=None,
dtype=None,
additional_model_config={},
audio_codec_path=None,
generation_type=generation_type,
**model_config
)
def get_interface(model_name: str):
"""Get interface instance for the model (no caching to avoid CUDA memory issues)."""
model = MODELS[model_name]
try:
quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q8_0)
config = try_ggml_model(model, quantization)
except:
has_cuda = torch.cuda.is_available()
model_config = MODEL_INFO[model]
config = outetts.ModelConfig(
model_path=f"OuteAI/{model_name}",
tokenizer_path=f"OuteAI/{model_name}",
backend=outetts.Backend.HF,
additional_model_config={
"device_map": "auto" if has_cuda else "cpu",
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_enable_fp32_cpu_offload=True
) if has_cuda else None,
},
**model_config
)
# Initialize the interface
interface = outetts.Interface(config=config)
return interface
def get_or_create_speaker(interface, audio_file):
"""Get speaker from cache or create new one if not cached."""
# Calculate file hash for caching
file_hash = get_file_hash(audio_file)
cache_key = f"{interface.config.interface_version}_{file_hash}"
# Check if speaker profile is already cached
if cache_key in speaker_cache:
print(f"✅ Using cached speaker profile for {os.path.basename(audio_file)}")
return json.loads(speaker_cache[cache_key])
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create new speaker profile
print(f"🔄 Creating new speaker profile for {os.path.basename(audio_file)}")
try:
speaker = interface.create_speaker(audio_file, whisper_model="large-v3-turbo", whisper_device=device)
# Cache the speaker profile
speaker_cache[cache_key] = json.dumps(speaker)
print(f"💾 Cached speaker profile ({len(speaker_cache)} total cached)")
return speaker
except Exception as e:
return f"❌ Error creating speaker profile: {str(e)}"
@spaces.GPU
def create_speaker_and_generate(model_name, audio_file, test_text: Optional[str] = None, temperature: float = 0.4):
"""Create speaker from audio and optionally generate test audio."""
if audio_file is None:
# Return default values for startup/caching purposes
return "Please upload an audio file to create a speaker profile.", None
# Get interface (no caching to avoid CUDA memory issues)
interface = get_interface(model_name)
# Get or create speaker profile (with caching)
speaker_result = get_or_create_speaker(interface, audio_file)
# Check if speaker_result is an error message
if isinstance(speaker_result, str) and speaker_result.startswith("❌"):
return speaker_result, None
# Convert speaker dict to formatted JSON
speaker_json = json.dumps(speaker_result, indent=2, ensure_ascii=False)
# Generate test audio if text is provided
generated_audio = None
if test_text and test_text.strip():
output = interface.generate(
config=outetts.GenerationConfig(
text=test_text,
speaker=speaker_result,
sampler_config=outetts.SamplerConfig(
temperature=temperature
),
max_length=MODEL_INFO[MODELS[model_name]]["max_seq_length"]
)
)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
output.save(f.name)
generated_audio = f.name
return speaker_json, generated_audio
example_text = "Hello, this is a test of the OuteTTS speaker profile."
# Create the Gradio interface
demo = gr.Interface(
fn=create_speaker_and_generate,
inputs=[
gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[-1],
label="Select OuteTTS Model",
info="Choose the model variant to use"
),
gr.Audio(
label="Upload Reference Audio (Max 20 seconds)",
type="filepath",
sources=["upload", "microphone"]
),
gr.Textbox(
label="Test Text (Optional)",
placeholder="Enter text to generate speech (leave empty to only create speaker profile)...",
lines=3,
value=None
),
gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.4,
label="Temperature",
info="Controls randomness in generation"
)
],
outputs=[
gr.Textbox(
label="Speaker Profile (JSON)",
lines=15,
max_lines=20,
show_copy_button=True
),
gr.Audio(
label="Generated Test Audio (if text provided)",
type="filepath"
)
],
title="🎙️ OuteTTS Speaker Creator",
description="Create and manage speaker profiles for OuteTTS text-to-speech synthesis. Upload audio to create a speaker profile, and optionally provide test text to generate sample audio.",
theme=gr.themes.Soft(),
examples=[
["OuteTTS-1.0-0.6B", None, example_text, 0.2],
["OuteTTS-0.3-500M", None, example_text, 0.2],
],
cache_examples=False,
flagging_mode="never"
)
if __name__ == "__main__":
# Launch with optimized configuration for HuggingFace Spaces
demo.launch(
server_name="0.0.0.0", # Allow external connections
server_port=7860,
share=False, # Set to True if you want a public link
show_api=True, # Show API documentation
show_error=True # Show detailed error messages
)