# -*- coding: utf-8 -*- import gradio as gr import os import torch import argparse import librosa import soundfile as sf from huggingface_hub import snapshot_download from loguru import logger from gpa_inference import GPAInference # Configuration constants MAX_AUDIO_DURATION = 30 # Max audio duration (seconds) MAX_TEXT_LENGTH = 2048 # Max text length (characters) # Global inference object placeholder inference = None def validate_audio_duration(audio_path): """Validate if audio duration exceeds limit""" if not audio_path: return True, 0 try: y, sr = librosa.load(audio_path, sr=None) duration = len(y) / sr if duration > MAX_AUDIO_DURATION: logger.warning(f"Audio duration {duration:.2f}s exceeds limit {MAX_AUDIO_DURATION}s") return False, duration return True, duration except Exception as e: logger.error(f"Error validating audio duration: {e}") return False, 0 def validate_text_length(text): """Validate if text length exceeds limit""" if not text: return True, 0 text_len = len(text) if text_len > MAX_TEXT_LENGTH: logger.warning(f"Text length {text_len} exceeds limit {MAX_TEXT_LENGTH}") return False, text_len return True, text_len def preprocess_audio(audio_path): """Ensure audio is 16kHz mono""" if not audio_path: return None try: # Validate audio duration is_valid, duration = validate_audio_duration(audio_path) if not is_valid: logger.warning(f"Audio duration {duration:.2f}s exceeds max limit {MAX_AUDIO_DURATION}s. Truncating.") # Load audio with librosa: automatically resamples to sr=16000 and converts to mono y, _ = librosa.load(audio_path, sr=16000, mono=True) # Truncate if exceeds max duration max_samples = int(MAX_AUDIO_DURATION * 16000) if len(y) > max_samples: y = y[:max_samples] # Save processed audio to a new file to avoid conflicts dir_name = os.path.dirname(audio_path) base_name = os.path.basename(audio_path) name, ext = os.path.splitext(base_name) new_path = os.path.join(dir_name, f"{name}_16k.wav") sf.write(new_path, y, 16000) logger.info(f"Preprocessed audio saved to: {new_path}") return new_path except ValueError as ve: # Re-raise validation error raise ve except Exception as e: logger.error(f"Error processing audio {audio_path}: {e}") return audio_path # ======================== Interface Call Logic ======================== def process_stt(audio_path): global inference if inference is None: return "Model not initialized" if not audio_path: return "Please upload audio file first" try: # Preprocess audio audio_path = preprocess_audio(audio_path) # Direct inference call return inference.run_stt(audio_path=audio_path, do_sample=False) except ValueError as ve: return f"Error: {str(ve)}" except Exception as e: logger.error(f"STT processing error: {e}") return f"Processing failed: {str(e)}" def process_tts_a(text, ref_audio): global inference if inference is None: return None if not text or not ref_audio: return None try: # Validate text length is_valid, text_len = validate_text_length(text) if not is_valid: logger.warning(f"Text length {text_len} exceeds max limit {MAX_TEXT_LENGTH}. Truncating.") text = text[:MAX_TEXT_LENGTH] # Preprocess audio ref_audio = preprocess_audio(ref_audio) # Direct inference call - returns (sample_rate, audio_array) result = inference.run_tts( task="tts-a", output_filename="tts_output.wav", text=text, ref_audio_path=ref_audio, temperature=0.8, do_sample=True, ) # Return tuple format for Gradio Audio component return result except ValueError as ve: logger.error(f"TTS validation failed: {ve}") return None except Exception as e: logger.error(f"TTS processing error: {e}") return None def process_vc(src_audio, ref_audio): global inference if inference is None: return None if not src_audio or not ref_audio: return None try: # Preprocess audio src_audio = preprocess_audio(src_audio) ref_audio = preprocess_audio(ref_audio) # Direct inference call - returns (sample_rate, audio_array) result = inference.run_vc( source_audio_path=src_audio, ref_audio_path=ref_audio, output_filename="vc_output.wav", ) # Return tuple format for Gradio Audio component return result except ValueError as ve: logger.error(f"VC validation failed: {ve}") return None except Exception as e: logger.error(f"VC processing error: {e}") return None # ======================== Gradio UI Layout ======================== # Use a soft, premium theme with indigo/slate colors to replace the default orange theme = gr.themes.Soft( primary_hue="indigo", secondary_hue="slate", neutral_hue="slate", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], ) with gr.Blocks( title="General Purpose Audio System", theme=theme, ) as demo: gr.Markdown( "# GPA: One Model for Speech Recognition, Text-to-Speech, and Voice Conversion" ) gr.HTML( """
""" ) with gr.Tabs(): # --- TTS-A Tab --- with gr.TabItem("👤 Text to Speech (TTS)"): with gr.Row(): with gr.Column(): ttsa_text = gr.Textbox( label="Synthesis Text", placeholder=f"Enter text to synthesize (max {MAX_TEXT_LENGTH} chars)...", value="Hello, I am generated by voice cloning.", lines=3, max_lines=10, ) ttsa_ref = gr.Audio( label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s", type="filepath" ) ttsa_output = gr.Audio(label="Synthesis Result") ttsa_btn = gr.Button("Synthesize Now", variant="primary") ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output) # gr.Examples( # examples=[ # [ # "Hello, I am generated by voice cloning.", # "examples/tts/01/prompt.wav", # ], # [ # "Welcome to the General Purpose Audio System.", # "examples/tts/02/prompt.wav", # ], # ], # inputs=[ttsa_text, ttsa_ref], # outputs=ttsa_output, # fn=process_tts_a, # cache_examples=True, # ) # --- VC Tab --- with gr.TabItem("🎭 Voice Conversion (VC)"): with gr.Row(): with gr.Column(): vc_src = gr.Audio( label=f"Source Audio (Content Source) - Max {MAX_AUDIO_DURATION}s", type="filepath" ) vc_ref = gr.Audio( label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s", type="filepath" ) vc_output = gr.Audio(label="Conversion Result") vc_btn = gr.Button("Start Conversion", variant="primary") vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output) # --- STT Tab --- with gr.TabItem("🎙️ Speech to Text (STT)"): with gr.Row(): stt_input = gr.Audio( label=f"Input Audio - Max {MAX_AUDIO_DURATION}s", type="filepath" ) stt_output = gr.Textbox( label="Recognition Result", placeholder="Recognition result will be displayed here in real-time...", lines=5, ) stt_btn = gr.Button("Start Recognition", variant="primary") stt_btn.click(process_stt, inputs=stt_input, outputs=stt_output) def parse_args(): parser = argparse.ArgumentParser(description="GPA Audio System GUI") # Model Paths parser.add_argument( "--hf_model_id", type=str, default="AutoArk-AI/GPA", help="Hugging Face model ID to download", ) parser.add_argument( "--cache_dir", type=str, default="./models", help="Directory to cache downloaded models", ) parser.add_argument( "--tokenizer_path", type=str, default=None, help="Path to GLM4 tokenizer (if None, will use downloaded model)", ) parser.add_argument( "--text_tokenizer_path", type=str, default=None, help="Path to text tokenizer (if None, will use downloaded model)", ) parser.add_argument( "--bicodec_tokenizer_path", type=str, default=None, help="Path to BiCodec tokenizer (if None, will use downloaded model)", ) parser.add_argument( "--gpa_model_path", type=str, default=None, help="Path to GPA model (if None, will use downloaded model)", ) return parser.parse_args() args = parse_args() # Download model from Hugging Face Hub logger.info(f"Downloading model from {args.hf_model_id}...") model_base_path = snapshot_download( repo_id=args.hf_model_id, cache_dir=args.cache_dir, resume_download=True, ) # model_base_path = "" logger.info(f"Model downloaded to: {model_base_path}") # Construct actual paths from downloaded model tokenizer_path = args.tokenizer_path or os.path.join( model_base_path, "glm-4-voice-tokenizer" ) text_tokenizer_path = args.text_tokenizer_path or model_base_path bicodec_tokenizer_path = args.bicodec_tokenizer_path or os.path.join( model_base_path, "BiCodec" ) gpa_model_path = args.gpa_model_path or model_base_path # Instantiate Model device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing GPA Inference System on {device}...") logger.info(f"Tokenizer path: {tokenizer_path}") logger.info(f"Text tokenizer path: {text_tokenizer_path}") logger.info(f"BiCodec tokenizer path: {bicodec_tokenizer_path}") logger.info(f"GPA model path: {gpa_model_path}") # Use None for output_dir to enable temporary directory in HF Spaces inference = GPAInference( tokenizer_path=tokenizer_path, text_tokenizer_path=text_tokenizer_path, bicodec_tokenizer_path=bicodec_tokenizer_path, gpa_model_path=gpa_model_path, output_dir=None, # Will use temporary directory device=device, ) # Launch Gradio Demo demo.queue().launch()