daksh-neo commited on
Commit
53a0ef9
·
verified ·
1 Parent(s): 07acd5c

Initial model upload with complete configuration and weights

Browse files
Files changed (4) hide show
  1. README.md +27 -0
  2. benchmark.py +72 -0
  3. optimize_tts.py +234 -0
  4. requirements.txt +80 -4
README.md CHANGED
@@ -24,6 +24,12 @@ language:
24
  - el
25
  - tr
26
  ---
 
 
 
 
 
 
27
  # MOSS-TTS Family
28
 
29
 
@@ -49,6 +55,27 @@ language:
49
  </div>
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ## Overview
53
  MOSS‑TTS Family is an open‑source **speech and sound generation model family** from [MOSI.AI](https://mosi.cn/#hero) and the [OpenMOSS team](https://www.open-moss.com/). It is designed for **high‑fidelity**, **high‑expressiveness**, and **complex real‑world scenarios**, covering stable long‑form speech, multi‑speaker dialogue, voice/character design, environmental sound effects, and real‑time streaming TTS.
54
 
 
24
  - el
25
  - tr
26
  ---
27
+ # MOSS-TTS (CPU Optimized)
28
+
29
+ > **Notice**: This is the **CPU-Optimized** version of MOSS-TTS. It includes high-performance inference scripts and has been validated for efficient execution on CPU-only environments using dynamic quantization.
30
+
31
+ ---
32
+
33
  # MOSS-TTS Family
34
 
35
 
 
55
  </div>
56
 
57
 
58
+ ### CPU Optimized Inference
59
+ This version contains specific optimizations for CPU environments.
60
+
61
+ 1. **Installation**:
62
+ ```bash
63
+ pip install -r requirements.txt
64
+ ```
65
+
66
+ 2. **Run Optimized Inference**:
67
+ Use the `optimize_tts.py` script included in this repository:
68
+ ```bash
69
+ python optimize_tts.py --mode int8 --text "Generating speech on CPU."
70
+ ```
71
+
72
+ 3. **Optimization Details**:
73
+ - Runtime Dynamic INT8 Quantization.
74
+ - Forced Float32 for stability on CPU.
75
+ - Multi-threaded CPU performance scaling.
76
+
77
+ ---
78
+
79
  ## Overview
80
  MOSS‑TTS Family is an open‑source **speech and sound generation model family** from [MOSI.AI](https://mosi.cn/#hero) and the [OpenMOSS team](https://www.open-moss.com/). It is designed for **high‑fidelity**, **high‑expressiveness**, and **complex real‑world scenarios**, covering stable long‑form speech, multi‑speaker dialogue, voice/character design, environmental sound effects, and real‑time streaming TTS.
81
 
benchmark.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+ import json
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
8
+
9
+ def get_python_exe():
10
+ """Detects the best python executable to use (prefers venv)."""
11
+ # Prefer explicit venv path
12
+ venv_python = os.path.join(os.getcwd(), "venv/bin/python3")
13
+ if os.path.exists(venv_python):
14
+ return venv_python
15
+ return sys.executable
16
+
17
+ def run_benchmark():
18
+ python_exe = get_python_exe()
19
+ logging.info(f"Using Python executable: {python_exe}")
20
+
21
+ modes = ["fp32", "int8", "selective"]
22
+ summary = {}
23
+
24
+ os.makedirs("results", exist_ok=True)
25
+ os.makedirs("outputs", exist_ok=True)
26
+
27
+ for mode in modes:
28
+ logging.info(f"=== BENCHMARKING MODE: {mode} ===")
29
+ output_file = f"results/results_{mode}.json"
30
+
31
+ # Run in subprocess to ensure isolated memory measurement
32
+ cmd = [
33
+ python_exe, "src/optimize_tts.py",
34
+ "--mode", mode,
35
+ "--output_json", output_file,
36
+ "--text", "This is a benchmarking sample for CPU optimized MOSS TTS. It tests end-to-end latency."
37
+ ]
38
+
39
+ try:
40
+ # Capturing outputs
41
+ result = subprocess.run(cmd, capture_output=True, text=True)
42
+ if result.returncode == 0:
43
+ if os.path.exists(output_file):
44
+ with open(output_file, "r") as f:
45
+ summary[mode] = json.load(f)
46
+ logging.info(f"Success: {mode}")
47
+ else:
48
+ logging.error(f"Output file {output_file} not found for mode {mode}")
49
+ else:
50
+ logging.error(f"Failed to benchmark {mode}. Return code: {result.returncode}")
51
+ logging.error(f"STDERR: {result.stderr}")
52
+ except Exception as e:
53
+ logging.error(f"Subprocess error for {mode}: {e}")
54
+
55
+ # Print Summary Table
56
+ print("\n" + "="*80)
57
+ print(f"{'Quantization Mode':<20} | {'RAM (MB)':<12} | {'Latency (ms)':<15} | {'Load (s)':<10}")
58
+ print("-" * 80)
59
+ for mode in modes:
60
+ if mode in summary:
61
+ data = summary[mode]
62
+ print(f"{mode:<20} | {data['peak_ram_mb']:<12.2f} | {data['latency_ms']:<15.2f} | {data['load_time_sec']:<10.2f}")
63
+ else:
64
+ print(f"{mode:<20} | {'FAILED':<12} | {'N/A':<15} | {'N/A':<10}")
65
+ print("="*80 + "\n")
66
+
67
+ with open("results/benchmark_summary.json", "w") as f:
68
+ json.dump(summary, f, indent=4)
69
+ logging.info("Benchmark summary saved to results/benchmark_summary.json")
70
+
71
+ if __name__ == "__main__":
72
+ run_benchmark()
optimize_tts.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import json
5
+ import logging
6
+ import argparse
7
+ import psutil
8
+ import torch
9
+ import torchaudio
10
+ from transformers import AutoProcessor, AutoModel
11
+
12
+ def setup_logging():
13
+ """
14
+ Sets up a production-grade logger with a stream handler and file logging.
15
+
16
+ Returns:
17
+ logging.Logger: The configured logger instance.
18
+ """
19
+ logger = logging.getLogger("MOSS-TTS-Opt")
20
+ if not logger.handlers:
21
+ logger.setLevel(logging.INFO)
22
+ formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
23
+
24
+ # Stream Handler
25
+ sh = logging.StreamHandler(sys.stdout)
26
+ sh.setFormatter(formatter)
27
+ logger.addHandler(sh)
28
+
29
+ # File Handler
30
+ os.makedirs("logs", exist_ok=True)
31
+ fh = logging.FileHandler("logs/inference.log")
32
+ fh.setFormatter(formatter)
33
+ logger.addHandler(fh)
34
+
35
+ return logger
36
+
37
+ class MOSSInferenceEngine:
38
+ """
39
+ A high-performance inference engine for MOSS-TTS optimized for CPU execution.
40
+
41
+ This engine handles model loading with float32 enforcement, dynamic INT8 quantization,
42
+ and optimized audio generation specifically for CPU-only environments.
43
+ """
44
+ def __init__(self, model_id: str = "OpenMOSS-Team/MOSS-TTS", device: str = "cpu"):
45
+ """
46
+ Initializes the inference engine.
47
+
48
+ Args:
49
+ model_id (str): Hugging Face model repository ID.
50
+ device (str): Device to run inference on (default is 'cpu').
51
+ """
52
+ self.model_id = model_id
53
+ self.device = device
54
+ self.model = None
55
+ self.processor = None
56
+ self.logger = setup_logging()
57
+
58
+ # Optimize CPU threading for PyTorch
59
+ self.threads = os.cpu_count()
60
+ torch.set_num_threads(self.threads)
61
+ self.logger.info(f"Engine: Initialized with {self.threads} CPU threads.")
62
+
63
+ def load(self, trust_remote_code: bool = True):
64
+ """
65
+ Loads the model and processor from the Hugging Face Hub.
66
+ Enforces float32 to ensure compatibility with CPU quantization and avoid dtype mismatches.
67
+
68
+ Args:
69
+ trust_remote_code (bool): Whether to trust remote code from the model repository.
70
+ """
71
+ self.logger.info(f"Engine: Loading model and processor: {self.model_id}")
72
+ start_time = time.time()
73
+
74
+ try:
75
+ self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=trust_remote_code)
76
+
77
+ # Implementation Note: We explicitly use torch_dtype=torch.float32 to avoid
78
+ # BFloat16/Float16 weight mismatches during torch.ao.quantization.quantize_dynamic calls on CPU.
79
+ self.model = AutoModel.from_pretrained(
80
+ self.model_id,
81
+ trust_remote_code=trust_remote_code,
82
+ torch_dtype=torch.float32,
83
+ low_cpu_mem_usage=True
84
+ ).to(self.device)
85
+
86
+ # Defensive cast to ensure all parameters are indeed float32
87
+ self.model = self.model.float()
88
+ self.model.eval()
89
+ self.logger.info(f"Engine: Load complete in {time.time() - start_time:.2f}s")
90
+ except Exception as e:
91
+ self.logger.error(f"Engine: Model loading failed: {e}")
92
+ raise
93
+
94
+ def quantize(self, mode: str = "int8"):
95
+ """
96
+ Applies a dynamic quantization strategy to the model.
97
+
98
+ Args:
99
+ mode (str): Quantization strategy - 'fp32' (none), 'int8' (full), or 'selective'.
100
+ """
101
+ if mode == "fp32":
102
+ self.logger.info("Engine: Operating in FP32 mode (No quantization).")
103
+ return
104
+
105
+ start_q = time.time()
106
+ if mode == "int8":
107
+ self.logger.info("Engine: Applying full Dynamic INT8 quantization to Linear layers...")
108
+ self.model = torch.quantization.quantize_dynamic(
109
+ self.model, {torch.nn.Linear}, dtype=torch.qint8
110
+ )
111
+ elif mode == "selective":
112
+ self.logger.info("Engine: Applying selective Dynamic INT8 quantization (Backbone only)...")
113
+ # Target the heavy language model backbone
114
+ if hasattr(self.model, 'language_model'):
115
+ self.model.language_model = torch.quantization.quantize_dynamic(
116
+ self.model.language_model, {torch.nn.Linear}, dtype=torch.qint8
117
+ )
118
+ # Target the output heads if present
119
+ if hasattr(self.model, 'lm_heads'):
120
+ self.model.lm_heads = torch.quantization.quantize_dynamic(
121
+ self.model.lm_heads, {torch.nn.Linear}, dtype=torch.qint8
122
+ )
123
+ self.logger.info(f"Engine: Quantization ({mode}) completed in {time.time() - start_q:.2f}s.")
124
+
125
+ def generate(self, text: str, max_new_tokens: int = 50, output_wav: str = None) -> dict:
126
+ """
127
+ Synthesizes speech from text and saves the output to a WAV file.
128
+
129
+ Args:
130
+ text (str): Input text to synthesize.
131
+ max_new_tokens (int): Maximum generation length.
132
+ output_wav (str): File path to save the generated audio.
133
+
134
+ Returns:
135
+ dict: Latency and output metadata.
136
+ """
137
+ self.logger.info(f"Engine: Generating for text sample: '{text[:50]}...'")
138
+
139
+ conversations = [{"role": "user", "content": text}]
140
+ inputs = self.processor(conversations=conversations, return_tensors="pt").to(self.device)
141
+
142
+ start_inf = time.time()
143
+ with torch.no_grad():
144
+ outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
145
+ latency = (time.time() - start_inf) * 1000
146
+
147
+ self.logger.info(f"Engine: Generation finished in {latency:.2f}ms")
148
+
149
+ if output_wav:
150
+ self._save_audio(outputs, output_wav)
151
+
152
+ return {"latency_ms": latency}
153
+
154
+ def _save_audio(self, outputs, output_path: str):
155
+ """Helper to extract and save audio from model outputs."""
156
+ try:
157
+ waveform = None
158
+ if isinstance(outputs, torch.Tensor):
159
+ waveform = outputs
160
+ elif isinstance(outputs, dict) and "waveform" in outputs:
161
+ waveform = outputs["waveform"]
162
+ elif hasattr(outputs, "waveform"):
163
+ waveform = outputs.waveform
164
+
165
+ if waveform is not None:
166
+ waveform = waveform.detach().cpu().float()
167
+ if waveform.dim() == 1:
168
+ waveform = waveform.unsqueeze(0)
169
+ elif waveform.dim() == 3: # Case: [batch, channel, time]
170
+ waveform = waveform.squeeze(0)
171
+
172
+ # Retrieve sample rate from model config or default to 24000
173
+ sr = getattr(self.model.config, "sampling_rate", 24000)
174
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
175
+ torchaudio.save(output_path, waveform, sr)
176
+ self.logger.info(f"Engine: Audio saved to {output_path}")
177
+ else:
178
+ self.logger.warning("Engine: No waveform found in model outputs.")
179
+ except Exception as e:
180
+ self.logger.error(f"Engine: Audio saving error: {e}")
181
+
182
+ def get_current_ram():
183
+ """Calculates the current process RAM usage in MB."""
184
+ return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
185
+
186
+ def main():
187
+ """Main entry point for the CLI tool."""
188
+ parser = argparse.ArgumentParser(description="Production-grade MOSS-TTS Optimizer for CPU")
189
+ parser.add_argument("--mode", type=str, choices=["fp32", "int8", "selective"], default="fp32",
190
+ help="Quantization mode (fp32, int8, selective).")
191
+ parser.add_argument("--text", type=str, default="Validating the optimized CPU inference pipeline for MOSS TTS.",
192
+ help="Text string to synthesize.")
193
+ parser.add_argument("--output_json", type=str, default="results/metrics.json",
194
+ help="Path to save performance metrics (JSON).")
195
+ parser.add_argument("--output_wav", type=str, default="outputs/generated_audio.wav",
196
+ help="Path to save the generated audio (WAV).")
197
+ args = parser.parse_args()
198
+
199
+ logger = setup_logging()
200
+ initial_ram = get_current_ram()
201
+
202
+ try:
203
+ engine = MOSSInferenceEngine()
204
+
205
+ load_start = time.time()
206
+ engine.load()
207
+ load_time = time.time() - load_start
208
+
209
+ engine.quantize(mode=args.mode)
210
+ peak_ram = get_current_ram()
211
+
212
+ # Adjust wav path to include mode
213
+ wav_path = args.output_wav.replace(".wav", f"_{args.mode}.wav")
214
+ res = engine.generate(args.text, output_wav=wav_path)
215
+
216
+ final_stats = {
217
+ "mode": args.mode,
218
+ "load_time_sec": load_time,
219
+ "peak_ram_mb": peak_ram,
220
+ "ram_usage_delta_mb": peak_ram - initial_ram,
221
+ "latency_ms": res["latency_ms"]
222
+ }
223
+
224
+ os.makedirs(os.path.dirname(args.output_json), exist_ok=True)
225
+ with open(args.output_json, "w") as f:
226
+ json.dump(final_stats, f, indent=4)
227
+
228
+ logger.info(f"Success: Mode={args.mode} | RAM={peak_ram:.2f}MB | Latency={res['latency_ms']:.2f}ms")
229
+ except Exception as e:
230
+ logger.error(f"Execution failed: {e}")
231
+ sys.exit(1)
232
+
233
+ if __name__ == "__main__":
234
+ main()
requirements.txt CHANGED
@@ -1,7 +1,83 @@
1
- transformers>=4.40.0
2
- torch
3
- torchaudio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  huggingface_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  psutil
6
- accelerate>=0.26.0
 
7
  pypinyin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Jinja2==3.1.6
2
+ MarkupSafe==3.0.3
3
+ PyYAML==6.0.3
4
+ Pygments==2.19.2
5
+ accelerate==1.12.0
6
+ accelerate>=0.26.0
7
+ annotated-doc==0.0.4
8
+ anyio==4.12.1
9
+ audioread==3.1.0
10
+ certifi==2026.1.4
11
+ cffi==2.0.0
12
+ charset-normalizer==3.4.4
13
+ click==8.3.1
14
+ cuda-bindings==12.9.4
15
+ cuda-pathfinder==1.3.4
16
+ decorator==5.2.1
17
+ filelock==3.24.3
18
+ fsspec==2026.2.0
19
+ h11==0.16.0
20
+ hf-xet==1.2.0
21
+ httpcore==1.0.9
22
+ httpx==0.28.1
23
  huggingface_hub
24
+ huggingface_hub==1.4.1
25
+ idna==3.11
26
+ joblib==1.5.3
27
+ lazy_loader==0.4
28
+ librosa==0.11.0
29
+ llvmlite==0.46.0
30
+ markdown-it-py==4.0.0
31
+ mdurl==0.1.2
32
+ mpmath==1.3.0
33
+ msgpack==1.1.2
34
+ networkx==3.6.1
35
+ numba==0.64.0
36
+ numpy==2.4.2
37
+ nvidia-cublas-cu12==12.8.4.1
38
+ nvidia-cuda-cupti-cu12==12.8.90
39
+ nvidia-cuda-nvrtc-cu12==12.8.93
40
+ nvidia-cuda-runtime-cu12==12.8.90
41
+ nvidia-cudnn-cu12==9.10.2.21
42
+ nvidia-cufft-cu12==11.3.3.83
43
+ nvidia-cufile-cu12==1.13.1.3
44
+ nvidia-curand-cu12==10.3.9.90
45
+ nvidia-cusolver-cu12==11.7.3.90
46
+ nvidia-cusparse-cu12==12.5.8.93
47
+ nvidia-cusparselt-cu12==0.7.1
48
+ nvidia-nccl-cu12==2.27.5
49
+ nvidia-nvjitlink-cu12==12.8.93
50
+ nvidia-nvshmem-cu12==3.4.5
51
+ nvidia-nvtx-cu12==12.8.90
52
+ packaging==26.0
53
+ platformdirs==4.9.2
54
+ pooch==1.9.0
55
  psutil
56
+ psutil==7.2.2
57
+ pycparser==3.0
58
  pypinyin
59
+ regex==2026.2.19
60
+ requests==2.32.5
61
+ rich==14.3.3
62
+ safetensors==0.7.0
63
+ scikit-learn==1.8.0
64
+ scipy==1.17.1
65
+ setuptools==82.0.0
66
+ shellingham==1.5.4
67
+ soundfile==0.13.1
68
+ soxr==1.0.0
69
+ sympy==1.14.0
70
+ threadpoolctl==3.6.0
71
+ tokenizers==0.22.2
72
+ torch
73
+ torch==2.10.0
74
+ torchaudio
75
+ torchaudio==2.10.0
76
+ tqdm==4.67.3
77
+ transformers==5.2.0
78
+ transformers>=4.40.0
79
+ triton==3.6.0
80
+ typer-slim==0.24.0
81
+ typer==0.24.1
82
+ typing_extensions==4.15.0
83
+ urllib3==2.6.3