File size: 12,266 Bytes
528efee
 
 
 
 
 
 
f4bb8a5
f1fdc79
528efee
 
 
f1fdc79
 
 
 
528efee
 
 
 
f1fdc79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528efee
 
 
 
 
f1fdc79
 
 
71d4610
f1fdc79
528efee
 
 
71d4610
 
 
 
 
528efee
 
 
 
 
 
 
f1fdc79
528efee
f1fdc79
 
 
528efee
f1fdc79
528efee
 
 
 
 
 
 
 
f1fdc79
528efee
 
f1fdc79
528efee
f1fdc79
 
 
528efee
f1fdc79
 
 
 
 
 
 
528efee
 
 
 
 
 
 
 
 
f1fdc79
 
 
 
71d4610
 
f1fdc79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528efee
 
 
 
 
 
 
 
 
f1fdc79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528efee
 
 
 
 
 
 
 
 
 
 
71de750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528efee
 
 
 
 
 
 
f63959d
 
f1fdc79
f63959d
 
f1fdc79
 
 
 
 
f63959d
528efee
 
 
 
4536e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71de750
528efee
 
 
 
f1fdc79
 
 
 
 
 
 
 
528efee
 
 
 
8220abd
 
 
f1fdc79
 
 
 
8220abd
 
 
 
 
 
 
528efee
 
 
 
 
 
f4bb8a5
528efee
f4bb8a5
 
528efee
 
f4bb8a5
528efee
f4bb8a5
 
528efee
 
f4bb8a5
528efee
f4bb8a5
 
528efee
 
f4bb8a5
528efee
f4bb8a5
 
528efee
 
f4bb8a5
528efee
f4bb8a5
 
528efee
 
f4bb8a5
528efee
f4bb8a5
 
528efee
 
 
 
 
 
f4bb8a5
f1fdc79
f4bb8a5
 
 
 
 
f63959d
f1fdc79
f4bb8a5
 
 
 
 
 
 
 
 
 
 
528efee
8d62622
 
f1fdc79
 
 
 
 
528efee
f4bb8a5
528efee
f4bb8a5
 
 
 
 
8d62622
528efee
 
 
15197ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# -*- coding: utf-8 -*-
import gradio as gr
import os
import torch
import argparse
import librosa
import soundfile as sf
from huggingface_hub import snapshot_download
from loguru import logger

from gpa_inference import GPAInference

# Configuration constants
MAX_AUDIO_DURATION = 30  # Max audio duration (seconds)
MAX_TEXT_LENGTH = 2048   # Max text length (characters)

# Global inference object placeholder
inference = None


def validate_audio_duration(audio_path):
    """Validate if audio duration exceeds limit"""
    if not audio_path:
        return True, 0
    try:
        y, sr = librosa.load(audio_path, sr=None)
        duration = len(y) / sr
        if duration > MAX_AUDIO_DURATION:
            logger.warning(f"Audio duration {duration:.2f}s exceeds limit {MAX_AUDIO_DURATION}s")
            return False, duration
        return True, duration
    except Exception as e:
        logger.error(f"Error validating audio duration: {e}")
        return False, 0


def validate_text_length(text):
    """Validate if text length exceeds limit"""
    if not text:
        return True, 0
    text_len = len(text)
    if text_len > MAX_TEXT_LENGTH:
        logger.warning(f"Text length {text_len} exceeds limit {MAX_TEXT_LENGTH}")
        return False, text_len
    return True, text_len


def preprocess_audio(audio_path):
    """Ensure audio is 16kHz mono"""
    if not audio_path:
        return None
    try:
        # Validate audio duration
        is_valid, duration = validate_audio_duration(audio_path)
        if not is_valid:
            logger.warning(f"Audio duration {duration:.2f}s exceeds max limit {MAX_AUDIO_DURATION}s. Truncating.")

        # Load audio with librosa: automatically resamples to sr=16000 and converts to mono
        y, _ = librosa.load(audio_path, sr=16000, mono=True)

        # Truncate if exceeds max duration
        max_samples = int(MAX_AUDIO_DURATION * 16000)
        if len(y) > max_samples:
            y = y[:max_samples]

        # Save processed audio to a new file to avoid conflicts
        dir_name = os.path.dirname(audio_path)
        base_name = os.path.basename(audio_path)
        name, ext = os.path.splitext(base_name)
        new_path = os.path.join(dir_name, f"{name}_16k.wav")

        sf.write(new_path, y, 16000)
        logger.info(f"Preprocessed audio saved to: {new_path}")
        return new_path
    except ValueError as ve:
        # Re-raise validation error
        raise ve
    except Exception as e:
        logger.error(f"Error processing audio {audio_path}: {e}")
        return audio_path


