GPA_DEMO / gpa_inference.py
wanglamao
fix model path
f4bb8a5
import os
import argparse
import torch
import soundfile as sf
import re
import tempfile
from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperFeatureExtractor
import numpy as np
from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer
from models.bicodec_tokenizer.spark_detokenizer import SparkDeTokenizer
from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor
from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
from data_utils.audio_dataset_ark_audio import ark_infer_processor
class GPAInference:
def __init__(
self,
tokenizer_path,
text_tokenizer_path,
bicodec_tokenizer_path,
gpa_model_path,
output_dir=None,
device=None,
):
self.tokenizer_path = tokenizer_path
self.text_tokenizer_path = text_tokenizer_path
self.bicodec_tokenizer_path = bicodec_tokenizer_path
self.gpa_model_path = gpa_model_path
# Use temporary directory if output_dir is None
if output_dir is None:
self.output_dir = tempfile.mkdtemp()
print(f"Using temporary output directory: {self.output_dir}")
else:
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
self.device = device
print(f"Using device: {self.device}")
self._load_models()
def _load_models(self):
print("Loading tokenizers...")
feature_extractor = WhisperFeatureExtractor.from_pretrained(self.tokenizer_path)
audio_model = (
WhisperVQEncoder.from_pretrained(self.tokenizer_path).eval().to(self.device)
)
self.glm_tokenizer = SpeechTokenExtractor(
model=audio_model, feature_extractor=feature_extractor, device=self.device
)
self.text_tokenizer = AutoTokenizer.from_pretrained(
self.text_tokenizer_path, trust_remote_code=True
)
self.bicodec_tokenizer = SparkTokenizer(
model_path=self.bicodec_tokenizer_path, device=self.device
)
self.bicodec_detokenizer = SparkDeTokenizer(
model_path=self.bicodec_tokenizer_path, device=self.device
)
self.processor = ark_infer_processor(
glm_tokenizer=self.glm_tokenizer,
bicodec_tokenizer=self.bicodec_tokenizer,
text_tokenizer=self.text_tokenizer,
device=self.device,
audio_path_name="audio",
)
print("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
self.gpa_model_path, trust_remote_code=True
).to(self.device)
def generate(self, inputs, **kwargs):
"""
Base generation method that accepts dynamic generation parameters.
"""
for k in inputs:
if isinstance(inputs[k], (list, np.ndarray)):
inputs[k] = torch.tensor(inputs[k]).unsqueeze(0).to(self.device)
elif isinstance(inputs[k], torch.Tensor):
inputs[k] = inputs[k].unsqueeze(0).to(self.device)
# Default generation config
generation_config = {
"max_new_tokens": 1000,
"do_sample": False,
"eos_token_id": self.text_tokenizer.convert_tokens_to_ids("<|im_end|>"),
}
# Override defaults with any passed kwargs
generation_config.update(kwargs)
# Remove keys that might be None if passed from args mistakenly
generation_config = {
k: v for k, v in generation_config.items() if v is not None
}
print(f"Generation config: {generation_config}")
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generation_config,
)
return outputs
def run_stt(self, audio_path, **kwargs):
if not audio_path:
raise ValueError("audio_path is required for STT")
print("\n--- Speech to Text (STT) ---")
inputs = self.processor.process_input(
task="stt",
audio_path=audio_path,
)
# recommend hyperparameters for TTS
kwargs = {
"max_new_tokens": 512,
"do_sample": False,
}
# Pass generation arguments (temperature, etc.) to generate
outputs = self.generate(inputs, **kwargs)
text = self.text_tokenizer.decode(outputs[0].tolist())
if "<|start_content|>" in text:
return (
text.split("<|start_content|>")[1]
.replace("<|im_end|>", "")
.replace("<|end_content|>", "")
)
else:
return text.replace("<|im_end|>", "")
def run_tts(self, task, output_filename, text, ref_audio_path, **kwargs):
"""
gen_kwargs: dict, parameters for model.generate (temp, top_p, etc.)
"""
if not text:
raise ValueError("text is required for TTS")
# Check ref_audio_path requirement based on task
if task == "tts-a" and not ref_audio_path:
raise ValueError(f"ref_audio_path is required for {task}")
# recommend hyperparameters for TTS
kwargs = {
"max_new_tokens": 512,
"temperature": 0.2,
"repetition_penalty": 1.2,
"do_sample": True,
}
print(f"\n--- {task.upper()} ---")
# Pass processor specific args (e.g. emotion, pitch) here
inputs = self.processor.process_input(
task=task,
ref_audio_path=ref_audio_path,
text=text,
)
# Pass generation specific args (e.g. temperature) here
# Note: Original code hardcoded temperature=0.8 for TTS, we use gen_kwargs or fallback to generate defaults
outputs = self.generate(inputs, **kwargs)
text_output = self.text_tokenizer.decode(outputs[0].tolist())
if "<|end_content|>" in text_output:
content = text_output.split("<|end_content|>")[1]
else:
print("Warning: <|end_content|> not found")
content = text_output
audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
audio_list = [int(x) for x in audio_ids]
if ref_audio_path:
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])[
"global_tokens"
]
else:
global_tokens = torch.zeros((1, 32), dtype=torch.long).to(self.device)
req = {
"global_tokens": global_tokens,
"semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device),
}
out = self.bicodec_detokenizer.detokenize(**req)
reconstructed_wav = out.detach().cpu().float().squeeze().numpy()
# Simple DC offset removal
if reconstructed_wav.size > 0:
reconstructed_wav -= reconstructed_wav.mean()
output_path = os.path.join(self.output_dir, output_filename)
sf.write(output_path, reconstructed_wav, 16000)
print(f"Saved output to {output_path}")
return 16000, reconstructed_wav
def run_vc(
self,
source_audio_path,
ref_audio_path,
output_filename="output_gpa_vc.wav",
**kwargs,
):
if not source_audio_path:
raise ValueError("source_audio_path is required for VC")
if not ref_audio_path:
raise ValueError("ref_audio_path is required for VC")
print("\n--- Voice Conversion (VC) ---")
output_path = os.path.join(self.output_dir, output_filename)
inputs = self.processor.process_input(
task="vc",
audio_path=source_audio_path,
ref_audio_path=ref_audio_path,
)
outputs = self.generate(inputs, **kwargs)
text_output = self.text_tokenizer.decode(outputs[0].tolist())
if "<|end_content|>" in text_output:
content = text_output.split("<|end_content|>")[1]
else:
content = text_output
audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
audio_list = [int(x) for x in audio_ids]
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])[
"global_tokens"
]
req = {
"global_tokens": global_tokens,
"semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device),
}
out = self.bicodec_detokenizer.detokenize(**req)
reconstructed_wav = out.detach().cpu().float().squeeze().numpy()
if reconstructed_wav.size > 0:
reconstructed_wav -= reconstructed_wav.mean()
sf.write(output_path, reconstructed_wav, 16000)
print(f"Saved VC output to {output_path}")
return 16000, reconstructed_wav
def parse_args():
parser = argparse.ArgumentParser(description="GPA Inference Script")
# Paths
parser.add_argument(
"--tokenizer_path",
type=str,
default="/nasdata/model/gpa/glm-4-voice-tokenizer",
help="Path to GLM4 tokenizer",
)
parser.add_argument(
"--text_tokenizer_path",
type=str,
default="/nasdata/model/gpa",
help="Path to text tokenizer",
)
parser.add_argument(
"--bicodec_tokenizer_path",
type=str,
default="/nasdata/model/gpa/BiCodec/",
help="Path to BiCodec tokenizer",
)
parser.add_argument(
"--gpa_model_path",
type=str,
default="/nasdata/model/gpa",
help="Path to GPA model",
)
# Audio inputs
parser.add_argument(
"--ref_audio_path", type=str, default=None, help="Reference audio path"
)
parser.add_argument(
"--src_audio_path", type=str, default=None, help="Source audio path for VC/STT"
)
# Output
parser.add_argument(
"--output_dir", type=str, default=".", help="Directory to save output files"
)
# Device
default_device = "cuda" if torch.cuda.is_available() else "cpu"
parser.add_argument(
"--device",
type=str,
default=default_device,
help="Device to use (e.g., cuda:0, cpu)",
)
# Task
parser.add_argument(
"--task",
type=str,
required=True,
choices=["stt", "tts-a", "vc"],
help="Task to run",
)
# TTS Inputs (Processor Arguments)
parser.add_argument("--text", type=str, default=None, help="Text for TTS")
return parser.parse_args()
def main():
args = parse_args()
# Ensure output directory exists
os.makedirs(args.output_dir, exist_ok=True)
inference = GPAInference(
tokenizer_path=args.tokenizer_path,
text_tokenizer_path=args.text_tokenizer_path,
bicodec_tokenizer_path=args.bicodec_tokenizer_path,
gpa_model_path=args.gpa_model_path,
output_dir=args.output_dir,
device=args.device,
)
if args.task == "stt":
if not args.src_audio_path:
raise ValueError("Error: --src_audio_path is required for STT task.")
# Pass gen_kwargs
result = inference.run_stt(audio_path=args.src_audio_path)
print("STT Result:", result)
elif args.task == "tts-a":
inference.run_tts(
task="tts-a",
output_filename="output_gpa_tts_a.wav",
text=args.text,
ref_audio_path=args.ref_audio_path,
)
elif args.task == "vc":
inference.run_vc(
source_audio_path=args.src_audio_path,
ref_audio_path=args.ref_audio_path,
output_filename="output_gpa_vc.wav",
)
if __name__ == "__main__":
main()