wanglamao commited on
Commit
f4bb8a5
·
1 Parent(s): 239bcb6

fix model path

Browse files
Files changed (2) hide show
  1. app.py +58 -32
  2. gpa_inference.py +96 -26
app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import argparse
6
  import librosa
7
  import soundfile as sf
 
8
 
9
  from gpa_inference import GPAInference
10
 
@@ -61,8 +62,8 @@ def process_tts_a(text, ref_audio):
61
  # Preprocess audio
62
  ref_audio = preprocess_audio(ref_audio)
63
 
64
- # Direct inference call
65
- return inference.run_tts(
66
  task="tts-a",
67
  output_filename="tts_output.wav",
68
  text=text,
@@ -70,6 +71,8 @@ def process_tts_a(text, ref_audio):
70
  temperature=0.8,
71
  do_sample=True,
72
  )
 
 
73
 
74
  def process_vc(src_audio, ref_audio):
75
  global inference
@@ -83,12 +86,14 @@ def process_vc(src_audio, ref_audio):
83
  src_audio = preprocess_audio(src_audio)
84
  ref_audio = preprocess_audio(ref_audio)
85
 
86
- # Direct inference call
87
- return inference.run_vc(
88
  source_audio_path=src_audio,
89
  ref_audio_path=ref_audio,
90
  output_filename="vc_output.wav",
91
  )
 
 
92
 
93
  # ======================== Gradio UI Layout ========================
94
 
@@ -139,42 +144,40 @@ def parse_args():
139
 
140
  # Model Paths
141
  parser.add_argument(
142
- "--tokenizer_path",
143
  type=str,
144
- default="/data3/gpa_ckpt/gpa_final/glm-4-voice-tokenizer",
145
- help="Path to GLM4 tokenizer",
146
  )
147
  parser.add_argument(
148
- "--text_tokenizer_path",
149
  type=str,
150
- default="/data3/gpa_ckpt/gpa_final",
151
- help="Path to text tokenizer",
152
  )
153
  parser.add_argument(
154
- "--bicodec_tokenizer_path",
155
  type=str,
156
- default="/data3/gpa_ckpt/gpa_final/BiCodec/",
157
- help="Path to BiCodec tokenizer",
158
  )
159
  parser.add_argument(
160
- "--gpa_model_path",
161
  type=str,
162
- default="/data3/gpa_ckpt/gpa_final",
163
- help="Path to GPA model",
164
  )
165
-
166
- # System Config
167
  parser.add_argument(
168
- "--output_dir",
169
  type=str,
170
- default="./output_gui",
171
- help="Directory to save output files",
172
  )
173
  parser.add_argument(
174
- "--device",
175
  type=str,
176
- default="cuda" if torch.cuda.is_available() else "cpu",
177
- help="Device to use",
178
  )
179
 
180
  # Server Config
@@ -189,18 +192,41 @@ def parse_args():
189
 
190
  args = parse_args()
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  # Instantiate Model
193
  print(f"Initializing GPA Inference System on {args.device}...")
194
- os.makedirs(args.output_dir, exist_ok=True)
 
 
 
195
 
 
196
  inference = GPAInference(
197
- tokenizer_path=args.tokenizer_path,
198
- text_tokenizer_path=args.text_tokenizer_path,
199
- bicodec_tokenizer_path=args.bicodec_tokenizer_path,
200
- gpa_model_path=args.gpa_model_path,
201
- output_dir=args.output_dir,
202
- device=args.device,
203
  )
204
 
205
  # Launch Gradio Demo
206
- demo.queue().launch()
 
5
  import argparse
6
  import librosa
7
  import soundfile as sf
8
+ from huggingface_hub import snapshot_download
9
 
10
  from gpa_inference import GPAInference
11
 
 
62
  # Preprocess audio
63
  ref_audio = preprocess_audio(ref_audio)
64
 
65
+ # Direct inference call - returns (sample_rate, audio_array)
66
+ result = inference.run_tts(
67
  task="tts-a",
68
  output_filename="tts_output.wav",
69
  text=text,
 
71
  temperature=0.8,
72
  do_sample=True,
73
  )
74
+ # Return tuple format for Gradio Audio component
75
+ return result
76
 
77
  def process_vc(src_audio, ref_audio):
78
  global inference
 
86
  src_audio = preprocess_audio(src_audio)