# ======================== Interface Call Logic ========================

def process_stt(audio_path):
    global inference
    if inference is None:
        return "Model not initialized"

    if not audio_path:
        return "Please upload audio file first"

    try:
        # Preprocess audio
        audio_path = preprocess_audio(audio_path)

        # Direct inference call
        return inference.run_stt(audio_path=audio_path, do_sample=False)
    except ValueError as ve:
        return f"Error: {str(ve)}"
    except Exception as e:
        logger.error(f"STT processing error: {e}")
        return f"Processing failed: {str(e)}"

def process_tts_a(text, ref_audio):
    global inference
    if inference is None:
        return None

    if not text or not ref_audio:
        return None

    try:
        # Validate text length
        is_valid, text_len = validate_text_length(text)
        if not is_valid:
            logger.warning(f"Text length {text_len} exceeds max limit {MAX_TEXT_LENGTH}. Truncating.")
            text = text[:MAX_TEXT_LENGTH]

        # Preprocess audio
        ref_audio = preprocess_audio(ref_audio)

        # Direct inference call - returns (sample_rate, audio_array)
        result = inference.run_tts(
            task="tts-a",
            output_filename="tts_output.wav",
            text=text,
            ref_audio_path=ref_audio,
            temperature=0.8,
            do_sample=True,
        )
        # Return tuple format for Gradio Audio component
        return result
    except ValueError as ve:
        logger.error(f"TTS validation failed: {ve}")
        return None
    except Exception as e:
        logger.error(f"TTS processing error: {e}")
        return None

def process_vc(src_audio, ref_audio):
    global inference
    if inference is None:
        return None

    if not src_audio or not ref_audio:
        return None

    try:
        # Preprocess audio
        src_audio = preprocess_audio(src_audio)
        ref_audio = preprocess_audio(ref_audio)

        # Direct inference call - returns (sample_rate, audio_array)
        result = inference.run_vc(
            source_audio_path=src_audio,
            ref_audio_path=ref_audio,
            output_filename="vc_output.wav",
        )
        # Return tuple format for Gradio Audio component
        return result
    except ValueError as ve:
        logger.error(f"VC validation failed: {ve}")
        return None
    except Exception as e:
        logger.error(f"VC processing error: {e}")
        return None

# ======================== Gradio UI Layout ========================

# Use a soft, premium theme with indigo/slate colors to replace the default orange
theme = gr.themes.Soft(
    primary_hue="indigo",
    secondary_hue="slate",
    neutral_hue="slate",
    font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
)

with gr.Blocks(
    title="General Purpose Audio System",
    theme=theme,
) as demo:
    gr.Markdown(
        "# GPA: One Model for Speech Recognition, Text-to-Speech, and Voice Conversion"
    )
    gr.HTML(
        """
        <div style="display: flex; flex-wrap: nowrap; gap: 8px; overflow-x: auto;">
            <a href="https://arxiv.org/abs/2601.10770"><img src="https://img.shields.io/badge/ArXiv-2601.10770-b31b1b?style=for-the-badge&logo=arxiv" alt="ArXiv"></a>
            <a href="https://autoark.github.io/GPA/"><img src="https://img.shields.io/badge/Demo-GitHub%20Pages-blue?style=for-the-badge&logo=github" alt="Demo"></a>
            <a href="https://huggingface.co/AutoArk-AI/GPA"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow?style=for-the-badge" alt="Hugging Face"></a>
            <a href="https://huggingface.co/spaces/AutoArk-AI/GPA_DEMO"><img src="https://img.shields.io/badge/๐ŸŽฎ%20Interactive%20Demo-Try%20It!-blue?style=for-the-badge" alt="Interactive Demo"></a>
            <a href="https://www.modelscope.cn/models/AutoArk/GPA"><img src="https://img.shields.io/badge/๐Ÿค–%20ModelScope-Models-purple?style=for-the-badge" alt="ModelScope"></a>
        </div>
        """
    )

    with gr.Tabs():

        # --- TTS-A Tab ---
        with gr.TabItem("๐Ÿ‘ค Text to Speech (TTS)"):
            with gr.Row():
                with gr.Column():
                    ttsa_text = gr.Textbox(
                        label="Synthesis Text",
                        placeholder=f"Enter text to synthesize (max {MAX_TEXT_LENGTH} chars)...",
                        value="Hello, I am generated by voice cloning.",
                        lines=3,
                        max_lines=10,
                    )
                    ttsa_ref = gr.Audio(
                        label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s",
                        type="filepath"
                    )
                ttsa_output = gr.Audio(label="Synthesis Result")
            ttsa_btn = gr.Button("Synthesize Now", variant="primary")
            ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output)

            # gr.Examples(
            #     examples=[
            #         [
            #             "Hello, I am generated by voice cloning.",
            #             "examples/tts/01/prompt.wav",
            #         ],
            #         [
            #             "Welcome to the General Purpose Audio System.",
            #             "examples/tts/02/prompt.wav",
            #         ],
            #     ],
            #     inputs=[ttsa_text, ttsa_ref],
            #     outputs=ttsa_output,
            #     fn=process_tts_a,
            #     cache_examples=True,
            # )

        # --- VC Tab ---
        with gr.TabItem("๐ŸŽญ Voice Conversion (VC)"):
            with gr.Row():
                with gr.Column():
                    vc_src = gr.Audio(
                        label=f"Source Audio (Content Source) - Max {MAX_AUDIO_DURATION}s",
                        type="filepath"
                    )
                    vc_ref = gr.Audio(
                        label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s",
                        type="filepath"
                    )
                vc_output = gr.Audio(label="Conversion Result")
            vc_btn = gr.Button("Start Conversion", variant="primary")
            vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output)

        # --- STT Tab ---
        with gr.TabItem("๐ŸŽ™๏ธ Speech to Text (STT)"):
            with gr.Row():
                stt_input = gr.Audio(
                    label=f"Input Audio - Max {MAX_AUDIO_DURATION}s",
                    type="filepath"
                )
                stt_output = gr.Textbox(
                    label="Recognition Result",
                    placeholder="Recognition result will be displayed here in real-time...",
                    lines=5,
                )
            stt_btn = gr.Button("Start Recognition", variant="primary")
            stt_btn.click(process_stt, inputs=stt_input, outputs=stt_output)

