GPA_DEMO / app.py
wanglamao
add trim logic
71d4610
# -*- coding: utf-8 -*-
import gradio as gr
import os
import torch
import argparse
import librosa
import soundfile as sf
from huggingface_hub import snapshot_download
from loguru import logger
from gpa_inference import GPAInference
# Configuration constants
MAX_AUDIO_DURATION = 30 # Max audio duration (seconds)
MAX_TEXT_LENGTH = 2048 # Max text length (characters)
# Global inference object placeholder
inference = None
def validate_audio_duration(audio_path):
"""Validate if audio duration exceeds limit"""
if not audio_path:
return True, 0
try:
y, sr = librosa.load(audio_path, sr=None)
duration = len(y) / sr
if duration > MAX_AUDIO_DURATION:
logger.warning(f"Audio duration {duration:.2f}s exceeds limit {MAX_AUDIO_DURATION}s")
return False, duration
return True, duration
except Exception as e:
logger.error(f"Error validating audio duration: {e}")
return False, 0
def validate_text_length(text):
"""Validate if text length exceeds limit"""
if not text:
return True, 0
text_len = len(text)
if text_len > MAX_TEXT_LENGTH:
logger.warning(f"Text length {text_len} exceeds limit {MAX_TEXT_LENGTH}")
return False, text_len
return True, text_len
def preprocess_audio(audio_path):
"""Ensure audio is 16kHz mono"""
if not audio_path:
return None
try:
# Validate audio duration
is_valid, duration = validate_audio_duration(audio_path)
if not is_valid:
logger.warning(f"Audio duration {duration:.2f}s exceeds max limit {MAX_AUDIO_DURATION}s. Truncating.")
# Load audio with librosa: automatically resamples to sr=16000 and converts to mono
y, _ = librosa.load(audio_path, sr=16000, mono=True)
# Truncate if exceeds max duration
max_samples = int(MAX_AUDIO_DURATION * 16000)
if len(y) > max_samples:
y = y[:max_samples]
# Save processed audio to a new file to avoid conflicts
dir_name = os.path.dirname(audio_path)
base_name = os.path.basename(audio_path)
name, ext = os.path.splitext(base_name)
new_path = os.path.join(dir_name, f"{name}_16k.wav")
sf.write(new_path, y, 16000)
logger.info(f"Preprocessed audio saved to: {new_path}")
return new_path
except ValueError as ve:
# Re-raise validation error
raise ve
except Exception as e:
logger.error(f"Error processing audio {audio_path}: {e}")
return audio_path
# ======================== Interface Call Logic ========================
def process_stt(audio_path):
global inference
if inference is None:
return "Model not initialized"
if not audio_path:
return "Please upload audio file first"
try:
# Preprocess audio
audio_path = preprocess_audio(audio_path)
# Direct inference call
return inference.run_stt(audio_path=audio_path, do_sample=False)
except ValueError as ve:
return f"Error: {str(ve)}"
except Exception as e:
logger.error(f"STT processing error: {e}")
return f"Processing failed: {str(e)}"
def process_tts_a(text, ref_audio):
global inference
if inference is None:
return None
if not text or not ref_audio:
return None
try:
# Validate text length
is_valid, text_len = validate_text_length(text)
if not is_valid:
logger.warning(f"Text length {text_len} exceeds max limit {MAX_TEXT_LENGTH}. Truncating.")
text = text[:MAX_TEXT_LENGTH]
# Preprocess audio
ref_audio = preprocess_audio(ref_audio)
# Direct inference call - returns (sample_rate, audio_array)
result = inference.run_tts(
task="tts-a",
output_filename="tts_output.wav",
text=text,
ref_audio_path=ref_audio,
temperature=0.8,
do_sample=True,
)
# Return tuple format for Gradio Audio component
return result
except ValueError as ve:
logger.error(f"TTS validation failed: {ve}")
return None
except Exception as e:
logger.error(f"TTS processing error: {e}")
return None
def process_vc(src_audio, ref_audio):
global inference
if inference is None:
return None
if not src_audio or not ref_audio:
return None
try:
# Preprocess audio
src_audio = preprocess_audio(src_audio)
ref_audio = preprocess_audio(ref_audio)
# Direct inference call - returns (sample_rate, audio_array)
result = inference.run_vc(
source_audio_path=src_audio,
ref_audio_path=ref_audio,
output_filename="vc_output.wav",
)
# Return tuple format for Gradio Audio component
return result
except ValueError as ve:
logger.error(f"VC validation failed: {ve}")
return None
except Exception as e:
logger.error(f"VC processing error: {e}")
return None
# ======================== Gradio UI Layout ========================
# Use a soft, premium theme with indigo/slate colors to replace the default orange
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="slate",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
)
with gr.Blocks(
title="General Purpose Audio System",
theme=theme,
) as demo:
gr.Markdown(
"# GPA: One Model for Speech Recognition, Text-to-Speech, and Voice Conversion"
)
gr.HTML(
"""
<div style="display: flex; flex-wrap: nowrap; gap: 8px; overflow-x: auto;">
<a href="https://arxiv.org/abs/2601.10770"><img src="https://img.shields.io/badge/ArXiv-2601.10770-b31b1b?style=for-the-badge&logo=arxiv" alt="ArXiv"></a>
<a href="https://autoark.github.io/GPA/"><img src="https://img.shields.io/badge/Demo-GitHub%20Pages-blue?style=for-the-badge&logo=github" alt="Demo"></a>
<a href="https://huggingface.co/AutoArk-AI/GPA"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow?style=for-the-badge" alt="Hugging Face"></a>
<a href="https://huggingface.co/spaces/AutoArk-AI/GPA_DEMO"><img src="https://img.shields.io/badge/🎮%20Interactive%20Demo-Try%20It!-blue?style=for-the-badge" alt="Interactive Demo"></a>
<a href="https://www.modelscope.cn/models/AutoArk/GPA"><img src="https://img.shields.io/badge/🤖%20ModelScope-Models-purple?style=for-the-badge" alt="ModelScope"></a>
</div>
"""
)
with gr.Tabs():
# --- TTS-A Tab ---
with gr.TabItem("👤 Text to Speech (TTS)"):
with gr.Row():
with gr.Column():
ttsa_text = gr.Textbox(
label="Synthesis Text",
placeholder=f"Enter text to synthesize (max {MAX_TEXT_LENGTH} chars)...",
value="Hello, I am generated by voice cloning.",
lines=3,
max_lines=10,
)
ttsa_ref = gr.Audio(
label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s",
type="filepath"
)
ttsa_output = gr.Audio(label="Synthesis Result")
ttsa_btn = gr.Button("Synthesize Now", variant="primary")
ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output)
# gr.Examples(
# examples=[
# [
# "Hello, I am generated by voice cloning.",
# "examples/tts/01/prompt.wav",
# ],
# [
# "Welcome to the General Purpose Audio System.",
# "examples/tts/02/prompt.wav",
# ],
# ],
# inputs=[ttsa_text, ttsa_ref],
# outputs=ttsa_output,
# fn=process_tts_a,
# cache_examples=True,
# )
# --- VC Tab ---
with gr.TabItem("🎭 Voice Conversion (VC)"):
with gr.Row():
with gr.Column():
vc_src = gr.Audio(
label=f"Source Audio (Content Source) - Max {MAX_AUDIO_DURATION}s",
type="filepath"
)
vc_ref = gr.Audio(
label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s",
type="filepath"
)
vc_output = gr.Audio(label="Conversion Result")
vc_btn = gr.Button("Start Conversion", variant="primary")
vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output)
# --- STT Tab ---
with gr.TabItem("🎙️ Speech to Text (STT)"):
with gr.Row():
stt_input = gr.Audio(
label=f"Input Audio - Max {MAX_AUDIO_DURATION}s",
type="filepath"
)
stt_output = gr.Textbox(
label="Recognition Result",
placeholder="Recognition result will be displayed here in real-time...",
lines=5,
)
stt_btn = gr.Button("Start Recognition", variant="primary")
stt_btn.click(process_stt, inputs=stt_input, outputs=stt_output)
def parse_args():
parser = argparse.ArgumentParser(description="GPA Audio System GUI")
# Model Paths
parser.add_argument(
"--hf_model_id",
type=str,
default="AutoArk-AI/GPA",
help="Hugging Face model ID to download",
)
parser.add_argument(
"--cache_dir",
type=str,
default="./models",
help="Directory to cache downloaded models",
)
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Path to GLM4 tokenizer (if None, will use downloaded model)",
)
parser.add_argument(
"--text_tokenizer_path",
type=str,
default=None,
help="Path to text tokenizer (if None, will use downloaded model)",
)
parser.add_argument(
"--bicodec_tokenizer_path",
type=str,
default=None,
help="Path to BiCodec tokenizer (if None, will use downloaded model)",
)
parser.add_argument(
"--gpa_model_path",
type=str,
default=None,
help="Path to GPA model (if None, will use downloaded model)",
)
return parser.parse_args()
args = parse_args()
# Download model from Hugging Face Hub
logger.info(f"Downloading model from {args.hf_model_id}...")
model_base_path = snapshot_download(
repo_id=args.hf_model_id,
cache_dir=args.cache_dir,
resume_download=True,
)
# model_base_path = ""
logger.info(f"Model downloaded to: {model_base_path}")
# Construct actual paths from downloaded model
tokenizer_path = args.tokenizer_path or os.path.join(
model_base_path, "glm-4-voice-tokenizer"
)
text_tokenizer_path = args.text_tokenizer_path or model_base_path
bicodec_tokenizer_path = args.bicodec_tokenizer_path or os.path.join(
model_base_path, "BiCodec"
)
gpa_model_path = args.gpa_model_path or model_base_path
# Instantiate Model
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing GPA Inference System on {device}...")
logger.info(f"Tokenizer path: {tokenizer_path}")
logger.info(f"Text tokenizer path: {text_tokenizer_path}")
logger.info(f"BiCodec tokenizer path: {bicodec_tokenizer_path}")
logger.info(f"GPA model path: {gpa_model_path}")
# Use None for output_dir to enable temporary directory in HF Spaces
inference = GPAInference(
tokenizer_path=tokenizer_path,
text_tokenizer_path=text_tokenizer_path,
bicodec_tokenizer_path=bicodec_tokenizer_path,
gpa_model_path=gpa_model_path,
output_dir=None, # Will use temporary directory
device=device,
)
# Launch Gradio Demo
demo.queue().launch()