Spaces:
Running
on
L4
Running
on
L4
| import os | |
| import argparse | |
| import torch | |
| import soundfile as sf | |
| import re | |
| import tempfile | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperFeatureExtractor | |
| import numpy as np | |
| from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer | |
| from models.bicodec_tokenizer.spark_detokenizer import SparkDeTokenizer | |
| from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor | |
| from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder | |
| from data_utils.audio_dataset_ark_audio import ark_infer_processor | |
| class GPAInference: | |
| def __init__( | |
| self, | |
| tokenizer_path, | |
| text_tokenizer_path, | |
| bicodec_tokenizer_path, | |
| gpa_model_path, | |
| output_dir=None, | |
| device=None, | |
| ): | |
| self.tokenizer_path = tokenizer_path | |
| self.text_tokenizer_path = text_tokenizer_path | |
| self.bicodec_tokenizer_path = bicodec_tokenizer_path | |
| self.gpa_model_path = gpa_model_path | |
| # Use temporary directory if output_dir is None | |
| if output_dir is None: | |
| self.output_dir = tempfile.mkdtemp() | |
| print(f"Using temporary output directory: {self.output_dir}") | |
| else: | |
| self.output_dir = output_dir | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| self.device = device | |
| print(f"Using device: {self.device}") | |
| self._load_models() | |
| def _load_models(self): | |
| print("Loading tokenizers...") | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained(self.tokenizer_path) | |
| audio_model = ( | |
| WhisperVQEncoder.from_pretrained(self.tokenizer_path).eval().to(self.device) | |
| ) | |
| self.glm_tokenizer = SpeechTokenExtractor( | |
| model=audio_model, feature_extractor=feature_extractor, device=self.device | |
| ) | |
| self.text_tokenizer = AutoTokenizer.from_pretrained( | |
| self.text_tokenizer_path, trust_remote_code=True | |
| ) | |
| self.bicodec_tokenizer = SparkTokenizer( | |
| model_path=self.bicodec_tokenizer_path, device=self.device | |
| ) | |
| self.bicodec_detokenizer = SparkDeTokenizer( | |
| model_path=self.bicodec_tokenizer_path, device=self.device | |
| ) | |
| self.processor = ark_infer_processor( | |
| glm_tokenizer=self.glm_tokenizer, | |
| bicodec_tokenizer=self.bicodec_tokenizer, | |
| text_tokenizer=self.text_tokenizer, | |
| device=self.device, | |
| audio_path_name="audio", | |
| ) | |
| print("Loading model...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.gpa_model_path, trust_remote_code=True | |
| ).to(self.device) | |
| def generate(self, inputs, **kwargs): | |
| """ | |
| Base generation method that accepts dynamic generation parameters. | |
| """ | |
| for k in inputs: | |
| if isinstance(inputs[k], (list, np.ndarray)): | |
| inputs[k] = torch.tensor(inputs[k]).unsqueeze(0).to(self.device) | |
| elif isinstance(inputs[k], torch.Tensor): | |
| inputs[k] = inputs[k].unsqueeze(0).to(self.device) | |
| # Default generation config | |
| generation_config = { | |
| "max_new_tokens": 1000, | |
| "do_sample": False, | |
| "eos_token_id": self.text_tokenizer.convert_tokens_to_ids("<|im_end|>"), | |
| } | |
| # Override defaults with any passed kwargs | |
| generation_config.update(kwargs) | |
| # Remove keys that might be None if passed from args mistakenly | |
| generation_config = { | |
| k: v for k, v in generation_config.items() if v is not None | |
| } | |
| print(f"Generation config: {generation_config}") | |
| outputs = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| **generation_config, | |
| ) | |
| return outputs | |
| def run_stt(self, audio_path, **kwargs): | |
| if not audio_path: | |
| raise ValueError("audio_path is required for STT") | |
| print("\n--- Speech to Text (STT) ---") | |
| inputs = self.processor.process_input( | |
| task="stt", | |
| audio_path=audio_path, | |
| ) | |
| # recommend hyperparameters for TTS | |
| kwargs = { | |
| "max_new_tokens": 512, | |
| "do_sample": False, | |
| } | |
| # Pass generation arguments (temperature, etc.) to generate | |
| outputs = self.generate(inputs, **kwargs) | |
| text = self.text_tokenizer.decode(outputs[0].tolist()) | |
| if "<|start_content|>" in text: | |
| return ( | |
| text.split("<|start_content|>")[1] | |
| .replace("<|im_end|>", "") | |
| .replace("<|end_content|>", "") | |
| ) | |
| else: | |
| return text.replace("<|im_end|>", "") | |
| def run_tts(self, task, output_filename, text, ref_audio_path, **kwargs): | |
| """ | |
| gen_kwargs: dict, parameters for model.generate (temp, top_p, etc.) | |
| """ | |
| if not text: | |
| raise ValueError("text is required for TTS") | |
| # Check ref_audio_path requirement based on task | |
| if task == "tts-a" and not ref_audio_path: | |
| raise ValueError(f"ref_audio_path is required for {task}") | |
| # recommend hyperparameters for TTS | |
| kwargs = { | |
| "max_new_tokens": 512, | |
| "temperature": 0.2, | |
| "repetition_penalty": 1.2, | |
| "do_sample": True, | |
| } | |
| print(f"\n--- {task.upper()} ---") | |
| # Pass processor specific args (e.g. emotion, pitch) here | |
| inputs = self.processor.process_input( | |
| task=task, | |
| ref_audio_path=ref_audio_path, | |
| text=text, | |
| ) | |
| # Pass generation specific args (e.g. temperature) here | |
| # Note: Original code hardcoded temperature=0.8 for TTS, we use gen_kwargs or fallback to generate defaults | |
| outputs = self.generate(inputs, **kwargs) | |
| text_output = self.text_tokenizer.decode(outputs[0].tolist()) | |
| if "<|end_content|>" in text_output: | |
| content = text_output.split("<|end_content|>")[1] | |
| else: | |
| print("Warning: <|end_content|> not found") | |
| content = text_output | |
| audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content) | |
| audio_list = [int(x) for x in audio_ids] | |
| if ref_audio_path: | |
| global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])[ | |
| "global_tokens" | |
| ] | |
| else: | |
| global_tokens = torch.zeros((1, 32), dtype=torch.long).to(self.device) | |
| req = { | |
| "global_tokens": global_tokens, | |
| "semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device), | |
| } | |
| out = self.bicodec_detokenizer.detokenize(**req) | |
| reconstructed_wav = out.detach().cpu().float().squeeze().numpy() | |
| # Simple DC offset removal | |
| if reconstructed_wav.size > 0: | |
| reconstructed_wav -= reconstructed_wav.mean() | |
| output_path = os.path.join(self.output_dir, output_filename) | |
| sf.write(output_path, reconstructed_wav, 16000) | |
| print(f"Saved output to {output_path}") | |
| return 16000, reconstructed_wav | |
| def run_vc( | |
| self, | |
| source_audio_path, | |
| ref_audio_path, | |
| output_filename="output_gpa_vc.wav", | |
| **kwargs, | |
| ): | |
| if not source_audio_path: | |
| raise ValueError("source_audio_path is required for VC") | |
| if not ref_audio_path: | |
| raise ValueError("ref_audio_path is required for VC") | |
| print("\n--- Voice Conversion (VC) ---") | |
| output_path = os.path.join(self.output_dir, output_filename) | |
| inputs = self.processor.process_input( | |
| task="vc", | |
| audio_path=source_audio_path, | |
| ref_audio_path=ref_audio_path, | |
| ) | |
| outputs = self.generate(inputs, **kwargs) | |
| text_output = self.text_tokenizer.decode(outputs[0].tolist()) | |
| if "<|end_content|>" in text_output: | |
| content = text_output.split("<|end_content|>")[1] | |
| else: | |
| content = text_output | |
| audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content) | |
| audio_list = [int(x) for x in audio_ids] | |
| global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])[ | |
| "global_tokens" | |
| ] | |
| req = { | |
| "global_tokens": global_tokens, | |
| "semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device), | |
| } | |
| out = self.bicodec_detokenizer.detokenize(**req) | |
| reconstructed_wav = out.detach().cpu().float().squeeze().numpy() | |
| if reconstructed_wav.size > 0: | |
| reconstructed_wav -= reconstructed_wav.mean() | |
| sf.write(output_path, reconstructed_wav, 16000) | |
| print(f"Saved VC output to {output_path}") | |
| return 16000, reconstructed_wav | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="GPA Inference Script") | |
| # Paths | |
| parser.add_argument( | |
| "--tokenizer_path", | |
| type=str, | |
| default="/nasdata/model/gpa/glm-4-voice-tokenizer", | |
| help="Path to GLM4 tokenizer", | |
| ) | |
| parser.add_argument( | |
| "--text_tokenizer_path", | |
| type=str, | |
| default="/nasdata/model/gpa", | |
| help="Path to text tokenizer", | |
| ) | |
| parser.add_argument( | |
| "--bicodec_tokenizer_path", | |
| type=str, | |
| default="/nasdata/model/gpa/BiCodec/", | |
| help="Path to BiCodec tokenizer", | |
| ) | |
| parser.add_argument( | |
| "--gpa_model_path", | |
| type=str, | |
| default="/nasdata/model/gpa", | |
| help="Path to GPA model", | |
| ) | |
| # Audio inputs | |
| parser.add_argument( | |
| "--ref_audio_path", type=str, default=None, help="Reference audio path" | |
| ) | |
| parser.add_argument( | |
| "--src_audio_path", type=str, default=None, help="Source audio path for VC/STT" | |
| ) | |
| # Output | |
| parser.add_argument( | |
| "--output_dir", type=str, default=".", help="Directory to save output files" | |
| ) | |
| # Device | |
| default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=default_device, | |
| help="Device to use (e.g., cuda:0, cpu)", | |
| ) | |
| # Task | |
| parser.add_argument( | |
| "--task", | |
| type=str, | |
| required=True, | |
| choices=["stt", "tts-a", "vc"], | |
| help="Task to run", | |
| ) | |
| # TTS Inputs (Processor Arguments) | |
| parser.add_argument("--text", type=str, default=None, help="Text for TTS") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Ensure output directory exists | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| inference = GPAInference( | |
| tokenizer_path=args.tokenizer_path, | |
| text_tokenizer_path=args.text_tokenizer_path, | |
| bicodec_tokenizer_path=args.bicodec_tokenizer_path, | |
| gpa_model_path=args.gpa_model_path, | |
| output_dir=args.output_dir, | |
| device=args.device, | |
| ) | |
| if args.task == "stt": | |
| if not args.src_audio_path: | |
| raise ValueError("Error: --src_audio_path is required for STT task.") | |
| # Pass gen_kwargs | |
| result = inference.run_stt(audio_path=args.src_audio_path) | |
| print("STT Result:", result) | |
| elif args.task == "tts-a": | |
| inference.run_tts( | |
| task="tts-a", | |
| output_filename="output_gpa_tts_a.wav", | |
| text=args.text, | |
| ref_audio_path=args.ref_audio_path, | |
| ) | |
| elif args.task == "vc": | |
| inference.run_vc( | |
| source_audio_path=args.src_audio_path, | |
| ref_audio_path=args.ref_audio_path, | |
| output_filename="output_gpa_vc.wav", | |
| ) | |
| if __name__ == "__main__": | |
| main() | |