KaniTTS / app.py
jblast94's picture
Update app.py
6fd45d2 verified
raw
history blame
3.71 kB
import gradio as gr
import torch
import os
# You must use the exact same model name as your repo
MODEL_ID = "nineninesix/Kani-TTS-370m"
@spaces.GPU
def generate_speech(text: str, model_choice: str, speaker_display: str):
if not text.strip():
return "Please enter text for speech generation.", None
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- This is the key part to load a specific model ---
if model_choice not in MODELS:
return f"Model '{model_choice}' not found.", None
selected_model = MODELS[model_choice]
# --- This part handles speakers ---
cfg = selected_model[1] # Model config
speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
if speaker_display and speaker_map:
speaker_id = speaker_map.get(speaker_display)
else:
speaker_id = None
print(f"Generating speech with {model_choice}...")
# --- Use the specific part of the model for generation ---
model_to_generate = selected_model[0]
audio, _, time_report = model_to_generate.run_model(
text=text,
speaker_id=speaker_id,
temperature=0.7,
repetition_penalty=1.2,
max_tokens=1024
)
sample_rate = 22050
print("Speech generation completed!")
return (sample_rate, audio), time_report
def load_models():
global MODELS
if not MODELS:
print("Loading models into GPU memory...")
from transformers import AutoModel
model_path = MODEL_ID
# Load both the main model and its config
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
MODELS = {
"Kani TTS 370M": (model, config)
}
print(f"Models loaded. Available speakers: {list(config.speaker_id.keys()) if config.speaker_id else []}")
return MODELS
# --- Gradio interface setup ---
MODELS = load_models()
with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
gr.Markdown("# 😻 KaniTTS: Fast and Expressive Speech Generation Model")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Selected Model"
)
# --- Speaker selector (populated on model load) ---
all_speakers = list(MODELS[list(MODELS.keys())[0]][1].speaker_id.keys()) if MODELS and MODELS[list(MODELS.keys())[0]][1] and MODELS[list(MODELS.keys())[0]][1].speaker_id else []
speaker_dropdown = gr.Dropdown(
choices=all_speakers,
value=None,
label="Speaker",
visible=True,
allow_custom_value=True
)
text_input = gr.Textbox(label="Text", lines=5)
generate_btn = gr.Button("Generate Speech", variant="primary")
audio_output = gr.Audio(label="Generated Audio", type="numpy")
# --- Event handlers ---
model_dropdown.change(
fn=lambda choice: gr.update(choices=list(MODELS[choice][1].speaker_id.keys()), value=None, visible=True) if MODELS and MODELS[choice][1].speaker_id else gr.update(visible=False),
inputs=[model_dropdown],
outputs=[speaker_dropdown]
)
generate_btn.click(
fn=generate_speech,
inputs=[text_input, model_dropdown, speaker_dropdown],
outputs=[audio_output]
)
# --- This is the API enabling line ---
demo.queue().launch(show_api=True)