87
  ref_audio = preprocess_audio(ref_audio)
88
 
89
+ # Direct inference call - returns (sample_rate, audio_array)
90
+ result = inference.run_vc(
91
  source_audio_path=src_audio,
92
  ref_audio_path=ref_audio,
93
  output_filename="vc_output.wav",
94
  )
95
+ # Return tuple format for Gradio Audio component
96
+ return result
97
 
98
  # ======================== Gradio UI Layout ========================
99
 
 
144
 
145
  # Model Paths
146
  parser.add_argument(
147
+ "--hf_model_id",
148
  type=str,
149
+ default="AutoArk-AI/GPA",
150
+ help="Hugging Face model ID to download",
151
  )
152
  parser.add_argument(
153
+ "--cache_dir",
154
  type=str,
155
+ default="./models",
156
+ help="Directory to cache downloaded models",
157
  )
158
  parser.add_argument(
159
+ "--tokenizer_path",
160
  type=str,
161
+ default=None,
162
+ help="Path to GLM4 tokenizer (if None, will use downloaded model)",
163
  )
164
  parser.add_argument(
165
+ "--text_tokenizer_path",
166
  type=str,
167
+ default=None,
168
+ help="Path to text tokenizer (if None, will use downloaded model)",
169
  )
 
 
170
  parser.add_argument(
171
+ "--bicodec_tokenizer_path",
172
  type=str,
173
+ default=None,
174
+ help="Path to BiCodec tokenizer (if None, will use downloaded model)",
175
  )
176
  parser.add_argument(
177
+ "--gpa_model_path",
178
  type=str,
179
+ default=None,
180
+ help="Path to GPA model (if None, will use downloaded model)",
181
  )
182
 
183
  # Server Config
 
192
 
193
  args = parse_args()
194
 
195
+ # Download model from Hugging Face Hub
196
+ print(f"Downloading model from {args.hf_model_id}...")
197
+ model_base_path = snapshot_download(
198
+ repo_id=args.hf_model_id,
199
+ cache_dir=args.cache_dir,
200
+ resume_download=True,
201
+ )
202
+ print(f"Model downloaded to: {model_base_path}")
203
+
204
+ # Construct actual paths from downloaded model
205
+ tokenizer_path = args.tokenizer_path or os.path.join(
206
+ model_base_path, "glm-4-voice-tokenizer"
207
+ )
208
+ text_tokenizer_path = args.text_tokenizer_path or model_base_path
209
+ bicodec_tokenizer_path = args.bicodec_tokenizer_path or os.path.join(
210
+ model_base_path, "BiCodec"
211
+ )
212
+ gpa_model_path = args.gpa_model_path or model_base_path
213
+
214
  # Instantiate Model
215
  print(f"Initializing GPA Inference System on {args.device}...")
216
+ print(f"Tokenizer path: {tokenizer_path}")
217
+ print(f"Text tokenizer path: {text_tokenizer_path}")
218
+ print(f"BiCodec tokenizer path: {bicodec_tokenizer_path}")
219
+ print(f"GPA model path: {gpa_model_path}")
220
 
221
+ # Use None for output_dir to enable temporary directory in HF Spaces
222
  inference = GPAInference(
223
+ tokenizer_path=tokenizer_path,
224
+ text_tokenizer_path=text_tokenizer_path,
225
+ bicodec_tokenizer_path=bicodec_tokenizer_path,
226
+ gpa_model_path=gpa_model_path,
227
+ output_dir=None, # Will use temporary directory
228
+ device="cuda" if torch.cuda.is_available() else "cpu",
229
  )
230
 
231
  # Launch Gradio Demo
232
+ demo.queue().launch(server_name=args.server_name, server_port=args.server_port)
gpa_inference.py CHANGED
@@ -3,6 +3,7 @@ import argparse
3
  import torch
4
  import soundfile as sf
5
  import re
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperFeatureExtractor
7
  import numpy as np
8
 
@@ -14,13 +15,31 @@ from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
14
 
15
  from data_utils.audio_dataset_ark_audio import ark_infer_processor
16
 
 
17
  class GPAInference:
18
- def __init__(self, tokenizer_path, text_tokenizer_path, bicodec_tokenizer_path, gpa_model_path, output_dir, device):
 
 
 
 
 
 
 
 
 
19
  self.tokenizer_path = tokenizer_path
