| | import os |
| | import sys |
| | import time |
| | import json |
| | import logging |
| | import argparse |
| | import psutil |
| | import torch |
| | import torchaudio |
| | from transformers import AutoProcessor, AutoModel |
| |
|
| | def setup_logging(): |
| | """ |
| | Sets up a production-grade logger with a stream handler and file logging. |
| | |
| | Returns: |
| | logging.Logger: The configured logger instance. |
| | """ |
| | logger = logging.getLogger("MOSS-TTS-Opt") |
| | if not logger.handlers: |
| | logger.setLevel(logging.INFO) |
| | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') |
| | |
| | |
| | sh = logging.StreamHandler(sys.stdout) |
| | sh.setFormatter(formatter) |
| | logger.addHandler(sh) |
| | |
| | |
| | os.makedirs("logs", exist_ok=True) |
| | fh = logging.FileHandler("logs/inference.log") |
| | fh.setFormatter(formatter) |
| | logger.addHandler(fh) |
| | |
| | return logger |
| |
|
| | class MOSSInferenceEngine: |
| | """ |
| | A high-performance inference engine for MOSS-TTS optimized for CPU execution. |
| | |
| | This engine handles model loading with float32 enforcement, dynamic INT8 quantization, |
| | and optimized audio generation specifically for CPU-only environments. |
| | """ |
| | def __init__(self, model_id: str = "OpenMOSS-Team/MOSS-TTS", device: str = "cpu"): |
| | """ |
| | Initializes the inference engine. |
| | |
| | Args: |
| | model_id (str): Hugging Face model repository ID. |
| | device (str): Device to run inference on (default is 'cpu'). |
| | """ |
| | self.model_id = model_id |
| | self.device = device |
| | self.model = None |
| | self.processor = None |
| | self.logger = setup_logging() |
| | |
| | |
| | self.threads = os.cpu_count() |
| | torch.set_num_threads(self.threads) |
| | self.logger.info(f"Engine: Initialized with {self.threads} CPU threads.") |
| |
|
| | def load(self, trust_remote_code: bool = True): |
| | """ |
| | Loads the model and processor from the Hugging Face Hub. |
| | Enforces float32 to ensure compatibility with CPU quantization and avoid dtype mismatches. |
| | |
| | Args: |
| | trust_remote_code (bool): Whether to trust remote code from the model repository. |
| | """ |
| | self.logger.info(f"Engine: Loading model and processor: {self.model_id}") |
| | start_time = time.time() |
| | |
| | try: |
| | self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=trust_remote_code) |
| | |
| | |
| | |
| | self.model = AutoModel.from_pretrained( |
| | self.model_id, |
| | trust_remote_code=trust_remote_code, |
| | torch_dtype=torch.float32, |
| | low_cpu_mem_usage=True |
| | ).to(self.device) |
| | |
| | |
| | self.model = self.model.float() |
| | self.model.eval() |
| | self.logger.info(f"Engine: Load complete in {time.time() - start_time:.2f}s") |
| | except Exception as e: |
| | self.logger.error(f"Engine: Model loading failed: {e}") |
| | raise |
| |
|
| | def quantize(self, mode: str = "int8"): |
| | """ |
| | Applies a dynamic quantization strategy to the model. |
| | |
| | Args: |
| | mode (str): Quantization strategy - 'fp32' (none), 'int8' (full), or 'selective'. |
| | """ |
| | if mode == "fp32": |
| | self.logger.info("Engine: Operating in FP32 mode (No quantization).") |
| | return |
| |
|
| | start_q = time.time() |
| | if mode == "int8": |
| | self.logger.info("Engine: Applying full Dynamic INT8 quantization to Linear layers...") |
| | self.model = torch.quantization.quantize_dynamic( |
| | self.model, {torch.nn.Linear}, dtype=torch.qint8 |
| | ) |
| | elif mode == "selective": |
| | self.logger.info("Engine: Applying selective Dynamic INT8 quantization (Backbone only)...") |
| | |
| | if hasattr(self.model, 'language_model'): |
| | self.model.language_model = torch.quantization.quantize_dynamic( |
| | self.model.language_model, {torch.nn.Linear}, dtype=torch.qint8 |
| | ) |
| | |
| | if hasattr(self.model, 'lm_heads'): |
| | self.model.lm_heads = torch.quantization.quantize_dynamic( |
| | self.model.lm_heads, {torch.nn.Linear}, dtype=torch.qint8 |
| | ) |
| | self.logger.info(f"Engine: Quantization ({mode}) completed in {time.time() - start_q:.2f}s.") |
| |
|
| | def generate(self, text: str, max_new_tokens: int = 50, output_wav: str = None) -> dict: |
| | """ |
| | Synthesizes speech from text and saves the output to a WAV file. |
| | |
| | Args: |
| | text (str): Input text to synthesize. |
| | max_new_tokens (int): Maximum generation length. |
| | output_wav (str): File path to save the generated audio. |
| | |
| | Returns: |
| | dict: Latency and output metadata. |
| | """ |
| | self.logger.info(f"Engine: Generating for text sample: '{text[:50]}...'") |
| | |
| | conversations = [{"role": "user", "content": text}] |
| | inputs = self.processor(conversations=conversations, return_tensors="pt").to(self.device) |
| | |
| | start_inf = time.time() |
| | with torch.no_grad(): |
| | outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) |
| | latency = (time.time() - start_inf) * 1000 |
| | |
| | self.logger.info(f"Engine: Generation finished in {latency:.2f}ms") |
| | |
| | if output_wav: |
| | self._save_audio(outputs, output_wav) |
| | |
| | return {"latency_ms": latency} |
| |
|
| | def _save_audio(self, outputs, output_path: str): |
| | """Helper to extract and save audio from model outputs.""" |
| | try: |
| | waveform = None |
| | if isinstance(outputs, torch.Tensor): |
| | waveform = outputs |
| | elif isinstance(outputs, dict) and "waveform" in outputs: |
| | waveform = outputs["waveform"] |
| | elif hasattr(outputs, "waveform"): |
| | waveform = outputs.waveform |
| | |
| | if waveform is not None: |
| | waveform = waveform.detach().cpu().float() |
| | if waveform.dim() == 1: |
| | waveform = waveform.unsqueeze(0) |
| | elif waveform.dim() == 3: |
| | waveform = waveform.squeeze(0) |
| | |
| | |
| | sr = getattr(self.model.config, "sampling_rate", 24000) |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | torchaudio.save(output_path, waveform, sr) |
| | self.logger.info(f"Engine: Audio saved to {output_path}") |
| | else: |
| | self.logger.warning("Engine: No waveform found in model outputs.") |
| | except Exception as e: |
| | self.logger.error(f"Engine: Audio saving error: {e}") |
| |
|
| | def get_current_ram(): |
| | """Calculates the current process RAM usage in MB.""" |
| | return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) |
| |
|
| | def main(): |
| | """Main entry point for the CLI tool.""" |
| | parser = argparse.ArgumentParser(description="Production-grade MOSS-TTS Optimizer for CPU") |
| | parser.add_argument("--mode", type=str, choices=["fp32", "int8", "selective"], default="fp32", |
| | help="Quantization mode (fp32, int8, selective).") |
| | parser.add_argument("--text", type=str, default="Validating the optimized CPU inference pipeline for MOSS TTS.", |
| | help="Text string to synthesize.") |
| | parser.add_argument("--output_json", type=str, default="results/metrics.json", |
| | help="Path to save performance metrics (JSON).") |
| | parser.add_argument("--output_wav", type=str, default="outputs/generated_audio.wav", |
| | help="Path to save the generated audio (WAV).") |
| | args = parser.parse_args() |
| |
|
| | logger = setup_logging() |
| | initial_ram = get_current_ram() |
| | |
| | try: |
| | engine = MOSSInferenceEngine() |
| | |
| | load_start = time.time() |
| | engine.load() |
| | load_time = time.time() - load_start |
| | |
| | engine.quantize(mode=args.mode) |
| | peak_ram = get_current_ram() |
| | |
| | |
| | wav_path = args.output_wav.replace(".wav", f"_{args.mode}.wav") |
| | res = engine.generate(args.text, output_wav=wav_path) |
| | |
| | final_stats = { |
| | "mode": args.mode, |
| | "load_time_sec": load_time, |
| | "peak_ram_mb": peak_ram, |
| | "ram_usage_delta_mb": peak_ram - initial_ram, |
| | "latency_ms": res["latency_ms"] |
| | } |
| | |
| | os.makedirs(os.path.dirname(args.output_json), exist_ok=True) |
| | with open(args.output_json, "w") as f: |
| | json.dump(final_stats, f, indent=4) |
| | |
| | logger.info(f"Success: Mode={args.mode} | RAM={peak_ram:.2f}MB | Latency={res['latency_ms']:.2f}ms") |
| | except Exception as e: |
| | logger.error(f"Execution failed: {e}") |
| | sys.exit(1) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|