def parse_args():
    parser = argparse.ArgumentParser(description="GPA Audio System GUI")

    # Model Paths
    parser.add_argument(
        "--hf_model_id",
        type=str,
        default="AutoArk-AI/GPA",
        help="Hugging Face model ID to download",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default="./models",
        help="Directory to cache downloaded models",
    )
    parser.add_argument(
        "--tokenizer_path",
        type=str,
        default=None,
        help="Path to GLM4 tokenizer (if None, will use downloaded model)",
    )
    parser.add_argument(
        "--text_tokenizer_path",
        type=str,
        default=None,
        help="Path to text tokenizer (if None, will use downloaded model)",
    )
    parser.add_argument(
        "--bicodec_tokenizer_path",
        type=str,
        default=None,
        help="Path to BiCodec tokenizer (if None, will use downloaded model)",
    )
    parser.add_argument(
        "--gpa_model_path",
        type=str,
        default=None,
        help="Path to GPA model (if None, will use downloaded model)",
    )

    return parser.parse_args()

args = parse_args()

# Download model from Hugging Face Hub
logger.info(f"Downloading model from {args.hf_model_id}...")
model_base_path = snapshot_download(
    repo_id=args.hf_model_id,
    cache_dir=args.cache_dir,
    resume_download=True,
)
# model_base_path = ""
logger.info(f"Model downloaded to: {model_base_path}")

# Construct actual paths from downloaded model
tokenizer_path = args.tokenizer_path or os.path.join(
    model_base_path, "glm-4-voice-tokenizer"
)
text_tokenizer_path = args.text_tokenizer_path or model_base_path
bicodec_tokenizer_path = args.bicodec_tokenizer_path or os.path.join(
    model_base_path, "BiCodec"
)
gpa_model_path = args.gpa_model_path or model_base_path

# Instantiate Model
device = "cuda" if torch.cuda.is_available() else "cpu"

logger.info(f"Initializing GPA Inference System on {device}...")
logger.info(f"Tokenizer path: {tokenizer_path}")
logger.info(f"Text tokenizer path: {text_tokenizer_path}")
logger.info(f"BiCodec tokenizer path: {bicodec_tokenizer_path}")
logger.info(f"GPA model path: {gpa_model_path}")

# Use None for output_dir to enable temporary directory in HF Spaces
inference = GPAInference(
    tokenizer_path=tokenizer_path,
    text_tokenizer_path=text_tokenizer_path,
    bicodec_tokenizer_path=bicodec_tokenizer_path,
    gpa_model_path=gpa_model_path,
    output_dir=None,  # Will use temporary directory
    device=device,
)

# Launch Gradio Demo
demo.queue().launch()