20
  self.text_tokenizer_path = text_tokenizer_path
21
  self.bicodec_tokenizer_path = bicodec_tokenizer_path
22
  self.gpa_model_path = gpa_model_path
23
- self.output_dir = output_dir
 
 
 
 
 
 
 
 
24
  self.device = device
25
 
26
  print(f"Using device: {self.device}")
@@ -29,15 +48,22 @@ class GPAInference:
29
  def _load_models(self):
30
  print("Loading tokenizers...")
31
  feature_extractor = WhisperFeatureExtractor.from_pretrained(self.tokenizer_path)
32
- audio_model = WhisperVQEncoder.from_pretrained(self.tokenizer_path).eval().to(self.device)
33
- self.glm_tokenizer = SpeechTokenExtractor(model=audio_model, feature_extractor=feature_extractor, device=self.device)
 
 
 
 
34
  self.text_tokenizer = AutoTokenizer.from_pretrained(
35
- self.text_tokenizer_path,
36
- trust_remote_code=True
37
  )
38
 
39
- self.bicodec_tokenizer = SparkTokenizer(model_path=self.bicodec_tokenizer_path, device=self.device)
40
- self.bicodec_detokenizer = SparkDeTokenizer(model_path=self.bicodec_tokenizer_path, device=self.device)
 
 
 
 
41
  self.processor = ark_infer_processor(
42
  glm_tokenizer=self.glm_tokenizer,
43
  bicodec_tokenizer=self.bicodec_tokenizer,
@@ -48,8 +74,7 @@ class GPAInference:
48
 
49
  print("Loading model...")
50
  self.model = AutoModelForCausalLM.from_pretrained(
51
- self.gpa_model_path,
52
- trust_remote_code=True
53
  ).to(self.device)
54
 
55
  def generate(self, inputs, **kwargs):
@@ -73,13 +98,15 @@ class GPAInference:
73
  generation_config.update(kwargs)
74
 
75
  # Remove keys that might be None if passed from args mistakenly
76
- generation_config = {k: v for k, v in generation_config.items() if v is not None}
 
 
77
  print(f"Generation config: {generation_config}")
78
 
79
  outputs = self.model.generate(
80
  input_ids=inputs["input_ids"],
81
  attention_mask=inputs["attention_mask"],
82
- **generation_config
83
  )
84
  return outputs
85
 
@@ -105,9 +132,13 @@ class GPAInference:
105
  text = self.text_tokenizer.decode(outputs[0].tolist())
106
 
107
  if "<|start_content|>" in text:
108
- return text.split("<|start_content|>")[1].replace("<|im_end|>","").replace("<|end_content|>","")
 
 
 
 
109
  else:
110
- return text.replace("<|im_end|>","")
111
 
112
  def run_tts(self, task, output_filename, text, ref_audio_path, **kwargs):
113
  """
@@ -129,12 +160,11 @@ class GPAInference:
129
  }
130
 
131
  print(f"\n--- {task.upper()} ---")
132
- output_path = os.path.join(self.output_dir, output_filename)
133
 
134
  # Pass processor specific args (e.g. emotion, pitch) here
135
  inputs = self.processor.process_input(
136
- task=task,
137
- ref_audio_path=ref_audio_path,
138
  text=text,
139
  )
140
 
@@ -154,7 +184,9 @@ class GPAInference:
154
  audio_list = [int(x) for x in audio_ids]
155
 
156
  if ref_audio_path:
157
- global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
 
 
158
  else:
159
  global_tokens = torch.zeros((1, 32), dtype=torch.long).to(self.device)
160
 
@@ -168,6 +200,7 @@ class GPAInference:
168
  if reconstructed_wav.size > 0:
169
  reconstructed_wav -= reconstructed_wav.mean()
170
 
 
171
  sf.write(output_path, reconstructed_wav, 16000)
172
  print(f"Saved output to {output_path}")
173
  return 16000, reconstructed_wav
@@ -204,7 +237,9 @@ class GPAInference:
204
  audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
205
  audio_list = [int(x) for x in audio_ids]
206
 
207
- global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
 
 
208
 
209
  req = {
210
  "global_tokens": global_tokens,
@@ -224,10 +259,30 @@ def parse_args():
224
  parser = argparse.ArgumentParser(description="GPA Inference Script")
225
 
226
  # Paths
227
- parser.add_argument("--tokenizer_path", type=str, default="/nasdata/model/gpa/glm-4-voice-tokenizer", help="Path to GLM4 tokenizer")
228
- parser.add_argument("--text_tokenizer_path", type=str, default="/nasdata/model/gpa", help="Path to text tokenizer")
229
- parser.add_argument("--bicodec_tokenizer_path", type=str, default="/nasdata/model/gpa/BiCodec/", help="Path to BiCodec tokenizer")
230
- parser.add_argument("--gpa_model_path", type=str, default="/nasdata/model/gpa", help="Path to GPA model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # Audio inputs
233
  parser.add_argument(
@@ -238,20 +293,34 @@ def parse_args():
238
  )
239
 
240
  # Output
241
- parser.add_argument("--output_dir", type=str, default=".", help="Directory to save output files")
 
 
242
 
243
  # Device
244
  default_device = "cuda" if torch.cuda.is_available() else "cpu"
245
- parser.add_argument("--device", type=str, default=default_device, help="Device to use (e.g., cuda:0, cpu)")
 
 
 
 
 
246
 
247
  # Task
248
- parser.add_argument("--task", type=str, required=True, choices=["stt", "tts-a", "vc"], help="Task to run")
 
 
 
 
 
 
249
 
250
  # TTS Inputs (Processor Arguments)
251
  parser.add_argument("--text", type=str, default=None, help="Text for TTS")
252
 
253
  return parser.parse_args()
254
 
 
255
  def main():
256
  args = parse_args()
257
 
@@ -289,5 +358,6 @@ def main():
289
  output_filename="output_gpa_vc.wav",
290
  )
291
 
 
292
  if __name__ == "__main__":
293
  main()
 
3
  import torch
4
  import soundfile as sf
5
  import re
6
+ import tempfile
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperFeatureExtractor
8
  import numpy as np
9
 
 
15
 
16
  from data_utils.audio_dataset_ark_audio import ark_infer_processor
17
 
18
+
19
  class GPAInference:
20
+
21
+ def __init__(
22
+ self,
23
+ tokenizer_path,
24
+ text_tokenizer_path,
25
+ bicodec_tokenizer_path,
26
+ gpa_model_path,
27
+ output_dir=None,
28
+ device=None,
29
+ ):
30
  self.tokenizer_path = tokenizer_path
31
  self.text_tokenizer_path = text_tokenizer_path
32
  self.bicodec_tokenizer_path = bicodec_tokenizer_path
33
  self.gpa_model_path = gpa_model_path
34
+
35
+ # Use temporary directory if output_dir is None
36
+ if output_dir is None:
37
+ self.output_dir = tempfile.mkdtemp()
38
+ print(f"Using temporary output directory: {self.output_dir}")
39
+ else:
40
+ self.output_dir = output_dir
41
+ os.makedirs(self.output_dir, exist_ok=True)
42
+
43
  self.device = device
44
 
45
  print(f"Using device: {self.device}")
 
48
  def _load_models(self):
49
  print("Loading tokenizers...")
50
  feature_extractor = WhisperFeatureExtractor.from_pretrained(self.tokenizer_path)
51
+ audio_model = (
52
+ WhisperVQEncoder.from_pretrained(self.tokenizer_path).eval().to(self.device)
53
+ )
54
+ self.glm_tokenizer = SpeechTokenExtractor(
55
+ model=audio_model, feature_extractor=feature_extractor, device=self.device
56
+ )
57
  self.text_tokenizer = AutoTokenizer.from_pretrained(
58
+ self.text_tokenizer_path, trust_remote_code=True
 
59
  )
60
 
61
+ self.bicodec_tokenizer = SparkTokenizer(
62
+ model_path=self.bicodec_tokenizer_path, device=self.device
63
+ )
64
+ self.bicodec_detokenizer = SparkDeTokenizer(
65
+ model_path=self.bicodec_tokenizer_path, device=self.device
66
+ )
67
  self.processor = ark_infer_processor(
68
  glm_tokenizer=self.glm_tokenizer,
69
  bicodec_tokenizer=self.bicodec_tokenizer,
 
74
 
75
  print("Loading model...")
76
  self.model = AutoModelForCausalLM.from_pretrained(
77
+ self.gpa_model_path, trust_remote_code=True
 
78
  ).to(self.device)
79
 
80
  def generate(self, inputs, **kwargs):
 
98
  generation_config.update(kwargs)
99
 
100
  # Remove keys that might be None if passed from args mistakenly
101
+ generation_config = {
102
+ k: v for k, v in generation_config.items() if v is not None
103
+ }
104
  print(f"Generation config: {generation_config}")
105
 
106
  outputs = self.model.generate(
107
  input_ids=inputs["input_ids"],
108
  attention_mask=inputs["attention_mask"],
109
+ **generation_config,
110
  )
111
  return outputs
112
 
 
132
  text = self.text_tokenizer.decode(outputs[0].tolist())
133
 
134
  if "<|start_content|>" in text:
135
+ return (
136
+ text.split("<|start_content|>")[1]
137
+ .replace("<|im_end|>", "")
138
+ .replace("<|end_content|>", "")
139
+ )
140
  else:
141
+ return text.replace("<|im_end|>", "")
142
 
143
  def run_tts(self, task, output_filename, text, ref_audio_path, **kwargs):
144
  """
 
160
  }
