| import os |
| import sys |
| import re |
| import tempfile |
| import numpy as np |
| from typing import Generator, Optional |
| from huggingface_hub import snapshot_download |
| from .model.voxcpm import VoxCPMModel, LoRAConfig |
|
|
| class VoxCPM: |
| def __init__(self, |
| voxcpm_model_path : str, |
| zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base", |
| enable_denoiser : bool = True, |
| optimize: bool = True, |
| lora_config: Optional[LoRAConfig] = None, |
| lora_weights_path: Optional[str] = None, |
| ): |
| """Initialize VoxCPM TTS pipeline. |
| |
| Args: |
| voxcpm_model_path: Local filesystem path to the VoxCPM model assets |
| (weights, configs, etc.). Typically the directory returned by |
| a prior download step. |
| zipenhancer_model_path: ModelScope acoustic noise suppression model |
| id or local path. If None, denoiser will not be initialized. |
| enable_denoiser: Whether to initialize the denoiser pipeline. |
| optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging. |
| lora_config: LoRA configuration for fine-tuning. If lora_weights_path is |
| provided without lora_config, a default config will be created. |
| lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory |
| containing lora_weights.ckpt). If provided, LoRA weights will be loaded. |
| """ |
| print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr) |
| |
| |
| if lora_weights_path is not None and lora_config is None: |
| lora_config = LoRAConfig( |
| enable_lm=True, |
| enable_dit=True, |
| enable_proj=False, |
| ) |
| print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr) |
| |
| self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) |
| |
| |
| if lora_weights_path is not None: |
| print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr) |
| loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path) |
| print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr) |
| |
| self.text_normalizer = None |
| if enable_denoiser and zipenhancer_model_path is not None: |
| from .zipenhancer import ZipEnhancer |
| self.denoiser = ZipEnhancer(zipenhancer_model_path) |
| else: |
| self.denoiser = None |
| if optimize: |
| print("Warm up VoxCPMModel...", file=sys.stderr) |
| self.tts_model.generate( |
| target_text="Hello, this is the first test sentence.", |
| max_len=10, |
| ) |
|
|
| @classmethod |
| def from_pretrained(cls, |
| hf_model_id: str = "openbmb/VoxCPM1.5", |
| load_denoiser: bool = True, |
| zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base", |
| cache_dir: str = None, |
| local_files_only: bool = False, |
| optimize: bool = True, |
| lora_config: Optional[LoRAConfig] = None, |
| lora_weights_path: Optional[str] = None, |
| **kwargs, |
| ): |
| """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot. |
| |
| Args: |
| hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path. |
| load_denoiser: Whether to initialize the denoiser pipeline. |
| optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging. |
| zipenhancer_model_id: Denoiser model id or path for ModelScope |
| acoustic noise suppression. |
| cache_dir: Custom cache directory for the snapshot. |
| local_files_only: If True, only use local files and do not attempt |
| to download. |
| lora_config: LoRA configuration for fine-tuning. If lora_weights_path is |
| provided without lora_config, a default config will be created with |
| enable_lm=True and enable_dit=True. |
| lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory |
| containing lora_weights.ckpt). If provided, LoRA weights will be loaded |
| after model initialization. |
| Kwargs: |
| Additional keyword arguments passed to the ``VoxCPM`` constructor. |
| |
| Returns: |
| VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to |
| the downloaded snapshot directory. |
| |
| Raises: |
| ValueError: If neither a valid ``hf_model_id`` nor a resolvable |
| ``hf_model_id`` is provided. |
| """ |
| repo_id = hf_model_id |
| if not repo_id: |
| raise ValueError("You must provide hf_model_id") |
| |
| |
| if os.path.isdir(repo_id): |
| local_path = repo_id |
| else: |
| |
| local_path = snapshot_download( |
| repo_id=repo_id, |
| cache_dir=cache_dir, |
| local_files_only=local_files_only, |
| ) |
|
|
| return cls( |
| voxcpm_model_path=local_path, |
| zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, |
| enable_denoiser=load_denoiser, |
| optimize=optimize, |
| lora_config=lora_config, |
| lora_weights_path=lora_weights_path, |
| **kwargs, |
| ) |
|
|
| def generate(self, *args, **kwargs) -> np.ndarray: |
| return next(self._generate(*args, streaming=False, **kwargs)) |
|
|
| def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]: |
| return self._generate(*args, streaming=True, **kwargs) |
|
|
| def _generate(self, |
| text : str, |
| prompt_wav_path : str = None, |
| prompt_text : str = None, |
| cfg_value : float = 2.0, |
| inference_timesteps : int = 10, |
| min_len : int = 2, |
| max_len : int = 4096, |
| normalize : bool = False, |
| denoise : bool = False, |
| retry_badcase : bool = True, |
| retry_badcase_max_times : int = 3, |
| retry_badcase_ratio_threshold : float = 6.0, |
| streaming: bool = False, |
| ) -> Generator[np.ndarray, None, None]: |
| """Synthesize speech for the given text and return a single waveform. |
| |
| This method optionally builds and reuses a prompt cache. If an external |
| prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be |
| used for all sub-sentences. Otherwise, the prompt cache is built from |
| the first generated result and reused for the remaining text chunks. |
| |
| Args: |
| text: Input text. Can include newlines; each non-empty line is |
| treated as a sub-sentence. |
| prompt_wav_path: Path to a reference audio file for prompting. |
| prompt_text: Text content corresponding to the prompt audio. |
| cfg_value: Guidance scale for the generation model. |
| inference_timesteps: Number of inference steps. |
| max_len: Maximum token length during generation. |
| normalize: Whether to run text normalization before generation. |
| denoise: Whether to denoise the prompt audio if a denoiser is |
| available. |
| retry_badcase: Whether to retry badcase. |
| retry_badcase_max_times: Maximum number of times to retry badcase. |
| retry_badcase_ratio_threshold: Threshold for audio-to-text ratio. |
| streaming: Whether to return a generator of audio chunks. |
| Returns: |
| Generator of numpy.ndarray: 1D waveform array (float32) on CPU. |
| Yields audio chunks for each generations step if ``streaming=True``, |
| otherwise yields a single array containing the final audio. |
| """ |
| if not text.strip() or not isinstance(text, str): |
| raise ValueError("target text must be a non-empty string") |
| |
| if prompt_wav_path is not None: |
| if not os.path.exists(prompt_wav_path): |
| raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}") |
| |
| if (prompt_wav_path is None) != (prompt_text is None): |
| raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None") |
| |
| text = text.replace("\n", " ") |
| text = re.sub(r'\s+', ' ', text) |
| temp_prompt_wav_path = None |
| |
| try: |
| if prompt_wav_path is not None and prompt_text is not None: |
| if denoise and self.denoiser is not None: |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: |
| temp_prompt_wav_path = tmp_file.name |
| self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path) |
| prompt_wav_path = temp_prompt_wav_path |
| fixed_prompt_cache = self.tts_model.build_prompt_cache( |
| prompt_wav_path=prompt_wav_path, |
| prompt_text=prompt_text |
| ) |
| else: |
| fixed_prompt_cache = None |
| |
| if normalize: |
| if self.text_normalizer is None: |
| from .utils.text_normalize import TextNormalizer |
| self.text_normalizer = TextNormalizer() |
| text = self.text_normalizer.normalize(text) |
| |
| generate_result = self.tts_model._generate_with_prompt_cache( |
| target_text=text, |
| prompt_cache=fixed_prompt_cache, |
| min_len=min_len, |
| max_len=max_len, |
| inference_timesteps=inference_timesteps, |
| cfg_value=cfg_value, |
| retry_badcase=retry_badcase, |
| retry_badcase_max_times=retry_badcase_max_times, |
| retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, |
| streaming=streaming, |
| ) |
| |
| for wav, _, _ in generate_result: |
| yield wav.squeeze(0).cpu().numpy() |
| |
| finally: |
| if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path): |
| try: |
| os.unlink(temp_prompt_wav_path) |
| except OSError: |
| pass |
|
|
| |
| |
| |
| def load_lora(self, lora_weights_path: str) -> tuple: |
| """Load LoRA weights from a checkpoint file. |
| |
| Args: |
| lora_weights_path: Path to LoRA weights (.pth file or directory |
| containing lora_weights.ckpt). |
| |
| Returns: |
| tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names. |
| |
| Raises: |
| RuntimeError: If model was not initialized with LoRA config. |
| """ |
| if self.tts_model.lora_config is None: |
| raise RuntimeError( |
| "Cannot load LoRA weights: model was not initialized with LoRA config. " |
| "Please reinitialize with lora_config or lora_weights_path parameter." |
| ) |
| return self.tts_model.load_lora_weights(lora_weights_path) |
|
|
| def unload_lora(self): |
| """Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA).""" |
| self.tts_model.reset_lora_weights() |
| |
| def set_lora_enabled(self, enabled: bool): |
| """Enable or disable LoRA layers without unloading weights. |
| |
| Args: |
| enabled: If True, LoRA layers are active; if False, only base model is used. |
| """ |
| self.tts_model.set_lora_enabled(enabled) |
| |
| def get_lora_state_dict(self) -> dict: |
| """Get current LoRA parameters state dict. |
| |
| Returns: |
| dict: State dict containing all LoRA parameters (lora_A, lora_B). |
| """ |
| return self.tts_model.get_lora_state_dict() |
| |
| @property |
| def lora_enabled(self) -> bool: |
| """Check if LoRA is currently configured.""" |
| return self.tts_model.lora_config is not None |
|
|