Spaces:
Running
on
L4
Running
on
L4
| # -*- coding: utf-8 -*- | |
| import gradio as gr | |
| import os | |
| import torch | |
| import argparse | |
| import librosa | |
| import soundfile as sf | |
| from huggingface_hub import snapshot_download | |
| from loguru import logger | |
| from gpa_inference import GPAInference | |
| # Configuration constants | |
| MAX_AUDIO_DURATION = 30 # Max audio duration (seconds) | |
| MAX_TEXT_LENGTH = 2048 # Max text length (characters) | |
| # Global inference object placeholder | |
| inference = None | |
| def validate_audio_duration(audio_path): | |
| """Validate if audio duration exceeds limit""" | |
| if not audio_path: | |
| return True, 0 | |
| try: | |
| y, sr = librosa.load(audio_path, sr=None) | |
| duration = len(y) / sr | |
| if duration > MAX_AUDIO_DURATION: | |
| logger.warning(f"Audio duration {duration:.2f}s exceeds limit {MAX_AUDIO_DURATION}s") | |
| return False, duration | |
| return True, duration | |
| except Exception as e: | |
| logger.error(f"Error validating audio duration: {e}") | |
| return False, 0 | |
| def validate_text_length(text): | |
| """Validate if text length exceeds limit""" | |
| if not text: | |
| return True, 0 | |
| text_len = len(text) | |
| if text_len > MAX_TEXT_LENGTH: | |
| logger.warning(f"Text length {text_len} exceeds limit {MAX_TEXT_LENGTH}") | |
| return False, text_len | |
| return True, text_len | |
| def preprocess_audio(audio_path): | |
| """Ensure audio is 16kHz mono""" | |
| if not audio_path: | |
| return None | |
| try: | |
| # Validate audio duration | |
| is_valid, duration = validate_audio_duration(audio_path) | |
| if not is_valid: | |
| logger.warning(f"Audio duration {duration:.2f}s exceeds max limit {MAX_AUDIO_DURATION}s. Truncating.") | |
| # Load audio with librosa: automatically resamples to sr=16000 and converts to mono | |
| y, _ = librosa.load(audio_path, sr=16000, mono=True) | |
| # Truncate if exceeds max duration | |
| max_samples = int(MAX_AUDIO_DURATION * 16000) | |
| if len(y) > max_samples: | |
| y = y[:max_samples] | |
| # Save processed audio to a new file to avoid conflicts | |
| dir_name = os.path.dirname(audio_path) | |
| base_name = os.path.basename(audio_path) | |
| name, ext = os.path.splitext(base_name) | |
| new_path = os.path.join(dir_name, f"{name}_16k.wav") | |
| sf.write(new_path, y, 16000) | |
| logger.info(f"Preprocessed audio saved to: {new_path}") | |
| return new_path | |
| except ValueError as ve: | |
| # Re-raise validation error | |
| raise ve | |
| except Exception as e: | |
| logger.error(f"Error processing audio {audio_path}: {e}") | |
| return audio_path | |
| # ======================== Interface Call Logic ======================== | |
| def process_stt(audio_path): | |
| global inference | |
| if inference is None: | |
| return "Model not initialized" | |
| if not audio_path: | |
| return "Please upload audio file first" | |
| try: | |
| # Preprocess audio | |
| audio_path = preprocess_audio(audio_path) | |
| # Direct inference call | |
| return inference.run_stt(audio_path=audio_path, do_sample=False) | |
| except ValueError as ve: | |
| return f"Error: {str(ve)}" | |
| except Exception as e: | |
| logger.error(f"STT processing error: {e}") | |
| return f"Processing failed: {str(e)}" | |
| def process_tts_a(text, ref_audio): | |
| global inference | |
| if inference is None: | |
| return None | |
| if not text or not ref_audio: | |
| return None | |
| try: | |
| # Validate text length | |
| is_valid, text_len = validate_text_length(text) | |
| if not is_valid: | |
| logger.warning(f"Text length {text_len} exceeds max limit {MAX_TEXT_LENGTH}. Truncating.") | |
| text = text[:MAX_TEXT_LENGTH] | |
| # Preprocess audio | |
| ref_audio = preprocess_audio(ref_audio) | |
| # Direct inference call - returns (sample_rate, audio_array) | |
| result = inference.run_tts( | |
| task="tts-a", | |
| output_filename="tts_output.wav", | |
| text=text, | |
| ref_audio_path=ref_audio, | |
| temperature=0.8, | |
| do_sample=True, | |
| ) | |
| # Return tuple format for Gradio Audio component | |
| return result | |
| except ValueError as ve: | |
| logger.error(f"TTS validation failed: {ve}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"TTS processing error: {e}") | |
| return None | |
| def process_vc(src_audio, ref_audio): | |
| global inference | |
| if inference is None: | |
| return None | |
| if not src_audio or not ref_audio: | |
| return None | |
| try: | |
| # Preprocess audio | |
| src_audio = preprocess_audio(src_audio) | |
| ref_audio = preprocess_audio(ref_audio) | |
| # Direct inference call - returns (sample_rate, audio_array) | |
| result = inference.run_vc( | |
| source_audio_path=src_audio, | |
| ref_audio_path=ref_audio, | |
| output_filename="vc_output.wav", | |
| ) | |
| # Return tuple format for Gradio Audio component | |
| return result | |
| except ValueError as ve: | |
| logger.error(f"VC validation failed: {ve}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"VC processing error: {e}") | |
| return None | |
| # ======================== Gradio UI Layout ======================== | |
| # Use a soft, premium theme with indigo/slate colors to replace the default orange | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="slate", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ) | |
| with gr.Blocks( | |
| title="General Purpose Audio System", | |
| theme=theme, | |
| ) as demo: | |
| gr.Markdown( | |
| "# GPA: One Model for Speech Recognition, Text-to-Speech, and Voice Conversion" | |
| ) | |
| gr.HTML( | |
| """ | |
| <div style="display: flex; flex-wrap: nowrap; gap: 8px; overflow-x: auto;"> | |
| <a href="https://arxiv.org/abs/2601.10770"><img src="https://img.shields.io/badge/ArXiv-2601.10770-b31b1b?style=for-the-badge&logo=arxiv" alt="ArXiv"></a> | |
| <a href="https://autoark.github.io/GPA/"><img src="https://img.shields.io/badge/Demo-GitHub%20Pages-blue?style=for-the-badge&logo=github" alt="Demo"></a> | |
| <a href="https://huggingface.co/AutoArk-AI/GPA"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow?style=for-the-badge" alt="Hugging Face"></a> | |
| <a href="https://huggingface.co/spaces/AutoArk-AI/GPA_DEMO"><img src="https://img.shields.io/badge/🎮%20Interactive%20Demo-Try%20It!-blue?style=for-the-badge" alt="Interactive Demo"></a> | |
| <a href="https://www.modelscope.cn/models/AutoArk/GPA"><img src="https://img.shields.io/badge/🤖%20ModelScope-Models-purple?style=for-the-badge" alt="ModelScope"></a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # --- TTS-A Tab --- | |
| with gr.TabItem("👤 Text to Speech (TTS)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| ttsa_text = gr.Textbox( | |
| label="Synthesis Text", | |
| placeholder=f"Enter text to synthesize (max {MAX_TEXT_LENGTH} chars)...", | |
| value="Hello, I am generated by voice cloning.", | |
| lines=3, | |
| max_lines=10, | |
| ) | |
| ttsa_ref = gr.Audio( | |
| label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s", | |
| type="filepath" | |
| ) | |
| ttsa_output = gr.Audio(label="Synthesis Result") | |
| ttsa_btn = gr.Button("Synthesize Now", variant="primary") | |
| ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output) | |
| # gr.Examples( | |
| # examples=[ | |
| # [ | |
| # "Hello, I am generated by voice cloning.", | |
| # "examples/tts/01/prompt.wav", | |
| # ], | |
| # [ | |
| # "Welcome to the General Purpose Audio System.", | |
| # "examples/tts/02/prompt.wav", | |
| # ], | |
| # ], | |
| # inputs=[ttsa_text, ttsa_ref], | |
| # outputs=ttsa_output, | |
| # fn=process_tts_a, | |
| # cache_examples=True, | |
| # ) | |
| # --- VC Tab --- | |
| with gr.TabItem("🎭 Voice Conversion (VC)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| vc_src = gr.Audio( | |
| label=f"Source Audio (Content Source) - Max {MAX_AUDIO_DURATION}s", | |
| type="filepath" | |
| ) | |
| vc_ref = gr.Audio( | |
| label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s", | |
| type="filepath" | |
| ) | |
| vc_output = gr.Audio(label="Conversion Result") | |
| vc_btn = gr.Button("Start Conversion", variant="primary") | |
| vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output) | |
| # --- STT Tab --- | |
| with gr.TabItem("🎙️ Speech to Text (STT)"): | |
| with gr.Row(): | |
| stt_input = gr.Audio( | |
| label=f"Input Audio - Max {MAX_AUDIO_DURATION}s", | |
| type="filepath" | |
| ) | |
| stt_output = gr.Textbox( | |
| label="Recognition Result", | |
| placeholder="Recognition result will be displayed here in real-time...", | |
| lines=5, | |
| ) | |
| stt_btn = gr.Button("Start Recognition", variant="primary") | |
| stt_btn.click(process_stt, inputs=stt_input, outputs=stt_output) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="GPA Audio System GUI") | |
| # Model Paths | |
| parser.add_argument( | |
| "--hf_model_id", | |
| type=str, | |
| default="AutoArk-AI/GPA", | |
| help="Hugging Face model ID to download", | |
| ) | |
| parser.add_argument( | |
| "--cache_dir", | |
| type=str, | |
| default="./models", | |
| help="Directory to cache downloaded models", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer_path", | |
| type=str, | |
| default=None, | |
| help="Path to GLM4 tokenizer (if None, will use downloaded model)", | |
| ) | |
| parser.add_argument( | |
| "--text_tokenizer_path", | |
| type=str, | |
| default=None, | |
| help="Path to text tokenizer (if None, will use downloaded model)", | |
| ) | |
| parser.add_argument( | |
| "--bicodec_tokenizer_path", | |
| type=str, | |
| default=None, | |
| help="Path to BiCodec tokenizer (if None, will use downloaded model)", | |
| ) | |
| parser.add_argument( | |
| "--gpa_model_path", | |
| type=str, | |
| default=None, | |
| help="Path to GPA model (if None, will use downloaded model)", | |
| ) | |
| return parser.parse_args() | |
| args = parse_args() | |
| # Download model from Hugging Face Hub | |
| logger.info(f"Downloading model from {args.hf_model_id}...") | |
| model_base_path = snapshot_download( | |
| repo_id=args.hf_model_id, | |
| cache_dir=args.cache_dir, | |
| resume_download=True, | |
| ) | |
| # model_base_path = "" | |
| logger.info(f"Model downloaded to: {model_base_path}") | |
| # Construct actual paths from downloaded model | |
| tokenizer_path = args.tokenizer_path or os.path.join( | |
| model_base_path, "glm-4-voice-tokenizer" | |
| ) | |
| text_tokenizer_path = args.text_tokenizer_path or model_base_path | |
| bicodec_tokenizer_path = args.bicodec_tokenizer_path or os.path.join( | |
| model_base_path, "BiCodec" | |
| ) | |
| gpa_model_path = args.gpa_model_path or model_base_path | |
| # Instantiate Model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Initializing GPA Inference System on {device}...") | |
| logger.info(f"Tokenizer path: {tokenizer_path}") | |
| logger.info(f"Text tokenizer path: {text_tokenizer_path}") | |
| logger.info(f"BiCodec tokenizer path: {bicodec_tokenizer_path}") | |
| logger.info(f"GPA model path: {gpa_model_path}") | |
| # Use None for output_dir to enable temporary directory in HF Spaces | |
| inference = GPAInference( | |
| tokenizer_path=tokenizer_path, | |
| text_tokenizer_path=text_tokenizer_path, | |
| bicodec_tokenizer_path=bicodec_tokenizer_path, | |
| gpa_model_path=gpa_model_path, | |
| output_dir=None, # Will use temporary directory | |
| device=device, | |
| ) | |
| # Launch Gradio Demo | |
| demo.queue().launch() | |