161
 
162
  print(f"\n--- {task.upper()} ---")
 
163
 
164
  # Pass processor specific args (e.g. emotion, pitch) here
165
  inputs = self.processor.process_input(
166
+ task=task,
167
+ ref_audio_path=ref_audio_path,
168
  text=text,
169
  )
170
 
 
184
  audio_list = [int(x) for x in audio_ids]
185
 
186
  if ref_audio_path:
187
+ global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])[
188
+ "global_tokens"
189
+ ]
190
  else:
191
  global_tokens = torch.zeros((1, 32), dtype=torch.long).to(self.device)
192
 
 
200
  if reconstructed_wav.size > 0:
201
  reconstructed_wav -= reconstructed_wav.mean()
202
 
203
+ output_path = os.path.join(self.output_dir, output_filename)
204
  sf.write(output_path, reconstructed_wav, 16000)
205
  print(f"Saved output to {output_path}")
206
  return 16000, reconstructed_wav
 
237
  audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
238
  audio_list = [int(x) for x in audio_ids]
239
 
240
+ global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])[
241
+ "global_tokens"
242
+ ]
243
 
244
  req = {
245
  "global_tokens": global_tokens,
 
259
  parser = argparse.ArgumentParser(description="GPA Inference Script")
260
 
261
  # Paths
262
+ parser.add_argument(
263
+ "--tokenizer_path",
264
+ type=str,
265
+ default="/nasdata/model/gpa/glm-4-voice-tokenizer",
266
+ help="Path to GLM4 tokenizer",
267
+ )
268
+ parser.add_argument(
269
+ "--text_tokenizer_path",
270
+ type=str,
271
+ default="/nasdata/model/gpa",
272
+ help="Path to text tokenizer",
273
+ )
274
+ parser.add_argument(
275
+ "--bicodec_tokenizer_path",
276
+ type=str,
277
+ default="/nasdata/model/gpa/BiCodec/",
278
+ help="Path to BiCodec tokenizer",
279
+ )
280
+ parser.add_argument(
281
+ "--gpa_model_path",
282
+ type=str,
283
+ default="/nasdata/model/gpa",
284
+ help="Path to GPA model",
285
+ )
286
 
287
  # Audio inputs
288
  parser.add_argument(
 
293
  )
294
 
295
  # Output
296
+ parser.add_argument(
297
+ "--output_dir", type=str, default=".", help="Directory to save output files"
298
+ )
299
 
300
  # Device
301
  default_device = "cuda" if torch.cuda.is_available() else "cpu"
302
+ parser.add_argument(
303
+ "--device",
304
+ type=str,
305
+ default=default_device,
306
+ help="Device to use (e.g., cuda:0, cpu)",
307
+ )
308
 
309
  # Task
310
+ parser.add_argument(
311
+ "--task",
312
+ type=str,
313
+ required=True,
314
+ choices=["stt", "tts-a", "vc"],
315
+ help="Task to run",
316
+ )
317
 
318
  # TTS Inputs (Processor Arguments)
319
  parser.add_argument("--text", type=str, default=None, help="Text for TTS")
320
 
321
  return parser.parse_args()
322
 
323
+
324
  def main():
325
  args = parse_args()
326
 
 
358
  output_filename="output_gpa_vc.wav",
359
  )
360
 
361
+
362
  if __name__ == "__main__":
363
  main()