wanglamao commited on
Commit
f1fdc79
·
1 Parent(s): f63959d

add max len limit

Browse files
Files changed (1) hide show
  1. app.py +126 -46
app.py CHANGED
@@ -6,18 +6,56 @@ import argparse
6
  import librosa
7
  import soundfile as sf
8
  from huggingface_hub import snapshot_download
 
9
 
10
  from gpa_inference import GPAInference
11
 
 
 
 
 
12
  # Global inference object placeholder
13
  inference = None
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def preprocess_audio(audio_path):
17
  """Ensure audio is 16kHz mono"""
18
  if not audio_path:
19
  return None
20
  try:
 
 
 
 
 
 
21
  # Load audio with librosa: automatically resamples to sr=16000 and converts to mono
22
  y, _ = librosa.load(audio_path, sr=16000, mono=True)
23
 
@@ -28,10 +66,13 @@ def preprocess_audio(audio_path):
28
  new_path = os.path.join(dir_name, f"{name}_16k.wav")
29
 
30
  sf.write(new_path, y, 16000)
31
- print(f"Preprocessed audio saved to: {new_path}")
32
  return new_path
 
 
 
33
  except Exception as e:
34
- print(f"Error processing audio {audio_path}: {e}")
35
  return audio_path
36
 
37
 
@@ -40,16 +81,22 @@ def preprocess_audio(audio_path):
40
  def process_stt(audio_path):
41
  global inference
42
  if inference is None:
43
- return "Model not initialized."
44
 
45
  if not audio_path:
46
- return "Please upload audio first."
47
 
48
- # Preprocess audio
49
- audio_path = preprocess_audio(audio_path)
 
50
 
51
- # Direct inference call
52
- return inference.run_stt(audio_path=audio_path, do_sample=False)
 
 
 
 
 
53
 
54
  def process_tts_a(text, ref_audio):
55
  global inference
@@ -59,20 +106,33 @@ def process_tts_a(text, ref_audio):
59
  if not text or not ref_audio:
60
  return None
61
 
62
- # Preprocess audio
63
- ref_audio = preprocess_audio(ref_audio)
64
-
65
- # Direct inference call - returns (sample_rate, audio_array)
66
- result = inference.run_tts(
67
- task="tts-a",
68
- output_filename="tts_output.wav",
69
- text=text,
70
- ref_audio_path=ref_audio,
71
- temperature=0.8,
72
- do_sample=True,
73
- )
74
- # Return tuple format for Gradio Audio component
75
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def process_vc(src_audio, ref_audio):
78
  global inference
@@ -82,18 +142,25 @@ def process_vc(src_audio, ref_audio):
82
  if not src_audio or not ref_audio:
83
  return None
84
 
85
- # Preprocess audio
86
- src_audio = preprocess_audio(src_audio)
87
- ref_audio = preprocess_audio(ref_audio)
88
-
89
- # Direct inference call - returns (sample_rate, audio_array)
90
- result = inference.run_vc(
91
- source_audio_path=src_audio,
92
- ref_audio_path=ref_audio,
93
- output_filename="vc_output.wav",
94
- )
95
- # Return tuple format for Gradio Audio component
96
- return result
 
 
 
 
 
 
 
97
 
98
  # ======================== Gradio UI Layout ========================
99
 
@@ -132,11 +199,15 @@ with gr.Blocks(
132
  with gr.Column():
133
  ttsa_text = gr.Textbox(
134
  label="Synthesis Text",
135
- placeholder="Enter the text you want to synthesize...",
136
  value="Hello, I am generated by voice cloning.",
137
  lines=3,
 
 
 
 
 
138
  )
139
- ttsa_ref = gr.Audio(label="Reference Audio (Voice Source)", type="filepath")
140
  ttsa_output = gr.Audio(label="Synthesis Result")
141
  ttsa_btn = gr.Button("Synthesize Now", variant="primary")
142
  ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output)
@@ -162,8 +233,14 @@ with gr.Blocks(
162
  with gr.TabItem("🎭 Voice Conversion (VC)"):
163
  with gr.Row():
164
  with gr.Column():
165
- vc_src = gr.Audio(label="Source Audio (Content Source)", type="filepath")
166
- vc_ref = gr.Audio(label="Reference Audio (Voice Source)", type="filepath")
 
 
 
 
 
 
167
  vc_output = gr.Audio(label="Conversion Result")
168
  vc_btn = gr.Button("Start Conversion", variant="primary")
169
  vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output)
