|
|
import gradio as gr
|
|
|
import os
|
|
|
import argparse
|
|
|
import torch
|
|
|
import logging
|
|
|
import threading
|
|
|
from datetime import datetime
|
|
|
import torchaudio
|
|
|
import librosa
|
|
|
import soundfile as sf
|
|
|
|
|
|
|
|
|
try:
|
|
|
import spaces
|
|
|
ZEROGPU_AVAILABLE = True
|
|
|
except ImportError:
|
|
|
ZEROGPU_AVAILABLE = False
|
|
|
|
|
|
class spaces:
|
|
|
@staticmethod
|
|
|
def GPU(duration=10):
|
|
|
def decorator(func):
|
|
|
return func
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
from tokenizer import StepAudioTokenizer
|
|
|
from tts import StepAudioTTS
|
|
|
from model_loader import ModelSource
|
|
|
from config.edit_config import get_supported_edit_types
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
encoder = None
|
|
|
common_tts_engine = None
|
|
|
args_global = None
|
|
|
_model_lock = threading.Lock()
|
|
|
|
|
|
def initialize_models():
|
|
|
"""Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
|
|
|
global encoder, common_tts_engine, args_global
|
|
|
|
|
|
|
|
|
if common_tts_engine is not None:
|
|
|
return
|
|
|
|
|
|
|
|
|
with _model_lock:
|
|
|
|
|
|
if common_tts_engine is not None:
|
|
|
return
|
|
|
|
|
|
if args_global is None:
|
|
|
raise RuntimeError("Global args not set. Cannot initialize models.")
|
|
|
|
|
|
try:
|
|
|
logger.info("π Initializing models inside GPU context (first call)...")
|
|
|
|
|
|
|
|
|
source_mapping = {
|
|
|
"auto": ModelSource.AUTO,
|
|
|
"local": ModelSource.LOCAL,
|
|
|
"modelscope": ModelSource.MODELSCOPE,
|
|
|
"huggingface": ModelSource.HUGGINGFACE
|
|
|
}
|
|
|
model_source = source_mapping[args_global.model_source]
|
|
|
|
|
|
|
|
|
encoder = StepAudioTokenizer(
|
|
|
os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
|
|
|
model_source=model_source,
|
|
|
funasr_model_id=args_global.tokenizer_model_id
|
|
|
)
|
|
|
logger.info("β StepAudioTokenizer loaded")
|
|
|
|
|
|
|
|
|
common_tts_engine = StepAudioTTS(
|
|
|
os.path.join(args_global.model_path, "Step-Audio-EditX"),
|
|
|
encoder,
|
|
|
model_source=model_source,
|
|
|
tts_model_id=args_global.tts_model_id
|
|
|
)
|
|
|
logger.info("β StepCommonAudioTTS loaded")
|
|
|
print("Models initialized inside GPU context.")
|
|
|
|
|
|
if ZEROGPU_AVAILABLE:
|
|
|
logger.info("π‘ Models loaded inside GPU context - ready for inference")
|
|
|
else:
|
|
|
logger.info("π‘ Models loaded - ready for inference")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"β Error loading models: {e}")
|
|
|
raise
|
|
|
|
|
|
def get_model_config():
|
|
|
"""Get model configuration without initializing GPU models"""
|
|
|
if args_global is None:
|
|
|
raise RuntimeError("Global args not set. Cannot get model config.")
|
|
|
|
|
|
return {
|
|
|
"encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
|
|
|
"tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"),
|
|
|
"model_source": args_global.model_source,
|
|
|
"tokenizer_model_id": args_global.tokenizer_model_id,
|
|
|
"tts_model_id": args_global.tts_model_id
|
|
|
}
|
|
|
|
|
|
def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info):
|
|
|
"""Dynamic GPU duration based on whether models need initialization"""
|
|
|
global common_tts_engine
|
|
|
|
|
|
if common_tts_engine is None:
|
|
|
|
|
|
return 300
|
|
|
else:
|
|
|
|
|
|
return 120
|
|
|
|
|
|
@spaces.GPU(duration=get_gpu_duration)
|
|
|
def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info):
|
|
|
"""Process audio using GPU (models are loaded inside GPU context to avoid main process errors)"""
|
|
|
global common_tts_engine
|
|
|
|
|
|
|
|
|
if common_tts_engine is None:
|
|
|
print("Initializing common_tts_engine inside GPU context...")
|
|
|
logger.info("π― GPU allocated for 300s (first call with model loading)...")
|
|
|
initialize_models()
|
|
|
logger.info("β
Models loaded successfully inside GPU context")
|
|
|
else:
|
|
|
print("common_tts_engine already initialized.")
|
|
|
logger.info("π― GPU allocated for 120s (inference with loaded models)...")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if task_type == "clone":
|
|
|
output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text)
|
|
|
else:
|
|
|
output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text)
|
|
|
|
|
|
logger.info("β
Audio processing completed")
|
|
|
return output_audio, sr
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"β Audio processing failed: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def save_audio(audio_type, audio_data, sr, tmp_dir):
|
|
|
"""Save audio data to a temporary file with timestamp"""
|
|
|
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav")
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
|
|
try:
|
|
|
if isinstance(audio_data, torch.Tensor):
|
|
|
torchaudio.save(save_path, audio_data, sr)
|
|
|
else:
|
|
|
sf.write(save_path, audio_data, sr)
|
|
|
logger.debug(f"Audio saved to: {save_path}")
|
|
|
return save_path
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to save audio: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
class EditxTab:
|
|
|
"""Audio editing and voice cloning interface tab"""
|
|
|
|
|
|
def __init__(self, args):
|
|
|
self.args = args
|
|
|
self.edit_type_list = list(get_supported_edit_types().keys())
|
|
|
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
|
|
|
|
def history_messages_to_show(self, messages):
|
|
|
"""Convert message history to gradio chatbot format"""
|
|
|
show_msgs = []
|
|
|
for message in messages:
|
|
|
edit_type = message['edit_type']
|
|
|
edit_info = message['edit_info']
|
|
|
source_text = message['source_text']
|
|
|
target_text = message['target_text']
|
|
|
raw_audio_part = message['raw_wave']
|
|
|
edit_audio_part = message['edit_wave']
|
|
|
type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}"
|
|
|
show_msgs.extend([
|
|
|
{"role": "user", "content": f"δ»»ε‘η±»εοΌ{type_str}\nζζ¬οΌ{source_text}"},
|
|
|
{"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)},
|
|
|
{"role": "assistant", "content": f"θΎεΊι³ι’οΌ\nζζ¬οΌ{target_text}"},
|
|
|
{"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)}
|
|
|
])
|
|
|
return show_msgs
|
|
|
|
|
|
def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
|
|
|
"""Generate cloned audio (models are loaded on first GPU call)"""
|
|
|
self.logger.info("Starting voice cloning process")
|
|
|
state['history_audio'] = []
|
|
|
state['history_messages'] = []
|
|
|
|
|
|
|
|
|
if not prompt_text_input or prompt_text_input.strip() == "":
|
|
|
error_msg = "[Error] Uploaded text cannot be empty."
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
if not prompt_audio_input:
|
|
|
error_msg = "[Error] Uploaded audio cannot be empty."
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
if not generated_text or generated_text.strip() == "":
|
|
|
error_msg = "[Error] Clone content cannot be empty."
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
if edit_type != "clone":
|
|
|
error_msg = "[Error] CLONE button must use clone task."
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
|
|
|
try:
|
|
|
|
|
|
output_audio, output_sr = process_audio_with_gpu(
|
|
|
prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info
|
|
|
)
|
|
|
|
|
|
if output_audio is not None and output_sr is not None:
|
|
|
|
|
|
if isinstance(output_audio, torch.Tensor):
|
|
|
audio_numpy = output_audio.cpu().numpy().squeeze()
|
|
|
else:
|
|
|
audio_numpy = output_audio
|
|
|
|
|
|
|
|
|
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
|
|
|
|
|
|
|
|
|
cur_assistant_msg = {
|
|
|
"edit_type": edit_type,
|
|
|
"edit_info": edit_info,
|
|
|
"source_text": prompt_text_input,
|
|
|
"target_text": generated_text,
|
|
|
"raw_wave": (input_sample_rate, input_audio_data_numpy),
|
|
|
"edit_wave": (output_sr, audio_numpy),
|
|
|
}
|
|
|
state["history_audio"].append((output_sr, audio_numpy, generated_text))
|
|
|
state["history_messages"].append(cur_assistant_msg)
|
|
|
|
|
|
show_msgs = self.history_messages_to_show(state["history_messages"])
|
|
|
self.logger.info("Voice cloning completed successfully")
|
|
|
return show_msgs, state
|
|
|
else:
|
|
|
error_msg = "[Error] Clone failed"
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = f"[Error] Clone failed: {str(e)}"
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
|
|
|
def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
|
|
|
"""Generate edited audio (models are loaded on first GPU call)"""
|
|
|
self.logger.info("Starting audio editing process")
|
|
|
|
|
|
|
|
|
if not prompt_audio_input:
|
|
|
error_msg = "[Error] Uploaded audio cannot be empty."
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
|
|
|
try:
|
|
|
|
|
|
if len(state["history_audio"]) == 0:
|
|
|
|
|
|
audio_to_edit = prompt_audio_input
|
|
|
text_to_use = prompt_text_input
|
|
|
self.logger.debug("Using prompt audio, no history found")
|
|
|
else:
|
|
|
|
|
|
sample_rate, audio_numpy, previous_text = state["history_audio"][-1]
|
|
|
temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir)
|
|
|
audio_to_edit = temp_path
|
|
|
text_to_use = previous_text
|
|
|
self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
|
|
|
|
|
|
|
|
|
if edit_type not in {"paralinguistic"}:
|
|
|
generated_text = text_to_use
|
|
|
|
|
|
|
|
|
output_audio, output_sr = process_audio_with_gpu(
|
|
|
audio_to_edit, text_to_use, generated_text, edit_type, edit_info
|
|
|
)
|
|
|
|
|
|
if output_audio is not None and output_sr is not None:
|
|
|
|
|
|
if isinstance(output_audio, torch.Tensor):
|
|
|
audio_numpy = output_audio.cpu().numpy().squeeze()
|
|
|
else:
|
|
|
audio_numpy = output_audio
|
|
|
|
|
|
|
|
|
if len(state["history_audio"]) == 0:
|
|
|
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
|
|
|
else:
|
|
|
input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1]
|
|
|
|
|
|
|
|
|
cur_assistant_msg = {
|
|
|
"edit_type": edit_type,
|
|
|
"edit_info": edit_info,
|
|
|
"source_text": text_to_use,
|
|
|
"target_text": generated_text,
|
|
|
"raw_wave": (input_sample_rate, input_audio_data_numpy),
|
|
|
"edit_wave": (output_sr, audio_numpy),
|
|
|
}
|
|
|
state["history_audio"].append((output_sr, audio_numpy, generated_text))
|
|
|
state["history_messages"].append(cur_assistant_msg)
|
|
|
|
|
|
show_msgs = self.history_messages_to_show(state["history_messages"])
|
|
|
self.logger.info("Audio editing completed successfully")
|
|
|
return show_msgs, state
|
|
|
else:
|
|
|
error_msg = "[Error] Edit failed"
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = f"[Error] Edit failed: {str(e)}"
|
|
|
self.logger.error(error_msg)
|
|
|
return [{"role": "user", "content": error_msg}], state
|
|
|
|
|
|
def clear_history(self, state):
|
|
|
"""Clear conversation history"""
|
|
|
state["history_messages"] = []
|
|
|
state["history_audio"] = []
|
|
|
return [], state
|
|
|
|
|
|
def init_state(self):
|
|
|
"""Initialize conversation state"""
|
|
|
return {
|
|
|
"history_messages": [],
|
|
|
"history_audio": []
|
|
|
}
|
|
|
|
|
|
def register_components(self):
|
|
|
"""Register gradio components - maintaining exact layout from original"""
|
|
|
with gr.Tab("Editx"):
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
|
|
|
self.prompt_text_input = gr.Textbox(label="Prompt Text", value="", scale=1)
|
|
|
self.prompt_audio_input = gr.Audio(
|
|
|
sources=["upload", "microphone"],
|
|
|
format="wav",
|
|
|
type="filepath",
|
|
|
label="Input Audio",
|
|
|
)
|
|
|
self.generated_text = gr.Textbox(label="Target Text", lines=1, max_lines=200, max_length=1000)
|
|
|
with gr.Column():
|
|
|
with gr.Row():
|
|
|
self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone")
|
|
|
self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None)
|
|
|
self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1)
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
with gr.Row():
|
|
|
self.button_tts = gr.Button("CLONE", variant="primary")
|
|
|
self.button_edit = gr.Button("EDIT", variant="primary")
|
|
|
with gr.Column():
|
|
|
self.clean_history_submit = gr.Button("Clear History", variant="primary")
|
|
|
|
|
|
gr.Markdown("---")
|
|
|
gr.Markdown("""
|
|
|
**Button Description:**
|
|
|
- CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used.
|
|
|
- EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio.
|
|
|
""")
|
|
|
gr.Markdown("""
|
|
|
**Operation Workflow:**
|
|
|
- Upload the audio to be edited on the left side and fill in the corresponding text content of the audio;
|
|
|
- If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "clone text" field. For all other tasks, keep the uploaded audio text content unchanged;
|
|
|
- Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.);
|
|
|
- Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side.
|
|
|
""")
|
|
|
gr.Markdown("""
|
|
|
**Para-linguistic Description:**
|
|
|
- Supported tags include: [Breathing] [Laughter] [Surprise-oh] [Confirmation-en] [Uhm] [Surprise-ah] [Surprise-wa] [Sigh] [Question-ei] [Dissatisfaction-hnn]
|
|
|
- Example:
|
|
|
- Fill in "clone text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio.
|
|
|
- Change "clone text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio.
|
|
|
""")
|
|
|
|
|
|
def register_events(self):
|
|
|
"""Register event handlers"""
|
|
|
|
|
|
state = gr.State(self.init_state())
|
|
|
|
|
|
self.button_tts.click(self.generate_clone,
|
|
|
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
|
|
|
outputs=[self.chat_box, state])
|
|
|
self.button_edit.click(self.generate_edit,
|
|
|
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
|
|
|
outputs=[self.chat_box, state])
|
|
|
|
|
|
self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state])
|
|
|
self.edit_type.change(
|
|
|
fn=self.update_edit_info,
|
|
|
inputs=self.edit_type,
|
|
|
outputs=self.edit_info,
|
|
|
)
|
|
|
|
|
|
def update_edit_info(self, category):
|
|
|
"""Update sub-task dropdown based on main task selection"""
|
|
|
category_items = get_supported_edit_types()
|
|
|
choices = category_items.get(category, [])
|
|
|
value = None if len(choices) == 0 else choices[0]
|
|
|
return gr.Dropdown(label="Sub-task", choices=choices, value=value)
|
|
|
|
|
|
|
|
|
def launch_demo(args, editx_tab):
|
|
|
"""Launch the gradio demo"""
|
|
|
with gr.Blocks(
|
|
|
theme=gr.themes.Soft(),
|
|
|
title="ποΈ Step-Audio-EditX",
|
|
|
css="""
|
|
|
:root {
|
|
|
--font: "Helvetica Neue", Helvetica, Arial, sans-serif;
|
|
|
--font-mono: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
|
|
|
}
|
|
|
""") as demo:
|
|
|
gr.Markdown("## ποΈ Step-Audio-EditX")
|
|
|
gr.Markdown("Audio Editing and Zero-Shot Cloning using Step-Audio-EditX")
|
|
|
|
|
|
|
|
|
editx_tab.register_components()
|
|
|
|
|
|
|
|
|
editx_tab.register_events()
|
|
|
|
|
|
|
|
|
demo.queue().launch(
|
|
|
server_name=args.server_name,
|
|
|
server_port=args.server_port,
|
|
|
share=args.share if hasattr(args, 'share') else False
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Step-Audio Edit Demo")
|
|
|
parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.")
|
|
|
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
|
|
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
|
|
parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.")
|
|
|
parser.add_argument("--share", action="store_true", help="Share gradio app.")
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--model-source",
|
|
|
type=str,
|
|
|
default="huggingface",
|
|
|
choices=["auto", "local", "modelscope", "huggingface"],
|
|
|
help="Model source: auto (detect automatically), local, modelscope, or huggingface"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--tokenizer-model-id",
|
|
|
type=str,
|
|
|
default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online",
|
|
|
help="Tokenizer model ID for online loading"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--tts-model-id",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help="TTS model ID for online loading (if different from model-path)"
|
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
args_global = args
|
|
|
|
|
|
logger.info(f"Configuration loaded:")
|
|
|
logger.info(f"Model source: {args.model_source}")
|
|
|
logger.info(f"Model path: {args.model_path}")
|
|
|
logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
|
|
|
if args.tts_model_id:
|
|
|
logger.info(f"TTS model ID: {args.tts_model_id}")
|
|
|
|
|
|
|
|
|
|
|
|
if ZEROGPU_AVAILABLE:
|
|
|
logger.info("π ZeroGPU detected - using dynamic GPU duration management!")
|
|
|
logger.info("π‘ First call: 300s (model loading), subsequent calls: 120s (inference only)")
|
|
|
else:
|
|
|
logger.info("π» Running in local mode - models will be loaded on first call")
|
|
|
|
|
|
|
|
|
editx_tab = EditxTab(args)
|
|
|
|
|
|
|
|
|
launch_demo(args, editx_tab) |