import numpy as np import torch import librosa import os import sys import argparse # Add project root to sys.path to ensure tts module is found sys.path.append('/mnt/data/MegaTTS3') try: from tts.infer_cli import MegaTTS3DiTInfer, hparams except ImportError as e: print(f"Failed to import MegaTTS3DiTInfer and hparams: {e}") sys.exit(1) def generate_npy_file(audio_path, output_npy_path, model, sample_rate=24000): """ Generate and save a .npy file containing the latent representation of an audio file. :param audio_path: Path to the input audio file (e.g., .wav, .mp3). :param output_npy_path: Path where the .npy file will be saved. :param model: Instance of MegaTTS3DiTInfer with a loaded WaveVAE encoder. :param sample_rate: Sample rate for audio (default: 24000). :return: True if successful, False otherwise. """ try: if not os.path.exists(audio_path): raise FileNotFoundError(f"Input audio file not found: {audio_path}") # Ensure output directory exists os.makedirs(os.path.dirname(output_npy_path), exist_ok=True) # Load and preprocess audio wav, _ = librosa.core.load(audio_path, sr=sample_rate) ws = hparams['win_size'] if len(wav) % ws < ws - 1: wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32) wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32) # Encode to latent representation if model.has_vae_encoder: wav = torch.FloatTensor(wav)[None].to(model.device) with torch.inference_mode(): vae_latent = model.wavvae.encode_latent(wav) # Note: Changed from wavvae_en to wavvae # Save latent to .npy file np.save(output_npy_path, vae_latent.cpu().numpy()) return True else: raise ValueError("WaveVAE encoder model is not available. Cannot generate .npy file.") except Exception as e: print(f"Error generating .npy file: {e}") return False def extract_vae_features(input_wav, output_npy): """ Wrapper function to initialize the model and generate the .npy file. :param input_wav: Path to the input WAV file. :param output_npy: Path where the .npy file will be saved. :return: True if successful, False otherwise. """ try: # Initialize the MegaTTS3DiTInfer model model = MegaTTS3DiTInfer(ckpt_root='/mnt/data/MegaTTS3/checkpoints') # Generate the .npy file success = generate_npy_file(input_wav, output_npy, model) # Clean up model to free memory model.wavvae = None model.dur_model = None model.dit = None model.g2p_model = None model.aligner_lm = None torch.cuda.empty_cache() return success except Exception as e: print(f"Error in extract_vae_features: {e}") return False if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract VAE features from a WAV file and save as .npy") parser.add_argument('--input_wav', type=str, required=True, help='输入WAV文件路径 (Path to input WAV file)') parser.add_argument('--output_npy', type=str, required=True, help='输出NPY文件路径 (Path to output NPY file)') args = parser.parse_args() success = extract_vae_features(args.input_wav, args.output_npy) if success: print("特征提取完成! (Feature extraction completed!)") else: print("特征提取失败 (Feature extraction failed)")