MOSS-TTS / optimize_tts.py
daksh-neo's picture
Initial model upload with complete configuration and weights
53a0ef9 verified
import os
import sys
import time
import json
import logging
import argparse
import psutil
import torch
import torchaudio
from transformers import AutoProcessor, AutoModel
def setup_logging():
"""
Sets up a production-grade logger with a stream handler and file logging.
Returns:
logging.Logger: The configured logger instance.
"""
logger = logging.getLogger("MOSS-TTS-Opt")
if not logger.handlers:
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
# Stream Handler
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
logger.addHandler(sh)
# File Handler
os.makedirs("logs", exist_ok=True)
fh = logging.FileHandler("logs/inference.log")
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
class MOSSInferenceEngine:
"""
A high-performance inference engine for MOSS-TTS optimized for CPU execution.
This engine handles model loading with float32 enforcement, dynamic INT8 quantization,
and optimized audio generation specifically for CPU-only environments.
"""
def __init__(self, model_id: str = "OpenMOSS-Team/MOSS-TTS", device: str = "cpu"):
"""
Initializes the inference engine.
Args:
model_id (str): Hugging Face model repository ID.
device (str): Device to run inference on (default is 'cpu').
"""
self.model_id = model_id
self.device = device
self.model = None
self.processor = None
self.logger = setup_logging()
# Optimize CPU threading for PyTorch
self.threads = os.cpu_count()
torch.set_num_threads(self.threads)
self.logger.info(f"Engine: Initialized with {self.threads} CPU threads.")
def load(self, trust_remote_code: bool = True):
"""
Loads the model and processor from the Hugging Face Hub.
Enforces float32 to ensure compatibility with CPU quantization and avoid dtype mismatches.
Args:
trust_remote_code (bool): Whether to trust remote code from the model repository.
"""
self.logger.info(f"Engine: Loading model and processor: {self.model_id}")
start_time = time.time()
try:
self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=trust_remote_code)
# Implementation Note: We explicitly use torch_dtype=torch.float32 to avoid
# BFloat16/Float16 weight mismatches during torch.ao.quantization.quantize_dynamic calls on CPU.
self.model = AutoModel.from_pretrained(
self.model_id,
trust_remote_code=trust_remote_code,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
).to(self.device)
# Defensive cast to ensure all parameters are indeed float32
self.model = self.model.float()
self.model.eval()
self.logger.info(f"Engine: Load complete in {time.time() - start_time:.2f}s")
except Exception as e:
self.logger.error(f"Engine: Model loading failed: {e}")
raise
def quantize(self, mode: str = "int8"):
"""
Applies a dynamic quantization strategy to the model.
Args:
mode (str): Quantization strategy - 'fp32' (none), 'int8' (full), or 'selective'.
"""
if mode == "fp32":
self.logger.info("Engine: Operating in FP32 mode (No quantization).")
return
start_q = time.time()
if mode == "int8":
self.logger.info("Engine: Applying full Dynamic INT8 quantization to Linear layers...")
self.model = torch.quantization.quantize_dynamic(
self.model, {torch.nn.Linear}, dtype=torch.qint8
)
elif mode == "selective":
self.logger.info("Engine: Applying selective Dynamic INT8 quantization (Backbone only)...")
# Target the heavy language model backbone
if hasattr(self.model, 'language_model'):
self.model.language_model = torch.quantization.quantize_dynamic(
self.model.language_model, {torch.nn.Linear}, dtype=torch.qint8
)
# Target the output heads if present
if hasattr(self.model, 'lm_heads'):
self.model.lm_heads = torch.quantization.quantize_dynamic(
self.model.lm_heads, {torch.nn.Linear}, dtype=torch.qint8
)
self.logger.info(f"Engine: Quantization ({mode}) completed in {time.time() - start_q:.2f}s.")
def generate(self, text: str, max_new_tokens: int = 50, output_wav: str = None) -> dict:
"""
Synthesizes speech from text and saves the output to a WAV file.
Args:
text (str): Input text to synthesize.
max_new_tokens (int): Maximum generation length.
output_wav (str): File path to save the generated audio.
Returns:
dict: Latency and output metadata.
"""
self.logger.info(f"Engine: Generating for text sample: '{text[:50]}...'")
conversations = [{"role": "user", "content": text}]
inputs = self.processor(conversations=conversations, return_tensors="pt").to(self.device)
start_inf = time.time()
with torch.no_grad():
outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
latency = (time.time() - start_inf) * 1000
self.logger.info(f"Engine: Generation finished in {latency:.2f}ms")
if output_wav:
self._save_audio(outputs, output_wav)
return {"latency_ms": latency}
def _save_audio(self, outputs, output_path: str):
"""Helper to extract and save audio from model outputs."""
try:
waveform = None
if isinstance(outputs, torch.Tensor):
waveform = outputs
elif isinstance(outputs, dict) and "waveform" in outputs:
waveform = outputs["waveform"]
elif hasattr(outputs, "waveform"):
waveform = outputs.waveform
if waveform is not None:
waveform = waveform.detach().cpu().float()
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
elif waveform.dim() == 3: # Case: [batch, channel, time]
waveform = waveform.squeeze(0)
# Retrieve sample rate from model config or default to 24000
sr = getattr(self.model.config, "sampling_rate", 24000)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, waveform, sr)
self.logger.info(f"Engine: Audio saved to {output_path}")
else:
self.logger.warning("Engine: No waveform found in model outputs.")
except Exception as e:
self.logger.error(f"Engine: Audio saving error: {e}")
def get_current_ram():
"""Calculates the current process RAM usage in MB."""
return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
def main():
"""Main entry point for the CLI tool."""
parser = argparse.ArgumentParser(description="Production-grade MOSS-TTS Optimizer for CPU")
parser.add_argument("--mode", type=str, choices=["fp32", "int8", "selective"], default="fp32",
help="Quantization mode (fp32, int8, selective).")
parser.add_argument("--text", type=str, default="Validating the optimized CPU inference pipeline for MOSS TTS.",
help="Text string to synthesize.")
parser.add_argument("--output_json", type=str, default="results/metrics.json",
help="Path to save performance metrics (JSON).")
parser.add_argument("--output_wav", type=str, default="outputs/generated_audio.wav",
help="Path to save the generated audio (WAV).")
args = parser.parse_args()
logger = setup_logging()
initial_ram = get_current_ram()
try:
engine = MOSSInferenceEngine()
load_start = time.time()
engine.load()
load_time = time.time() - load_start
engine.quantize(mode=args.mode)
peak_ram = get_current_ram()
# Adjust wav path to include mode
wav_path = args.output_wav.replace(".wav", f"_{args.mode}.wav")
res = engine.generate(args.text, output_wav=wav_path)
final_stats = {
"mode": args.mode,
"load_time_sec": load_time,
"peak_ram_mb": peak_ram,
"ram_usage_delta_mb": peak_ram - initial_ram,
"latency_ms": res["latency_ms"]
}
os.makedirs(os.path.dirname(args.output_json), exist_ok=True)
with open(args.output_json, "w") as f:
json.dump(final_stats, f, indent=4)
logger.info(f"Success: Mode={args.mode} | RAM={peak_ram:.2f}MB | Latency={res['latency_ms']:.2f}ms")
except Exception as e:
logger.error(f"Execution failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()