KaniTTS / app.py
jblast94's picture
Update app.py
24c936f verified
raw
history blame
4.65 kB
import gradio as gr
import torch
import os
# You must use the exact same model name as your repo
MODEL_ID = "nineninesix/Kani-TTS-370m"
# --- Global variable to store loaded models ---
MODELS = {}
@spaces.GPU
def load_models():
"""Load models into GPU memory and store in a global variable."""
global MODELS
if not MODELS:
print("Loading models into GPU memory...")
from transformers import AutoModel, AutoConfig
model_path = MODEL_ID
# Load both the main model and its configuration
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# Store the loaded model and its configuration in the global variable
MODELS = {
"Kani TTS 370M": (model, config)
}
print(f"Models loaded. Available speakers: {list(config.speaker_id.keys()) if config.speaker_id else []}")
return MODELS
# --- Define a separate function for updating the stats display ---
def update_stats_display():
"""This function gets the agent's stats and returns a formatted string for Gradio."""
# This assumes 'agent' is a global instance of your ConversationalAgent class
stats_text = agent.get_memory_stats()
return gr.Markdown(f"### 📊 Memory Stats\n{stats_text}")
def generate_speech(text: str, model_choice: str, speaker_display: str):
"""Generate speech using the selected model."""
if not text.strip():
return "Please enter text for speech generation.", None
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Ensure models are loaded
if not MODELS:
load_models()
# Get the selected model from the global variable
if model_choice not in MODELS:
return f"Model '{model_choice}' not found.", None
selected_model = MODELS[model_choice]
# --- This is the key part to load a specific model ---
model_to_generate = selected_model[0]
cfg = selected_model[1] # Model config
speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
if speaker_display and speaker_map:
speaker_id = speaker_map.get(speaker_display)
else:
speaker_id = None
print(f"Generating speech with {model_choice}...")
# --- Use the specific part of the model for generation ---
audio, _, time_report = model_to_generate.run_model(
text=text,
speaker_id=speaker_id,
temperature=0.7,
repetition_penalty=1.2,
max_tokens=1024
)
sample_rate = 22050
print("Speech generation completed!")
return (sample_rate, audio), time_report
# --- Create and configure the Gradio interface ---
MODELS = load_models()
with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
gr.Markdown("# 😻 KaniTTS: Fast and Expressive Speech Generation Model")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Selected Model"
)
# --- Speaker selector (populated on model load) ---
all_speakers = []
if MODELS and list(MODELS.keys())[0] and MODELS[list(MODELS.keys())[0]][1]:
all_speakers.extend(list(MODELS[list(MODELS.keys())[0]][1].speaker_id.keys()))
all_speakers = sorted(list(set(all_speakers)))
speaker_dropdown = gr.Dropdown(
choices=all_speakers,
value=None,
label="Speaker",
visible=True,
allow_custom_value=True
)
text_input = gr.Textbox(label="Text", lines=5)
generate_btn = gr.Button("Generate Speech", variant="primary")
audio_output = gr.Audio(label="Generated Audio", type="numpy")
# --- Define the event to update the speakers when the model changes ---
model_dropdown.change(
fn=lambda choice: gr.update(choices=list(MODELS[choice][1].speaker_id.keys()), value=None, visible=True) if MODELS and MODELS[choice][1].speaker_id else gr.update(visible=False),
inputs=[model_dropdown],
outputs=[speaker_dropdown]
)
# --- Wire up the main generation button ---
generate_btn.click(
fn=generate_speech,
inputs=[text_input, model_dropdown, speaker_dropdown],
outputs=[audio_output]
)
# --- This is the API-enabling line ---
demo.queue().launch(show_api=True)