@@ -171,7 +248,10 @@ with gr.Blocks(
171
  # --- STT Tab ---
172
  with gr.TabItem("🎙️ Speech to Text (STT)"):
173
  with gr.Row():
174
- stt_input = gr.Audio(label="Input Audio", type="filepath")
 
 
 
175
  stt_output = gr.Textbox(
176
  label="Recognition Result",
177
  placeholder="Recognition result will be displayed here in real-time...",
@@ -226,14 +306,14 @@ def parse_args():
226
  args = parse_args()
227
 
228
  # Download model from Hugging Face Hub
229
- print(f"Downloading model from {args.hf_model_id}...")
230
  model_base_path = snapshot_download(
231
  repo_id=args.hf_model_id,
232
  cache_dir=args.cache_dir,
233
  resume_download=True,
234
  )
235
  # model_base_path = ""
236
- print(f"Model downloaded to: {model_base_path}")
237
 
238
  # Construct actual paths from downloaded model
239
  tokenizer_path = args.tokenizer_path or os.path.join(
@@ -248,11 +328,11 @@ gpa_model_path = args.gpa_model_path or model_base_path
248
  # Instantiate Model
249
  device = "cuda" if torch.cuda.is_available() else "cpu"
250
 
251
- print(f"Initializing GPA Inference System on {device}...")
252
- print(f"Tokenizer path: {tokenizer_path}")
253
- print(f"Text tokenizer path: {text_tokenizer_path}")
254
- print(f"BiCodec tokenizer path: {bicodec_tokenizer_path}")
255
- print(f"GPA model path: {gpa_model_path}")
256
 
257
  # Use None for output_dir to enable temporary directory in HF Spaces
258
  inference = GPAInference(
 
6
  import librosa
7
  import soundfile as sf
8
  from huggingface_hub import snapshot_download
9
+ from loguru import logger
10
 
11
  from gpa_inference import GPAInference
12
 
13
+ # Configuration constants
14
+ MAX_AUDIO_DURATION = 30 # Max audio duration (seconds)
15
+ MAX_TEXT_LENGTH = 2048 # Max text length (characters)
16
+
17
  # Global inference object placeholder
18
  inference = None
19
 
20
 
21
+ def validate_audio_duration(audio_path):
22
+ """Validate if audio duration exceeds limit"""
23
+ if not audio_path:
24
+ return True, 0
25
+ try:
26
+ y, sr = librosa.load(audio_path, sr=None)
27
+ duration = len(y) / sr
28
+ if duration > MAX_AUDIO_DURATION:
29
+ logger.warning(f"Audio duration {duration:.2f}s exceeds limit {MAX_AUDIO_DURATION}s")
30
+ return False, duration
31
+ return True, duration
32
+ except Exception as e:
33
+ logger.error(f"Error validating audio duration: {e}")
34
+ return False, 0
35
+
36
+
37
+ def validate_text_length(text):
38
+ """Validate if text length exceeds limit"""
39
+ if not text:
40
+ return True, 0
41
+ text_len = len(text)
42
+ if text_len > MAX_TEXT_LENGTH:
43
+ logger.warning(f"Text length {text_len} exceeds limit {MAX_TEXT_LENGTH}")
44
+ return False, text_len
45
+ return True, text_len
46
+
47
+
48
  def preprocess_audio(audio_path):
49
  """Ensure audio is 16kHz mono"""
50
  if not audio_path:
51
  return None
52
  try:
53
+ # Validate audio duration
54
+ is_valid, duration = validate_audio_duration(audio_path)
55
+ if not is_valid:
56
+ logger.error(f"Audio duration {duration:.2f}s exceeds max limit {MAX_AUDIO_DURATION}s")
57
+ raise ValueError(f"Audio duration cannot exceed {MAX_AUDIO_DURATION}s, current is {duration:.2f}s")
58
+
59
  # Load audio with librosa: automatically resamples to sr=16000 and converts to mono
60
  y, _ = librosa.load(audio_path, sr=16000, mono=True)
61
 
 
66
  new_path = os.path.join(dir_name, f"{name}_16k.wav")
67
 
68
  sf.write(new_path, y, 16000)
69
+ logger.info(f"Preprocessed audio saved to: {new_path}")
70
  return new_path
71
+ except ValueError as ve:
72
+ # Re-raise validation error
73
+ raise ve
74
  except Exception as e:
75
+ logger.error(f"Error processing audio {audio_path}: {e}")
76
  return audio_path
77
 
78
 
 
81
  def process_stt(audio_path):
82
  global inference
83
  if inference is None:
84
+ return "Model not initialized"
85
 
86
  if not audio_path:
87
+ return "Please upload audio file first"
88
 
89
+ try:
90
+ # Preprocess audio
91
+ audio_path = preprocess_audio(audio_path)
92
 
93
+ # Direct inference call
94
+ return inference.run_stt(audio_path=audio_path, do_sample=False)
95
+ except ValueError as ve:
96
+ return f"Error: {str(ve)}"
97
+ except Exception as e:
98
+ logger.error(f"STT processing error: {e}")
99
+ return f"Processing failed: {str(e)}"
100
 
101
  def process_tts_a(text, ref_audio):
102
  global inference
 
106
  if not text or not ref_audio:
107
  return None
108
 
109
+ try:
110
+ # Validate text length
111
+ is_valid, text_len = validate_text_length(text)
112
+ if not is_valid:
113
+ logger.error(f"Text length {text_len} exceeds max limit {MAX_TEXT_LENGTH}")
114
+ raise ValueError(f"Text length cannot exceed {MAX_TEXT_LENGTH} chars, current is {text_len} chars")
115
+
116
+ # Preprocess audio
117
+ ref_audio = preprocess_audio(ref_audio)
118
+
119
+ # Direct inference call - returns (sample_rate, audio_array)
120
+ result = inference.run_tts(
121
+ task="tts-a",
122
+ output_filename="tts_output.wav",
123
+ text=text,
124
+ ref_audio_path=ref_audio,
125
+ temperature=0.8,
126
+ do_sample=True,
127
+ )
128
+ # Return tuple format for Gradio Audio component
129
+ return result
130
+ except ValueError as ve:
131
+ logger.error(f"TTS validation failed: {ve}")
132
+ return None
133
+ except Exception as e:
134
+ logger.error(f"TTS processing error: {e}")
135
+ return None
136
 
137
  def process_vc(src_audio, ref_audio):
138
  global inference
 
142
  if not src_audio or not ref_audio:
143
  return None
144
 
145
+ try:
146
+ # Preprocess audio
147
+ src_audio = preprocess_audio(src_audio)
148
+ ref_audio = preprocess_audio(ref_audio)
149
+
150
+ # Direct inference call - returns (sample_rate, audio_array)
151
+ result = inference.run_vc(
152
+ source_audio_path=src_audio,
153
+ ref_audio_path=ref_audio,
154
+ output_filename="vc_output.wav",
155
+ )
156
+ # Return tuple format for Gradio Audio component
157
+ return result
158
+ except ValueError as ve:
159
+ logger.error(f"VC validation failed: {ve}")
160
+ return None
161
+ except Exception as e:
162
+ logger.error(f"VC processing error: {e}")
163
+ return None
164
 
165
  # ======================== Gradio UI Layout ========================
166
 
 
199
  with gr.Column():
200
  ttsa_text = gr.Textbox(
201
  label="Synthesis Text",
202
+ placeholder=f"Enter text to synthesize (max {MAX_TEXT_LENGTH} chars)...",
203
  value="Hello, I am generated by voice cloning.",
204
  lines=3,
205
+ max_lines=10,
206
+ )
207
+ ttsa_ref = gr.Audio(
208
+ label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s",
209
+ type="filepath"
210
  )
 
211
  ttsa_output = gr.Audio(label="Synthesis Result")
212
  ttsa_btn = gr.Button("Synthesize Now", variant="primary")
213
  ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output)
 
233
  with gr.TabItem("🎭 Voice Conversion (VC)"):
234
  with gr.Row():
235
  with gr.Column():
236
+ vc_src = gr.Audio(
237
+ label=f"Source Audio (Content Source) - Max {MAX_AUDIO_DURATION}s",
238
+ type="filepath"
239
+ )
240
+ vc_ref = gr.Audio(
241
+ label=f"Reference Audio (Voice Source) - Max {MAX_AUDIO_DURATION}s",
242
+ type="filepath"
243
+ )
244
  vc_output = gr.Audio(label="Conversion Result")
245
  vc_btn = gr.Button("Start Conversion", variant="primary")
246
  vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output)
 
248
  # --- STT Tab ---
249
  with gr.TabItem("🎙️ Speech to Text (STT)"):
250
  with gr.Row():
251
+ stt_input = gr.Audio(
252
+ label=f"Input Audio - Max {MAX_AUDIO_DURATION}s",
253
+ type="filepath"
254
+ )
255
  stt_output = gr.Textbox(
256
  label="Recognition Result",
257
  placeholder="Recognition result will be displayed here in real-time...",
 
306
  args = parse_args()
307
 
308
  # Download model from Hugging Face Hub
309
+ logger.info(f"Downloading model from {args.hf_model_id}...")
310
  model_base_path = snapshot_download(
311
  repo_id=args.hf_model_id,
312
  cache_dir=args.cache_dir,
313
  resume_download=True,
314
  )
315
  # model_base_path = ""
316
+ logger.info(f"Model downloaded to: {model_base_path}")
317
 
318
  # Construct actual paths from downloaded model
319
  tokenizer_path = args.tokenizer_path or os.path.join(
 
328
  # Instantiate Model
329
  device = "cuda" if torch.cuda.is_available() else "cpu"
330
 
331
+ logger.info(f"Initializing GPA Inference System on {device}...")
332
+ logger.info(f"Tokenizer path: {tokenizer_path}")
333
+ logger.info(f"Text tokenizer path: {text_tokenizer_path}")
334
+ logger.info(f"BiCodec tokenizer path: {bicodec_tokenizer_path}")
335
+ logger.info(f"GPA model path: {gpa_model_path}")
336
 
337
  # Use None for output_dir to enable temporary directory in HF Spaces
338
  inference = GPAInference(