import os import sys import time import json import logging import argparse import psutil import torch import torchaudio from transformers import AutoProcessor, AutoModel def setup_logging(): """ Sets up a production-grade logger with a stream handler and file logging. Returns: logging.Logger: The configured logger instance. """ logger = logging.getLogger("MOSS-TTS-Opt") if not logger.handlers: logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') # Stream Handler sh = logging.StreamHandler(sys.stdout) sh.setFormatter(formatter) logger.addHandler(sh) # File Handler os.makedirs("logs", exist_ok=True) fh = logging.FileHandler("logs/inference.log") fh.setFormatter(formatter) logger.addHandler(fh) return logger class MOSSInferenceEngine: """ A high-performance inference engine for MOSS-TTS optimized for CPU execution. This engine handles model loading with float32 enforcement, dynamic INT8 quantization, and optimized audio generation specifically for CPU-only environments. """ def __init__(self, model_id: str = "OpenMOSS-Team/MOSS-TTS", device: str = "cpu"): """ Initializes the inference engine. Args: model_id (str): Hugging Face model repository ID. device (str): Device to run inference on (default is 'cpu'). """ self.model_id = model_id self.device = device self.model = None self.processor = None self.logger = setup_logging() # Optimize CPU threading for PyTorch self.threads = os.cpu_count() torch.set_num_threads(self.threads) self.logger.info(f"Engine: Initialized with {self.threads} CPU threads.") def load(self, trust_remote_code: bool = True): """ Loads the model and processor from the Hugging Face Hub. Enforces float32 to ensure compatibility with CPU quantization and avoid dtype mismatches. Args: trust_remote_code (bool): Whether to trust remote code from the model repository. """ self.logger.info(f"Engine: Loading model and processor: {self.model_id}") start_time = time.time() try: self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=trust_remote_code) # Implementation Note: We explicitly use torch_dtype=torch.float32 to avoid # BFloat16/Float16 weight mismatches during torch.ao.quantization.quantize_dynamic calls on CPU. self.model = AutoModel.from_pretrained( self.model_id, trust_remote_code=trust_remote_code, torch_dtype=torch.float32, low_cpu_mem_usage=True ).to(self.device) # Defensive cast to ensure all parameters are indeed float32 self.model = self.model.float() self.model.eval() self.logger.info(f"Engine: Load complete in {time.time() - start_time:.2f}s") except Exception as e: self.logger.error(f"Engine: Model loading failed: {e}") raise def quantize(self, mode: str = "int8"): """ Applies a dynamic quantization strategy to the model. Args: mode (str): Quantization strategy - 'fp32' (none), 'int8' (full), or 'selective'. """ if mode == "fp32": self.logger.info("Engine: Operating in FP32 mode (No quantization).") return start_q = time.time() if mode == "int8": self.logger.info("Engine: Applying full Dynamic INT8 quantization to Linear layers...") self.model = torch.quantization.quantize_dynamic( self.model, {torch.nn.Linear}, dtype=torch.qint8 ) elif mode == "selective": self.logger.info("Engine: Applying selective Dynamic INT8 quantization (Backbone only)...") # Target the heavy language model backbone if hasattr(self.model, 'language_model'): self.model.language_model = torch.quantization.quantize_dynamic( self.model.language_model, {torch.nn.Linear}, dtype=torch.qint8 ) # Target the output heads if present if hasattr(self.model, 'lm_heads'): self.model.lm_heads = torch.quantization.quantize_dynamic( self.model.lm_heads, {torch.nn.Linear}, dtype=torch.qint8 ) self.logger.info(f"Engine: Quantization ({mode}) completed in {time.time() - start_q:.2f}s.") def generate(self, text: str, max_new_tokens: int = 50, output_wav: str = None) -> dict: """ Synthesizes speech from text and saves the output to a WAV file. Args: text (str): Input text to synthesize. max_new_tokens (int): Maximum generation length. output_wav (str): File path to save the generated audio. Returns: dict: Latency and output metadata. """ self.logger.info(f"Engine: Generating for text sample: '{text[:50]}...'") conversations = [{"role": "user", "content": text}] inputs = self.processor(conversations=conversations, return_tensors="pt").to(self.device) start_inf = time.time() with torch.no_grad(): outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) latency = (time.time() - start_inf) * 1000 self.logger.info(f"Engine: Generation finished in {latency:.2f}ms") if output_wav: self._save_audio(outputs, output_wav) return {"latency_ms": latency} def _save_audio(self, outputs, output_path: str): """Helper to extract and save audio from model outputs.""" try: waveform = None if isinstance(outputs, torch.Tensor): waveform = outputs elif isinstance(outputs, dict) and "waveform" in outputs: waveform = outputs["waveform"] elif hasattr(outputs, "waveform"): waveform = outputs.waveform if waveform is not None: waveform = waveform.detach().cpu().float() if waveform.dim() == 1: waveform = waveform.unsqueeze(0) elif waveform.dim() == 3: # Case: [batch, channel, time] waveform = waveform.squeeze(0) # Retrieve sample rate from model config or default to 24000 sr = getattr(self.model.config, "sampling_rate", 24000) os.makedirs(os.path.dirname(output_path), exist_ok=True) torchaudio.save(output_path, waveform, sr) self.logger.info(f"Engine: Audio saved to {output_path}") else: self.logger.warning("Engine: No waveform found in model outputs.") except Exception as e: self.logger.error(f"Engine: Audio saving error: {e}") def get_current_ram(): """Calculates the current process RAM usage in MB.""" return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) def main(): """Main entry point for the CLI tool.""" parser = argparse.ArgumentParser(description="Production-grade MOSS-TTS Optimizer for CPU") parser.add_argument("--mode", type=str, choices=["fp32", "int8", "selective"], default="fp32", help="Quantization mode (fp32, int8, selective).") parser.add_argument("--text", type=str, default="Validating the optimized CPU inference pipeline for MOSS TTS.", help="Text string to synthesize.") parser.add_argument("--output_json", type=str, default="results/metrics.json", help="Path to save performance metrics (JSON).") parser.add_argument("--output_wav", type=str, default="outputs/generated_audio.wav", help="Path to save the generated audio (WAV).") args = parser.parse_args() logger = setup_logging() initial_ram = get_current_ram() try: engine = MOSSInferenceEngine() load_start = time.time() engine.load() load_time = time.time() - load_start engine.quantize(mode=args.mode) peak_ram = get_current_ram() # Adjust wav path to include mode wav_path = args.output_wav.replace(".wav", f"_{args.mode}.wav") res = engine.generate(args.text, output_wav=wav_path) final_stats = { "mode": args.mode, "load_time_sec": load_time, "peak_ram_mb": peak_ram, "ram_usage_delta_mb": peak_ram - initial_ram, "latency_ms": res["latency_ms"] } os.makedirs(os.path.dirname(args.output_json), exist_ok=True) with open(args.output_json, "w") as f: json.dump(final_stats, f, indent=4) logger.info(f"Success: Mode={args.mode} | RAM={peak_ram:.2f}MB | Latency={res['latency_ms']:.2f}ms") except Exception as e: logger.error(f"Execution failed: {e}") sys.exit(1) if __name__ == "__main__": main()