| import os, json, math, time, wave, shutil |
| from pathlib import Path |
| from dataclasses import dataclass |
| from typing import Any, Callable |
| os.environ["OMP_NUM_THREADS"] = "2" |
|
|
| import numpy as np |
| import onnxruntime as ort |
| import sentencepiece as spm |
| import torch |
| import torchaudio |
| import gradio as gr |
| from huggingface_hub import snapshot_download |
|
|
| SAMPLE_MODE_GREEDY = "greedy" |
| SAMPLE_MODE_FIXED = "fixed" |
| SAMPLE_MODE_FULL = "full" |
| EXECUTION_PROVIDER_CPU = "cpu" |
|
|
| MODEL_DIR = Path(os.environ.get("MOSS_MODEL_DIR", "/app/models")) |
| OUTPUT_DIR = Path(os.environ.get("MOSS_OUTPUT_DIR", "/tmp/moss_output")) |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| SENTENCE_END_PUNCTUATION = set(".!?。!?;;") |
| CLAUSE_SPLIT_PUNCTUATION = set(",,、;;::") |
| CLOSING_PUNCTUATION = set("\"'\"')]})】》」』") |
| MANIFEST_CANDIDATE_RELATIVE_PATHS = ( |
| "browser_poc_manifest.json", |
| "MOSS-TTS-Nano-100M-ONNX/browser_poc_manifest.json", |
| "MOSS-TTS-Nano-ONNX-CPU/browser_poc_manifest.json", |
| ) |
| MODEL_DIR_ALIAS_MAP = { |
| "MOSS-TTS-Nano-ONNX-CPU": "MOSS-TTS-Nano-100M-ONNX", |
| "MOSS-Audio-Tokenizer-Nano-ONNX-CPU": "MOSS-Audio-Tokenizer-Nano-ONNX", |
| } |
| DEFAULT_TTS_REPO = "OpenMOSS-Team/MOSS-TTS-Nano-100M-ONNX" |
| DEFAULT_CODEC_REPO = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano-ONNX" |
| DEFAULT_INTER_CHUNK_PAUSE_SHORT = 0.40 |
| DEFAULT_INTER_CHUNK_PAUSE_LONG = 0.24 |
|
|
|
|
| def _argmax(values): |
| return int(np.argmax(values)) |
|
|
|
|
| def _normalize_sample_mode(raw, do_sample=True): |
| s = str(raw or "").strip() |
| if s in {SAMPLE_MODE_GREEDY, SAMPLE_MODE_FIXED, SAMPLE_MODE_FULL}: |
| return s |
| if not do_sample: |
| return SAMPLE_MODE_GREEDY |
| return SAMPLE_MODE_FIXED |
|
|
|
|
| def _softmax(values): |
| mx = float(np.max(values)) |
| shifted = np.asarray(values - mx, dtype=np.float64) |
| exps = np.exp(shifted) |
| return exps / np.sum(exps, dtype=np.float64) |
|
|
|
|
| def _sample_from_scores(values, *, do_sample, temperature, top_k, top_p, rng): |
| if not do_sample: |
| return _argmax(values) |
| scores = np.asarray(values, dtype=np.float32).copy() / float(temperature) |
| if top_k > 0 and top_k < scores.shape[0]: |
| threshold = float(np.sort(scores)[::-1][top_k - 1]) |
| scores[scores < threshold] = float("-inf") |
| if top_p > 0 and top_p < 1: |
| indexed = list(enumerate(scores.tolist())) |
| indexed.sort(key=lambda x: x[1], reverse=True) |
| sorted_scores = np.asarray([x[1] for x in indexed], dtype=np.float32) |
| sorted_probs = _softmax(sorted_scores) |
| remove_mask = [False] * len(indexed) |
| cumulative = 0.0 |
| for i, p in enumerate(sorted_probs): |
| cumulative += float(p) |
| if cumulative > float(top_p): |
| remove_mask[i] = True |
| for i in range(len(remove_mask) - 1, 0, -1): |
| remove_mask[i] = remove_mask[i - 1] |
| if remove_mask: |
| remove_mask[0] = False |
| for i, rm in enumerate(remove_mask): |
| if rm: |
| scores[indexed[i][0]] = float("-inf") |
| probs = _softmax(scores) |
| rv = float(rng.random()) |
| for i, p in enumerate(probs): |
| rv -= float(p) |
| if rv <= 0: |
| return int(i) |
| return _argmax(scores) |
|
|
|
|
| def _apply_repetition_penalty(values, prev_ids, penalty): |
| if not prev_ids or penalty == 1.0: |
| return values |
| result = values.copy() |
| for tid in set(int(x) for x in prev_ids): |
| if 0 <= tid < result.shape[0]: |
| result[tid] = result[tid] * penalty if result[tid] < 0 else result[tid] / penalty |
| return result |
|
|
|
|
| def _argmax_with_repetition_penalty(values, prev_set, penalty): |
| best_idx, best_val = 0, float("-inf") |
| apply = bool(prev_set) and penalty != 1.0 |
| for i, v in enumerate(values): |
| s = float(v) |
| if apply and i in prev_set: |
| s = s * penalty if s < 0 else s / penalty |
| if s > best_val: |
| best_val, best_idx = s, i |
| return int(best_idx) |
|
|
|
|
| def _sample_assistant_text_token(text_logits, manifest, gen_defaults, rng): |
| cids = np.asarray([ |
| int(manifest["tts_config"]["audio_assistant_slot_token_id"]), |
| int(manifest["tts_config"]["audio_end_token_id"]), |
| ], dtype=np.int32) |
| cs = text_logits[cids] |
| si = _sample_from_scores(cs, do_sample=bool(gen_defaults["do_sample"]), |
| temperature=float(gen_defaults["text_temperature"]), |
| top_k=min(int(gen_defaults["text_top_k"]), int(cs.shape[0])), |
| top_p=float(gen_defaults["text_top_p"]), rng=rng) |
| return int(cids[si]) |
|
|
|
|
| def _sample_audio_token(audio_logits, prev_ids, prev_set, gen_defaults, rng): |
| rp = float(gen_defaults["audio_repetition_penalty"]) |
| if not bool(gen_defaults["do_sample"]): |
| return _argmax_with_repetition_penalty(audio_logits, prev_set, rp) |
| penalized = _apply_repetition_penalty(audio_logits, prev_ids, rp) |
| return _sample_from_scores(penalized, do_sample=True, |
| temperature=float(gen_defaults["audio_temperature"]), |
| top_k=int(gen_defaults["audio_top_k"]), |
| top_p=float(gen_defaults["audio_top_p"]), rng=rng) |
|
|
|
|
| def _flatten3d(nested): |
| d0, d1, d2 = len(nested), len(nested[0]), len(nested[0][0]) |
| data = np.zeros((d0 * d1 * d2,), dtype=np.int32) |
| off = 0 |
| for i in range(d0): |
| for j in range(d1): |
| for k in range(d2): |
| data[off] = int(nested[i][j][k]) |
| off += 1 |
| return data, [d0, d1, d2] |
|
|
|
|
| def _flatten2d(nested): |
| d0, d1 = len(nested), len(nested[0]) |
| data = np.zeros((d0 * d1,), dtype=np.int32) |
| off = 0 |
| for i in range(d0): |
| for j in range(d1): |
| data[off] = int(nested[i][j]) |
| off += 1 |
| return data, [d0, d1] |
|
|
|
|
| def _extract_last_hidden(hs): |
| if hs.ndim == 2: |
| return hs.astype(np.float32, copy=False) |
| return hs[:, -1, :].astype(np.float32, copy=False) |
|
|
|
|
| def _slice_channel_major_audio(audio, start=0, end=None): |
| ch = int(audio.shape[1]) |
| total = int(audio.shape[2]) |
| s = max(0, int(start)) |
| e = total if end is None else max(s, min(int(end), total)) |
| return [audio[0, c, s:e].astype(np.float32, copy=False) for c in range(ch)] |
|
|
|
|
| def _contains_cjk(text): |
| for c in str(text or ""): |
| if "\u4e00" <= c <= "\u9fff" or "\u3400" <= c <= "\u4dbf" or "\u3040" <= c <= "\u30ff" or "\uac00" <= c <= "\ud7af": |
| return True |
| return False |
|
|
|
|
| def _prepare_text_for_sentence_chunking(text): |
| t = str(text or "").strip() |
| if not t: |
| raise ValueError("Text prompt cannot be empty.") |
| t = t.replace("\r", " ").replace("\n", " ") |
| while " " in t: |
| t = t.replace(" ", " ") |
| if _contains_cjk(t): |
| if t[-1] not in SENTENCE_END_PUNCTUATION: |
| t += "。" |
| return t |
| if t[:1].islower(): |
| t = t[:1].upper() + t[1:] |
| if t[-1].isalnum(): |
| t += "." |
| if len([x for x in t.split() if x]) < 5: |
| t = f" {t}" |
| return t |
|
|
|
|
| def _split_by_punct(text, punct): |
| sentences, cur, i = [], [], 0 |
| while i < len(text): |
| c = text[i] |
| cur.append(c) |
| if c in punct: |
| la = i + 1 |
| while la < len(text) and text[la] in CLOSING_PUNCTUATION: |
| cur.append(text[la]) |
| la += 1 |
| s = "".join(cur).strip() |
| if s: |
| sentences.append(s) |
| cur.clear() |
| while la < len(text) and text[la].isspace(): |
| la += 1 |
| i = la |
| continue |
| i += 1 |
| tail = "".join(cur).strip() |
| if tail: |
| sentences.append(tail) |
| return sentences |
|
|
|
|
| def _merge_audio_channels(channels): |
| if not channels: |
| return np.zeros((0, 1), dtype=np.float32) |
| if len(channels) == 1: |
| return np.asarray(channels[0], dtype=np.float32).reshape(-1, 1) |
| ml = min(int(c.shape[0]) for c in channels) |
| return np.stack([np.asarray(c[:ml], dtype=np.float32) for c in channels], axis=1) |
|
|
|
|
| def _concat_waveforms(wfs): |
| if not wfs: |
| return np.zeros((0, 1), dtype=np.float32) |
| ne = [w for w in wfs if w.size > 0] |
| if not ne: |
| return np.zeros((0, max(1, int(wfs[0].shape[1]) if wfs[0].ndim > 1 and wfs[0].shape[1] > 0 else 1)), dtype=np.float32) |
| return np.concatenate(ne, axis=0) |
|
|
|
|
| def _write_wav(path, waveform, sr): |
| p = Path(path).expanduser().resolve() |
| p.parent.mkdir(parents=True, exist_ok=True) |
| audio = np.asarray(waveform, dtype=np.float32) |
| if audio.ndim == 1: |
| audio = audio.reshape(-1, 1) |
| pcm16 = np.round(np.clip(audio, -1.0, 1.0) * 32767.0).astype(np.int16) |
| with wave.open(str(p), "wb") as f: |
| f.setnchannels(int(pcm16.shape[1])) |
| f.setsampwidth(2) |
| f.setframerate(int(sr)) |
| f.writeframes(pcm16.tobytes()) |
| return p |
|
|
|
|
| @dataclass |
| class CodecStreamingSession: |
| codec_meta: dict |
| session: ort.InferenceSession |
|
|
| def __post_init__(self): |
| self.transformer_specs = list(self.codec_meta.get("streaming_decode", {}).get("transformer_offsets", [])) |
| self.attention_specs = list(self.codec_meta.get("streaming_decode", {}).get("attention_caches", [])) |
| self.state_feeds = {} |
| self.reset() |
|
|
| def reset(self): |
| self.state_feeds = {} |
| for s in self.transformer_specs: |
| self.state_feeds[str(s["input_name"])] = np.zeros(tuple(s["shape"]), dtype=np.int32) |
| for s in self.attention_specs: |
| self.state_feeds[str(s["offset_input_name"])] = np.zeros(tuple(s["offset_shape"]), dtype=np.int32) |
| self.state_feeds[str(s["cached_keys_input_name"])] = np.zeros(tuple(s["cache_shape"]), dtype=np.float32) |
| self.state_feeds[str(s["cached_values_input_name"])] = np.zeros(tuple(s["cache_shape"]), dtype=np.float32) |
| self.state_feeds[str(s["cached_positions_input_name"])] = np.full(tuple(s["positions_shape"]), -1, dtype=np.int32) |
|
|
| def run_frames(self, frame_rows): |
| if not frame_rows: |
| return None |
| nq = int(self.codec_meta["codec_config"]["num_quantizers"]) |
| fc = len(frame_rows) |
| ac = np.zeros((1, fc, nq), dtype=np.int32) |
| for fi, fr in enumerate(frame_rows): |
| for ci in range(nq): |
| ac[0, fi, ci] = int(fr[ci] if ci < len(fr) else 0) |
| feeds = {"audio_codes": ac, "audio_code_lengths": np.asarray([fc], dtype=np.int32)} |
| feeds.update(self.state_feeds) |
| outs = self.session.run(None, feeds) |
| onames = [o.name for o in self.session.get_outputs()] |
| named = dict(zip(onames, outs, strict=True)) |
| for s in self.transformer_specs: |
| self.state_feeds[str(s["input_name"])] = named[str(s["output_name"])] |
| for s in self.attention_specs: |
| self.state_feeds[str(s["offset_input_name"])] = named[str(s["offset_output_name"])] |
| self.state_feeds[str(s["cached_keys_input_name"])] = named[str(s["cached_keys_output_name"])] |
| self.state_feeds[str(s["cached_values_input_name"])] = named[str(s["cached_values_output_name"])] |
| self.state_feeds[str(s["cached_positions_input_name"])] = named[str(s["cached_positions_output_name"])] |
| return named["audio"], int(named["audio_lengths"].reshape(-1)[0]) |
|
|
|
|
| def _resolve_stream_decode_frame_budget(emitted_total, sr, first_audio_at): |
| if not first_audio_at or sr <= 0: |
| return 1 |
| elapsed = max(0.0, time.perf_counter() - first_audio_at) |
| lead = emitted_total / float(sr) - elapsed |
| if not first_audio_at or lead < 0.20: |
| return 1 |
| if lead < 0.55: |
| return 2 |
| if lead < 1.10: |
| return 4 |
| return 8 |
|
|
|
|
| class MossTtsRuntime: |
| def __init__(self, model_dir, thread_count=2, max_new_frames=375): |
| self.model_dir = Path(model_dir).expanduser().resolve() |
| self.thread_count = max(1, int(thread_count)) |
| self.manifest_path = self._find_manifest() |
| self.manifest_dir = self.manifest_path.parent |
| self.manifest = json.loads(self.manifest_path.read_text("utf-8")) |
| if max_new_frames is not None: |
| self.manifest["generation_defaults"]["max_new_frames"] = int(max_new_frames) |
| self.rng = np.random.default_rng(1234) |
| self.tts_meta_path = self._resolve_path(self.manifest["model_files"]["tts_meta"]) |
| self.codec_meta_path = self._resolve_path(self.manifest["model_files"]["codec_meta"]) |
| self.tts_meta = json.loads(self.tts_meta_path.read_text("utf-8")) |
| self.codec_meta = json.loads(self.codec_meta_path.read_text("utf-8")) |
| tok_path = str(self._resolve_path(self.manifest["model_files"].get("tokenizer_model", "tokenizer.model"))) |
| self.sp = spm.SentencePieceProcessor(model_file=tok_path) |
| self.sessions = self._create_sessions() |
| self.codec_stream = CodecStreamingSession(self.codec_meta, self.sessions["codec_decode_step"]) |
|
|
| def _find_manifest(self): |
| for rp in MANIFEST_CANDIDATE_RELATIVE_PATHS: |
| c = (self.model_dir / rp).resolve() |
| if c.is_file(): |
| return c |
| raise FileNotFoundError(f"browser_poc_manifest.json not found under {self.model_dir}") |
|
|
| def _resolve_path(self, rel): |
| resolved = (self.manifest_dir / Path(rel)).resolve() |
| if resolved.exists(): |
| return resolved |
| rt = str(rel).replace("\\", "/") |
| for old, new in MODEL_DIR_ALIAS_MAP.items(): |
| frag = f"/{old}/" |
| if frag in f"/{rt}/": |
| rw = (self.manifest_dir / Path(rt.replace(old, new))).resolve() |
| if rw.exists(): |
| return rw |
| return resolved |
|
|
| def _session(self, p): |
| opts = ort.SessionOptions() |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| opts.intra_op_num_threads = self.thread_count |
| opts.inter_op_num_threads = 1 |
| return ort.InferenceSession(str(p), sess_options=opts, providers=["CPUExecutionProvider"]) |
|
|
| def _create_sessions(self): |
| td = self.tts_meta_path.parent |
| cd = self.codec_meta_path.parent |
| sess = { |
| "prefill": self._session(td / self.tts_meta["files"]["prefill"]), |
| "decode": self._session(td / self.tts_meta["files"]["decode_step"]), |
| "local_decoder": self._session(td / self.tts_meta["files"]["local_decoder"]), |
| "codec_encode": self._session(cd / self.codec_meta["files"]["encode"]), |
| "codec_decode": self._session(cd / self.codec_meta["files"]["decode_full"]), |
| "codec_decode_step": self._session(cd / self.codec_meta["files"]["decode_step"]), |
| } |
| if self.tts_meta["files"].get("local_greedy_frame"): |
| sess["local_greedy_frame"] = self._session(td / self.tts_meta["files"]["local_greedy_frame"]) |
| if self.tts_meta["files"].get("local_fixed_sampled_frame"): |
| sess["local_fixed_sampled_frame"] = self._session(td / self.tts_meta["files"]["local_fixed_sampled_frame"]) |
| if self.tts_meta["files"].get("local_cached_step"): |
| sess["local_cached_step"] = self._session(td / self.tts_meta["files"]["local_cached_step"]) |
| return sess |
|
|
| def list_builtin_voices(self): |
| return list(self.manifest["builtin_voices"]) |
|
|
| def encode_text(self, text): |
| return [int(t) for t in self.sp.encode(str(text or ""), out_type=int)] |
|
|
| def count_text_tokens(self, text): |
| return len(self.encode_text(text)) |
|
|
| def _load_ref_audio(self, path): |
| wf, sr = torchaudio.load(str(Path(path).expanduser().resolve())) |
| wf = wf.to(torch.float32) |
| tsr = int(self.codec_meta["codec_config"]["sample_rate"]) |
| tch = int(self.codec_meta["codec_config"]["channels"]) |
| if sr != tsr: |
| wf = torchaudio.functional.resample(wf, sr, tsr) |
| cc = int(wf.shape[0]) |
| if cc == tch: |
| pass |
| elif cc == 1 and tch > 1: |
| wf = wf.repeat(tch, 1) |
| elif cc > 1 and tch == 1: |
| wf = wf.mean(dim=0, keepdim=True) |
| else: |
| raise ValueError(f"Unsupported channel conversion: {cc} -> {tch}") |
| return wf.unsqueeze(0).detach().cpu().numpy().astype(np.float32, copy=False) |
|
|
| def encode_ref_audio(self, path): |
| wf = self._load_ref_audio(path) |
| wl = int(wf.shape[-1]) |
| outs = self.sessions["codec_encode"].run(None, {"waveform": wf, "input_lengths": np.asarray([wl], dtype=np.int32)}) |
| onames = [o.name for o in self.sessions["codec_encode"].get_outputs()] |
| named = dict(zip(onames, outs, strict=True)) |
| ac = np.asarray(named["audio_codes"], dtype=np.int32) |
| cl = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0]) |
| nq = int(self.codec_meta["codec_config"]["num_quantizers"]) |
| codes = [] |
| for fi in range(cl): |
| codes.append([int(ac[0, fi, qi]) for qi in range(nq)]) |
| return codes |
|
|
| def resolve_prompt_codes(self, *, voice, prompt_audio_path): |
| if prompt_audio_path: |
| return self.encode_ref_audio(prompt_audio_path) |
| v = str(voice or self.list_builtin_voices()[0]["voice"]) |
| row = next((x for x in self.list_builtin_voices() if x["voice"] == v), None) |
| if row is None: |
| raise ValueError(f"Built-in voice not found: {v}") |
| return list(row["prompt_audio_codes"]) |
|
|
| def build_text_rows(self, token_ids): |
| rw = int(self.manifest["tts_config"]["n_vq"]) + 1 |
| rows = [] |
| for tid in token_ids: |
| r = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * rw |
| r[0] = int(tid) |
| rows.append(r) |
| return rows |
|
|
| def build_audio_prefix_rows(self, codes, slot_id=None): |
| rw = int(self.manifest["tts_config"]["n_vq"]) + 1 |
| sid = int(self.manifest["tts_config"]["audio_user_slot_token_id"] if slot_id is None else slot_id) |
| rows = [] |
| for cr in codes: |
| r = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * rw |
| r[0] = sid |
| for i in range(min(len(cr), rw - 1)): |
| r[i + 1] = int(cr[i]) |
| rows.append(r) |
| return rows |
|
|
| def build_request_rows(self, codes, text_ids): |
| prefix = [*self.manifest["prompt_templates"]["user_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"])] |
| suffix = [int(self.manifest["tts_config"]["audio_end_token_id"]), *self.manifest["prompt_templates"]["user_prompt_after_reference_token_ids"], *text_ids, *self.manifest["prompt_templates"]["assistant_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"])] |
| rows = [*self.build_text_rows(prefix), *self.build_audio_prefix_rows(codes), *self.build_text_rows(suffix)] |
| return {"inputIds": rows, "attentionMask": [[1 for _ in rows]]} |
|
|
| def run_local_decoder(self, gh, text_tid, frame_prefix): |
| nvq = int(self.manifest["tts_config"]["n_vq"]) |
| apad = int(self.manifest["tts_config"]["audio_pad_token_id"]) |
| pp = np.full((1, nvq - 1), apad, dtype=np.int32) |
| for i in range(min(len(frame_prefix), nvq - 1)): |
| pp[0, i] = int(frame_prefix[i]) |
| outs = self.sessions["local_decoder"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "text_token_id": np.asarray([int(text_tid)], dtype=np.int32), "audio_prefix_token_ids": pp}) |
| on = [o.name for o in self.sessions["local_decoder"].get_outputs()] |
| nd = dict(zip(on, outs, strict=True)) |
| return nd["text_logits"].reshape(-1), nd["audio_logits"] |
|
|
| def create_empty_local_past(self): |
| ll = int(self.tts_meta["model_config"]["local_layers"]) |
| lh = int(self.tts_meta["model_config"]["local_heads"]) |
| lhd = int(self.tts_meta["model_config"]["local_head_dim"]) |
| return {n: np.zeros((1, 0, lh, lhd), dtype=np.float32) for li in range(ll) for n in (f"local_past_key_{li}", f"local_past_value_{li}")} |
|
|
| def run_local_cached_step(self, gh, *, text_tid, audio_tid, ch_idx, step_type, past_vl, past): |
| outs = self.sessions["local_cached_step"].run(None, { |
| "global_hidden": gh.astype(np.float32, copy=False), |
| "text_token_id": np.asarray([int(text_tid)], dtype=np.int32), |
| "audio_token_id": np.asarray([int(audio_tid)], dtype=np.int32), |
| "channel_index": np.asarray([int(ch_idx)], dtype=np.int32), |
| "step_type": np.asarray([int(step_type)], dtype=np.int32), |
| "past_valid_lengths": np.asarray([int(past_vl)], dtype=np.int32), |
| **past, |
| }) |
| on = [o.name for o in self.sessions["local_cached_step"].get_outputs()] |
| nd = dict(zip(on, outs, strict=True)) |
| npast = {n.replace("local_present_", "local_past_"): nd[n] for n in self.tts_meta["onnx"]["local_cached_output_names"][2:]} |
| return nd["text_logits"].reshape(-1), nd["audio_logits"], npast |
|
|
| def run_local_greedy_frame(self, gh, *, prev_sets, rep_penalty): |
| acs = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0]) |
| nvq = int(self.manifest["tts_config"]["n_vq"]) |
| rm = np.zeros((1, nvq, acs), dtype=np.int32) |
| for ci, ts in enumerate(prev_sets): |
| for tid in ts: |
| if 0 <= tid < acs: |
| rm[0, ci, tid] = 1 |
| outs = self.sessions["local_greedy_frame"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "repetition_seen_mask": rm, "repetition_penalty": np.asarray([float(rep_penalty)], dtype=np.float32)}) |
| on = [o.name for o in self.sessions["local_greedy_frame"].get_outputs()] |
| nd = dict(zip(on, outs, strict=True)) |
| cont = bool(int(np.asarray(nd["should_continue"]).reshape(-1)[0])) |
| ftids = np.asarray(nd["frame_token_ids"]).reshape(-1).astype(np.int32, copy=False).tolist() |
| return cont, [int(x) for x in ftids] |
|
|
| def run_local_fixed_sampled_frame(self, gh, *, prev_sets): |
| acs = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0]) |
| nvq = int(self.manifest["tts_config"]["n_vq"]) |
| rm = np.zeros((1, nvq, acs), dtype=np.int32) |
| for ci, ts in enumerate(prev_sets): |
| for tid in ts: |
| if 0 <= tid < acs: |
| rm[0, ci, tid] = 1 |
| aru = np.asarray([min(0.99999994, max(0.0, float(self.rng.random())))], dtype=np.float32) |
| au = np.asarray([[min(0.99999994, max(0.0, float(self.rng.random()))) for _ in range(nvq)]], dtype=np.float32) |
| outs = self.sessions["local_fixed_sampled_frame"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "repetition_seen_mask": rm, "assistant_random_u": aru, "audio_random_u": au}) |
| on = [o.name for o in self.sessions["local_fixed_sampled_frame"].get_outputs()] |
| nd = dict(zip(on, outs, strict=True)) |
| ftids = np.asarray(nd["frame_token_ids"]).reshape(-1).astype(np.int32, copy=False).tolist() |
| cont = bool(int(np.asarray(nd["should_continue"]).reshape(-1)[0])) |
| return cont, [int(x) for x in ftids] |
|
|
| def slice_audio_channel_logits(self, alogits, ci): |
| pc = int(alogits.shape[-1]) |
| flat = alogits.reshape(-1) |
| return flat[ci * pc:(ci + 1) * pc] |
|
|
| def decode_full_audio(self, frames): |
| if not frames: |
| return [], 0 |
| ac, dims = _flatten3d([frames]) |
| outs = self.sessions["codec_decode"].run(None, {"audio_codes": ac.reshape(dims), "audio_code_lengths": np.asarray([len(frames)], dtype=np.int32)}) |
| on = [o.name for o in self.sessions["codec_decode"].get_outputs()] |
| nd = dict(zip(on, outs, strict=True)) |
| al = int(nd["audio_lengths"].reshape(-1)[0]) |
| return _slice_channel_major_audio(nd["audio"], 0, al), al |
|
|
| def generate_audio_frames(self, req_rows, on_frame=None): |
| gd = self.manifest["generation_defaults"] |
| rw = int(self.manifest["tts_config"]["n_vq"]) + 1 |
| pids, pdims = _flatten3d([req_rows["inputIds"]]) |
| pmask, pmdims = _flatten2d(req_rows["attentionMask"]) |
| outs = self.sessions["prefill"].run(None, {"input_ids": pids.reshape(pdims), "attention_mask": pmask.reshape(pmdims)}) |
| on = [o.name for o in self.sessions["prefill"].get_outputs()] |
| nd = dict(zip(on, outs, strict=True)) |
| gh = _extract_last_hidden(nd["global_hidden"]) |
| pvl = sum(int(x) for x in req_rows["attentionMask"][0]) |
| past = {n.replace("present_", "past_"): nd[n] for n in self.tts_meta["onnx"]["prefill_output_names"][1:]} |
| gen_frames = [] |
| prev_by_ch = [[] for _ in range(int(self.manifest["tts_config"]["n_vq"]))] |
| prev_set_by_ch = [set() for _ in range(int(self.manifest["tts_config"]["n_vq"]))] |
|
|
| for si in range(int(gd["max_new_frames"])): |
| frame = [] |
| if "local_greedy_frame" in self.sessions and not bool(gd["do_sample"]): |
| cont, frame = self.run_local_greedy_frame(gh, prev_sets=prev_set_by_ch, rep_penalty=float(gd["audio_repetition_penalty"])) |
| if not cont: |
| break |
| for ci, st in enumerate(frame): |
| prev_by_ch[ci].append(st) |
| prev_set_by_ch[ci].add(st) |
| elif "local_fixed_sampled_frame" in self.sessions and gd["sample_mode"] == SAMPLE_MODE_FIXED: |
| cont, frame = self.run_local_fixed_sampled_frame(gh, prev_sets=prev_set_by_ch) |
| if not cont: |
| break |
| for ci, st in enumerate(frame): |
| prev_by_ch[ci].append(st) |
| prev_set_by_ch[ci].add(st) |
| elif "local_cached_step" in self.sessions: |
| lp = self.create_empty_local_past() |
| lpvl = 0 |
| tl, _, lp = self.run_local_cached_step(gh, text_tid=0, audio_tid=0, ch_idx=0, step_type=0, past_vl=lpvl, past=lp) |
| lpvl += 1 |
| ntt = _sample_assistant_text_token(tl, self.manifest, gd, self.rng) |
| if ntt != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]): |
| break |
| _, alogits, lp = self.run_local_cached_step(gh, text_tid=ntt, audio_tid=0, ch_idx=0, step_type=1, past_vl=lpvl, past=lp) |
| lpvl += 1 |
| fl = self.slice_audio_channel_logits(alogits, 0).astype(np.float32, copy=False) |
| st = _sample_audio_token(fl, prev_by_ch[0], prev_set_by_ch[0], gd, self.rng) |
| frame.append(st) |
| prev_by_ch[0].append(st) |
| prev_set_by_ch[0].add(st) |
| prev = st |
| for ci in range(1, int(self.manifest["tts_config"]["n_vq"])): |
| _, alogits, lp = self.run_local_cached_step(gh, text_tid=0, audio_tid=prev, ch_idx=ci - 1, step_type=2, past_vl=lpvl, past=lp) |
| lpvl += 1 |
| cl = self.slice_audio_channel_logits(alogits, ci).astype(np.float32, copy=False) |
| st = _sample_audio_token(cl, prev_by_ch[ci], prev_set_by_ch[ci], gd, self.rng) |
| frame.append(st) |
| prev_by_ch[ci].append(st) |
| prev_set_by_ch[ci].add(st) |
| prev = st |
| else: |
| tl, _ = self.run_local_decoder(gh, 0, []) |
| ntt = _sample_assistant_text_token(tl, self.manifest, gd, self.rng) |
| if ntt != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]): |
| break |
| for ci in range(int(self.manifest["tts_config"]["n_vq"])): |
| _, alogits = self.run_local_decoder(gh, ntt, frame) |
| cl = self.slice_audio_channel_logits(alogits, ci).astype(np.float32, copy=False) |
| st = _sample_audio_token(cl, prev_by_ch[ci], prev_set_by_ch[ci], gd, self.rng) |
| frame.append(st) |
| prev_by_ch[ci].append(st) |
| prev_set_by_ch[ci].add(st) |
| gen_frames.append(frame) |
| nr = np.full((1, 1, rw), int(self.manifest["tts_config"]["audio_pad_token_id"]), dtype=np.int32) |
| nr[0, 0, 0] = int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]) |
| for i, t in enumerate(frame): |
| nr[0, 0, i + 1] = int(t) |
| df = {"input_ids": nr, "past_valid_lengths": np.asarray([pvl], dtype=np.int32)} |
| for iname in self.tts_meta["onnx"]["decode_input_names"][2:]: |
| df[iname] = past[iname] |
| dout = self.sessions["decode"].run(None, df) |
| dn = [o.name for o in self.sessions["decode"].get_outputs()] |
| dnd = dict(zip(dn, dout, strict=True)) |
| gh = _extract_last_hidden(dnd["global_hidden"]) |
| pvl += 1 |
| past = {n.replace("present_", "past_"): dnd[n] for n in self.tts_meta["onnx"]["decode_output_names"][1:]} |
| if on_frame is not None: |
| on_frame(gen_frames, si, frame) |
| return gen_frames |
|
|
| def decode_full_audio_safe(self, frames): |
| try: |
| ch_arrays, _ = self.decode_full_audio(frames) |
| return _merge_audio_channels(ch_arrays) |
| except Exception as exc: |
| import logging |
| logging.warning("full codec decode failed, falling back: %s", exc) |
| self.codec_stream.reset() |
| nch = int(self.codec_meta["codec_config"]["channels"]) |
| merged = [[] for _ in range(nch)] |
| try: |
| for si in range(0, len(frames), 8): |
| chunk = frames[si:si + 8] |
| dec = self.codec_stream.run_frames(chunk) |
| if dec is None: |
| continue |
| audio, al = dec |
| if al <= 0: |
| continue |
| for ci in range(nch): |
| merged[ci].append(np.asarray(audio[0, ci, :al], dtype=np.float32)) |
| finally: |
| self.codec_stream.reset() |
| return _merge_audio_channels([np.concatenate(c) if c else np.zeros((0,), dtype=np.float32) for c in merged]) |
|
|
| def split_text_chunks(self, text, max_tokens=75): |
| t = str(text or "").strip() |
| if not t: |
| return [] |
| pieces = [] |
| pref = set(CLAUSE_SPLIT_PUNCTUATION) | set(SENTENCE_END_PUNCTUATION) | {" "} |
| while t: |
| if self.count_text_tokens(t) <= max_tokens: |
| pieces.append(t) |
| break |
| lo, hi, best = 1, len(t), 1 |
| while lo <= hi: |
| mid = (lo + hi) // 2 |
| cand = t[:mid].strip() |
| if cand and self.count_text_tokens(cand) <= max_tokens: |
| best = mid |
| lo = mid + 1 |
| else: |
| hi = mid - 1 |
| if not cand: |
| lo = mid + 1 |
| ci = best |
| pf = t[:best] |
| pi = -1 |
| for si in range(len(pf) - 1, max(-1, len(pf) - 25), -1): |
| if pf[si] in pref: |
| pi = si + 1 |
| break |
| if pi > 0: |
| ci = pi |
| piece = t[:ci].strip() |
| if not piece: |
| piece = t[:best].strip() |
| ci = best |
| pieces.append(piece) |
| t = t[ci:].strip() |
| return pieces if len(pieces) > 1 else [str(text or "").strip()] |
|
|
| def synthesize(self, *, text, voice=None, prompt_audio_path=None, sample_mode="fixed", do_sample=True, streaming=True, max_new_frames=375): |
| gd = self.manifest["generation_defaults"] |
| gd["max_new_frames"] = int(max_new_frames) |
| nsm = _normalize_sample_mode(sample_mode, do_sample) |
| gd["sample_mode"] = nsm |
| gd["do_sample"] = nsm != SAMPLE_MODE_GREEDY |
| codes = self.resolve_prompt_codes(voice=voice, prompt_audio_path=prompt_audio_path) |
| tid = self.encode_text(text) |
| req = self.build_request_rows(codes, tid) |
| if streaming: |
| pending = [] |
| emitted = [] |
| emitted_total = 0 |
| first_at = None |
| self.codec_stream.reset() |
|
|
| def decode_pending(force): |
| nonlocal emitted_total, first_at |
| pc = len(pending) |
| if pc <= 0: |
| return |
| sr = int(self.codec_meta["codec_config"]["sample_rate"]) |
| budget = _resolve_stream_decode_frame_budget(emitted_total, sr, first_at) |
| if not force and pc < max(1, budget): |
| return |
| fb = pc if force else min(pc, max(1, budget)) |
| chunk = pending[:fb] |
| del pending[:fb] |
| dec = self.codec_stream.run_frames(chunk) |
| if dec is None: |
| return |
| audio, al = dec |
| if al <= 0: |
| return |
| if first_at is None: |
| first_at = time.perf_counter() |
| emitted_total += al |
| nch = int(self.codec_meta["codec_config"]["channels"]) |
| emitted.append(_merge_audio_channels([audio[0, c, :al] for c in range(nch)])) |
|
|
| def on_frame(gf, si, f): |
| pending.append(list(f)) |
| decode_pending(False) |
|
|
| try: |
| gf = self.generate_audio_frames(req, on_frame=on_frame) |
| decode_pending(True) |
| finally: |
| self.codec_stream.reset() |
| waveform = _concat_waveforms(emitted) |
| else: |
| gf = self.generate_audio_frames(req) |
| waveform = self.decode_full_audio_safe(gf) |
|
|
| sr = int(self.codec_meta["codec_config"]["sample_rate"]) |
| out_path = OUTPUT_DIR / "output.wav" |
| _write_wav(out_path, waveform, sr) |
| return {"audio_path": str(out_path), "sample_rate": sr, "frames": len(gf)} |
|
|
|
|
| def ensure_models(): |
| tts_dir = MODEL_DIR / "MOSS-TTS-Nano-100M-ONNX" |
| codec_dir = MODEL_DIR / "MOSS-Audio-Tokenizer-Nano-ONNX" |
| if not (tts_dir / "browser_poc_manifest.json").is_file(): |
| tts_dir.mkdir(parents=True, exist_ok=True) |
| snapshot_download(DEFAULT_TTS_REPO, local_dir=str(tts_dir), local_dir_use_symlinks=False, allow_patterns=["*.onnx", "*.data", "*.json", "tokenizer.model"]) |
| src = tts_dir / "MOSS-TTS-Nano-100M-ONNX" |
| if src.is_dir(): |
| for f in src.iterdir(): |
| dst = tts_dir / f.name |
| if not dst.exists(): |
| shutil.move(str(f), str(dst)) |
| if not (codec_dir / "codec_browser_onnx_meta.json").is_file(): |
| codec_dir.mkdir(parents=True, exist_ok=True) |
| snapshot_download(DEFAULT_CODEC_REPO, local_dir=str(codec_dir), local_dir_use_symlinks=False, allow_patterns=["*.onnx", "*.data", "*.json"]) |
| src = codec_dir / "MOSS-Audio-Tokenizer-Nano-ONNX" |
| if src.is_dir(): |
| for f in src.iterdir(): |
| dst = codec_dir / f.name |
| if not dst.exists(): |
| shutil.move(str(f), str(dst)) |
|
|
|
|
| runtime = None |
|
|
|
|
| def get_runtime(): |
| global runtime |
| if runtime is not None: |
| return runtime |
| ensure_models() |
| runtime = MossTtsRuntime(MODEL_DIR, thread_count=2, max_new_frames=375) |
| return runtime |
|
|
|
|
| def synthesize_gradio(text, voice, audio_path, sample_mode, max_frames): |
| rt = get_runtime() |
| t0 = time.time() |
| result = rt.synthesize( |
| text=text, |
| voice=voice if not audio_path else None, |
| prompt_audio_path=audio_path if audio_path else None, |
| sample_mode=sample_mode, |
| do_sample=(sample_mode != "greedy"), |
| streaming=True, |
| max_new_frames=int(max_frames), |
| ) |
| elapsed = time.time() - t0 |
| return result["audio_path"], f"Done in {elapsed:.1f}s | {result['sample_rate']}Hz | {result['frames']} frames" |
|
|
|
|
| VOICES = ["Junhao", "Zhiming", "Weiguo", "Xiaoyu", "Yuewen", "Lingyu", "Trump", "Ava", "Bella", "Adam", "Nathan", "Soyo", "Saki", "Mortis", "Umiri", "Mei", "Anon", "Arisa"] |
|
|
| with gr.Blocks(title="MOSS-TTS-Nano ONNX") as demo: |
| gr.Markdown("# MOSS-TTS-Nano-100M-ONNX\nCPU-only TTS with voice cloning. First run downloads ~730MB model.") |
| with gr.Row(): |
| with gr.Column(): |
| text_in = gr.Textbox(label="Text", value="Hello, welcome to MOSS TTS Nano.", lines=3) |
| with gr.Row(): |
| voice_in = gr.Dropdown(choices=VOICES, value="Junhao", label="Voice (overridden by ref audio)") |
| ref_audio = gr.Audio(label="Reference Audio (optional, for voice cloning)", type="filepath") |
| with gr.Row(): |
| sample_mode = gr.Dropdown(choices=["fixed", "greedy", "full"], value="fixed", label="Sample Mode") |
| max_frames = gr.Slider(16, 750, value=375, step=1, label="Max Frames") |
| btn = gr.Button("Synthesize", variant="primary") |
| with gr.Column(): |
| audio_out = gr.Audio(label="Generated Audio", type="filepath") |
| info_out = gr.Textbox(label="Info") |
| btn.click(fn=synthesize_gradio, inputs=[text_in, voice_in, ref_audio, sample_mode, max_frames], outputs=[audio_out, info_out]) |
|
|
| if __name__ == "__main__": |
| get_runtime() |
| demo.launch() |