SingingSDS / interface.py
jhansss's picture
Update interface.py
d91c8af verified
import time
import uuid
import gradio as gr
import spaces
import yaml
from characters import CHARACTERS
from pipeline import SingingDialoguePipeline
pipe = None
def _ensure_pipeline(config):
"""Ensure pipeline is initialized in GPU worker context."""
global pipe
if pipe is None:
pipe = SingingDialoguePipeline(config)
@spaces.GPU(duration=120)
def run_pipeline(audio_path, config, svs_model_info, character_prompt, current_voice):
global pipe
if not audio_path:
return gr.update(value=None), gr.update(value=None), None
_ensure_pipeline(config)
tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
results = pipe.run(
audio_path,
svs_model_info["lang"],
character_prompt,
current_voice,
output_audio_path=tmp_file,
)
formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
return (
gr.update(value=formatted_logs),
gr.update(value=results["output_audio_path"]),
results,
)
@spaces.GPU(duration=120)
def update_metrics(audio_path, config, results_data):
global pipe
if not audio_path or not results_data:
return gr.update(value="")
_ensure_pipeline(config)
results = pipe.evaluate(audio_path, **results_data)
results.update(results_data.get("metrics", {}))
formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
return gr.update(value=formatted_metrics)
@spaces.GPU(duration=120)
def update_asr_model_in_pipeline(config, asr_model):
_ensure_pipeline(config)
pipe.set_asr_model(asr_model)
return gr.update(value=asr_model)
@spaces.GPU(duration=120)
def update_llm_model_in_pipeline(config, llm_model):
_ensure_pipeline(config)
pipe.set_llm_model(llm_model)
return gr.update(value=llm_model)
@spaces.GPU(duration=120)
def update_svs_model_in_pipeline(config, svs_model_path):
_ensure_pipeline(config)
pipe.set_svs_model(svs_model_path)
return gr.update()
@spaces.GPU(duration=120)
def update_melody_source_in_pipeline(config, melody_source):
_ensure_pipeline(config)
pipe.set_melody_controller(melody_source)
return gr.update(value=melody_source)
class GradioInterface:
def __init__(self, options_config: str, default_config: str):
self.options = self.load_config(options_config)
self.svs_model_map = {
model["id"]: model for model in self.options["svs_models"]
}
self.default_config = self.load_config(default_config)
self.character_info = CHARACTERS
self.current_character = self.default_config["character"]
self.current_svs_model = (
f"{self.default_config['language']}-{self.default_config['svs_model']}"
)
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
self.character_info[self.current_character].default_voice
]
self.results = None
def load_config(self, path: str):
with open(path, "r") as f:
return yaml.safe_load(f)
def create_interface(self) -> gr.Blocks:
try:
with gr.Blocks(title="SingingSDS") as demo:
gr.Markdown("# SingingSDS: Role-Playing Singing Spoken Dialogue System")
with gr.Row():
with gr.Column(scale=1):
character_image = gr.Image(
self.character_info[self.current_character].image_path,
label="Character",
show_label=False,
)
with gr.Column(scale=2):
mic_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Speak to the character",
)
interaction_log = gr.Textbox(
label="Interaction Log", lines=3, interactive=False
)
audio_output = gr.Audio(
label="Character's Response",
type="filepath",
autoplay=True,
interactive=False,
)
with gr.Row():
metrics_button = gr.Button(
"Evaluate Metrics", variant="secondary"
)
metrics_output = gr.Textbox(
label="Evaluation Results", lines=3, interactive=False
)
gr.Markdown("## Configuration")
with gr.Row():
with gr.Column():
character_radio = gr.Radio(
label="Character Role",
choices=list(self.character_info.keys()),
value=self.default_config["character"],
)
with gr.Row():
asr_radio = gr.Radio(
label="ASR Model",
choices=[
(model["name"], model["id"])
for model in self.options["asr_models"]
],
value=self.default_config["asr_model"],
)
with gr.Row():
llm_radio = gr.Radio(
label="LLM Model",
choices=[
(model["name"], model["id"])
for model in self.options["llm_models"]
],
value=self.default_config["llm_model"],
)
with gr.Column():
with gr.Row():
melody_radio = gr.Radio(
label="Melody Source",
choices=[
(source["name"], source["id"])
for source in self.options["melody_sources"]
],
value=self.default_config["melody_source"],
)
with gr.Row():
svs_radio = gr.Radio(
label="SVS Model",
choices=[
(model["name"], model["id"])
for model in self.options["svs_models"]
],
value=self.current_svs_model,
)
with gr.Row():
voice_radio = gr.Radio(
label="Singing voice",
choices=list(
self.svs_model_map[self.current_svs_model][
"voices"
].keys()
),
value=self.character_info[
self.current_character
].default_voice,
)
character_radio.change(
fn=self.update_character,
inputs=character_radio,
outputs=[character_image, voice_radio],
)
asr_radio.change(
fn=self.update_asr_model, inputs=asr_radio, outputs=asr_radio
)
llm_radio.change(
fn=self.update_llm_model, inputs=llm_radio, outputs=llm_radio
)
svs_radio.change(
fn=self.update_svs_model,
inputs=svs_radio,
outputs=[svs_radio, voice_radio],
)
melody_radio.change(
fn=self.update_melody_source,
inputs=melody_radio,
outputs=melody_radio,
)
voice_radio.change(
fn=self.update_voice, inputs=voice_radio, outputs=voice_radio
)
mic_input.change(
fn=self._run_pipeline_wrapper,
inputs=mic_input,
outputs=[interaction_log, audio_output],
)
metrics_button.click(
fn=self._update_metrics_wrapper,
inputs=audio_output,
outputs=[metrics_output],
)
gr.Markdown(
"<div style='text-align: right; font-size: 12px; color: #666; margin-top: 20px;'>Yaoyin character illustration by Zihe Zhou</div>",
elem_classes="footer"
)
return demo
except Exception:
import traceback
print(traceback.format_exc())
return gr.Blocks()
def update_character(self, character):
self.current_character = character
character_voice = self.character_info[self.current_character].default_voice
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
character_voice
]
return gr.update(value=self.character_info[character].image_path), gr.update(
value=character_voice
)
def update_asr_model(self, asr_model):
self.default_config["asr_model"] = asr_model
return update_asr_model_in_pipeline(self.default_config, asr_model)
def update_llm_model(self, llm_model):
self.default_config["llm_model"] = llm_model
return update_llm_model_in_pipeline(self.default_config, llm_model)
def update_svs_model(self, svs_model):
self.current_svs_model = svs_model
character_voice = self.character_info[self.current_character].default_voice
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
character_voice
]
svs_model_path = self.svs_model_map[self.current_svs_model]["model_path"]
self.default_config["svs_model"] = svs_model_path
update_svs_model_in_pipeline(self.default_config, svs_model_path)
return (
gr.update(value=svs_model),
gr.update(
choices=list(
self.svs_model_map[self.current_svs_model]["voices"].keys()
),
value=character_voice,
),
)
def update_melody_source(self, melody_source):
self.current_melody_source = melody_source
self.default_config["melody_source"] = melody_source
return update_melody_source_in_pipeline(self.default_config, melody_source)
def update_voice(self, voice):
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
return gr.update(value=voice)
def _run_pipeline_wrapper(self, audio_path):
log_update, audio_update, pipeline_results = run_pipeline(
audio_path,
self.default_config,
self.svs_model_map[self.current_svs_model],
self.character_info[self.current_character].prompt,
self.current_voice,
)
if pipeline_results:
self.results = pipeline_results
return log_update, audio_update
def _update_metrics_wrapper(self, audio_path):
return update_metrics(audio_path, self.default_config, self.results or {})