import os, json, math, time, wave, shutil from pathlib import Path from dataclasses import dataclass from typing import Any, Callable os.environ["OMP_NUM_THREADS"] = "2" import numpy as np import onnxruntime as ort import sentencepiece as spm import torch import torchaudio import gradio as gr from huggingface_hub import snapshot_download SAMPLE_MODE_GREEDY = "greedy" SAMPLE_MODE_FIXED = "fixed" SAMPLE_MODE_FULL = "full" EXECUTION_PROVIDER_CPU = "cpu" MODEL_DIR = Path(os.environ.get("MOSS_MODEL_DIR", "/app/models")) OUTPUT_DIR = Path(os.environ.get("MOSS_OUTPUT_DIR", "/tmp/moss_output")) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) SENTENCE_END_PUNCTUATION = set(".!?。!?;;") CLAUSE_SPLIT_PUNCTUATION = set(",,、;;::") CLOSING_PUNCTUATION = set("\"'\"')]})】》」』") MANIFEST_CANDIDATE_RELATIVE_PATHS = ( "browser_poc_manifest.json", "MOSS-TTS-Nano-100M-ONNX/browser_poc_manifest.json", "MOSS-TTS-Nano-ONNX-CPU/browser_poc_manifest.json", ) MODEL_DIR_ALIAS_MAP = { "MOSS-TTS-Nano-ONNX-CPU": "MOSS-TTS-Nano-100M-ONNX", "MOSS-Audio-Tokenizer-Nano-ONNX-CPU": "MOSS-Audio-Tokenizer-Nano-ONNX", } DEFAULT_TTS_REPO = "OpenMOSS-Team/MOSS-TTS-Nano-100M-ONNX" DEFAULT_CODEC_REPO = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano-ONNX" DEFAULT_INTER_CHUNK_PAUSE_SHORT = 0.40 DEFAULT_INTER_CHUNK_PAUSE_LONG = 0.24 def _argmax(values): return int(np.argmax(values)) def _normalize_sample_mode(raw, do_sample=True): s = str(raw or "").strip() if s in {SAMPLE_MODE_GREEDY, SAMPLE_MODE_FIXED, SAMPLE_MODE_FULL}: return s if not do_sample: return SAMPLE_MODE_GREEDY return SAMPLE_MODE_FIXED def _softmax(values): mx = float(np.max(values)) shifted = np.asarray(values - mx, dtype=np.float64) exps = np.exp(shifted) return exps / np.sum(exps, dtype=np.float64) def _sample_from_scores(values, *, do_sample, temperature, top_k, top_p, rng): if not do_sample: return _argmax(values) scores = np.asarray(values, dtype=np.float32).copy() / float(temperature) if top_k > 0 and top_k < scores.shape[0]: threshold = float(np.sort(scores)[::-1][top_k - 1]) scores[scores < threshold] = float("-inf") if top_p > 0 and top_p < 1: indexed = list(enumerate(scores.tolist())) indexed.sort(key=lambda x: x[1], reverse=True) sorted_scores = np.asarray([x[1] for x in indexed], dtype=np.float32) sorted_probs = _softmax(sorted_scores) remove_mask = [False] * len(indexed) cumulative = 0.0 for i, p in enumerate(sorted_probs): cumulative += float(p) if cumulative > float(top_p): remove_mask[i] = True for i in range(len(remove_mask) - 1, 0, -1): remove_mask[i] = remove_mask[i - 1] if remove_mask: remove_mask[0] = False for i, rm in enumerate(remove_mask): if rm: scores[indexed[i][0]] = float("-inf") probs = _softmax(scores) rv = float(rng.random()) for i, p in enumerate(probs): rv -= float(p) if rv <= 0: return int(i) return _argmax(scores) def _apply_repetition_penalty(values, prev_ids, penalty): if not prev_ids or penalty == 1.0: return values result = values.copy() for tid in set(int(x) for x in prev_ids): if 0 <= tid < result.shape[0]: result[tid] = result[tid] * penalty if result[tid] < 0 else result[tid] / penalty return result def _argmax_with_repetition_penalty(values, prev_set, penalty): best_idx, best_val = 0, float("-inf") apply = bool(prev_set) and penalty != 1.0 for i, v in enumerate(values): s = float(v) if apply and i in prev_set: s = s * penalty if s < 0 else s / penalty if s > best_val: best_val, best_idx = s, i return int(best_idx) def _sample_assistant_text_token(text_logits, manifest, gen_defaults, rng): cids = np.asarray([ int(manifest["tts_config"]["audio_assistant_slot_token_id"]), int(manifest["tts_config"]["audio_end_token_id"]), ], dtype=np.int32) cs = text_logits[cids] si = _sample_from_scores(cs, do_sample=bool(gen_defaults["do_sample"]), temperature=float(gen_defaults["text_temperature"]), top_k=min(int(gen_defaults["text_top_k"]), int(cs.shape[0])), top_p=float(gen_defaults["text_top_p"]), rng=rng) return int(cids[si]) def _sample_audio_token(audio_logits, prev_ids, prev_set, gen_defaults, rng): rp = float(gen_defaults["audio_repetition_penalty"]) if not bool(gen_defaults["do_sample"]): return _argmax_with_repetition_penalty(audio_logits, prev_set, rp) penalized = _apply_repetition_penalty(audio_logits, prev_ids, rp) return _sample_from_scores(penalized, do_sample=True, temperature=float(gen_defaults["audio_temperature"]), top_k=int(gen_defaults["audio_top_k"]), top_p=float(gen_defaults["audio_top_p"]), rng=rng) def _flatten3d(nested): d0, d1, d2 = len(nested), len(nested[0]), len(nested[0][0]) data = np.zeros((d0 * d1 * d2,), dtype=np.int32) off = 0 for i in range(d0): for j in range(d1): for k in range(d2): data[off] = int(nested[i][j][k]) off += 1 return data, [d0, d1, d2] def _flatten2d(nested): d0, d1 = len(nested), len(nested[0]) data = np.zeros((d0 * d1,), dtype=np.int32) off = 0 for i in range(d0): for j in range(d1): data[off] = int(nested[i][j]) off += 1 return data, [d0, d1] def _extract_last_hidden(hs): if hs.ndim == 2: return hs.astype(np.float32, copy=False) return hs[:, -1, :].astype(np.float32, copy=False) def _slice_channel_major_audio(audio, start=0, end=None): ch = int(audio.shape[1]) total = int(audio.shape[2]) s = max(0, int(start)) e = total if end is None else max(s, min(int(end), total)) return [audio[0, c, s:e].astype(np.float32, copy=False) for c in range(ch)] def _contains_cjk(text): for c in str(text or ""): if "\u4e00" <= c <= "\u9fff" or "\u3400" <= c <= "\u4dbf" or "\u3040" <= c <= "\u30ff" or "\uac00" <= c <= "\ud7af": return True return False def _prepare_text_for_sentence_chunking(text): t = str(text or "").strip() if not t: raise ValueError("Text prompt cannot be empty.") t = t.replace("\r", " ").replace("\n", " ") while " " in t: t = t.replace(" ", " ") if _contains_cjk(t): if t[-1] not in SENTENCE_END_PUNCTUATION: t += "。" return t if t[:1].islower(): t = t[:1].upper() + t[1:] if t[-1].isalnum(): t += "." if len([x for x in t.split() if x]) < 5: t = f" {t}" return t def _split_by_punct(text, punct): sentences, cur, i = [], [], 0 while i < len(text): c = text[i] cur.append(c) if c in punct: la = i + 1 while la < len(text) and text[la] in CLOSING_PUNCTUATION: cur.append(text[la]) la += 1 s = "".join(cur).strip() if s: sentences.append(s) cur.clear() while la < len(text) and text[la].isspace(): la += 1 i = la continue i += 1 tail = "".join(cur).strip() if tail: sentences.append(tail) return sentences def _merge_audio_channels(channels): if not channels: return np.zeros((0, 1), dtype=np.float32) if len(channels) == 1: return np.asarray(channels[0], dtype=np.float32).reshape(-1, 1) ml = min(int(c.shape[0]) for c in channels) return np.stack([np.asarray(c[:ml], dtype=np.float32) for c in channels], axis=1) def _concat_waveforms(wfs): if not wfs: return np.zeros((0, 1), dtype=np.float32) ne = [w for w in wfs if w.size > 0] if not ne: return np.zeros((0, max(1, int(wfs[0].shape[1]) if wfs[0].ndim > 1 and wfs[0].shape[1] > 0 else 1)), dtype=np.float32) return np.concatenate(ne, axis=0) def _write_wav(path, waveform, sr): p = Path(path).expanduser().resolve() p.parent.mkdir(parents=True, exist_ok=True) audio = np.asarray(waveform, dtype=np.float32) if audio.ndim == 1: audio = audio.reshape(-1, 1) pcm16 = np.round(np.clip(audio, -1.0, 1.0) * 32767.0).astype(np.int16) with wave.open(str(p), "wb") as f: f.setnchannels(int(pcm16.shape[1])) f.setsampwidth(2) f.setframerate(int(sr)) f.writeframes(pcm16.tobytes()) return p @dataclass class CodecStreamingSession: codec_meta: dict session: ort.InferenceSession def __post_init__(self): self.transformer_specs = list(self.codec_meta.get("streaming_decode", {}).get("transformer_offsets", [])) self.attention_specs = list(self.codec_meta.get("streaming_decode", {}).get("attention_caches", [])) self.state_feeds = {} self.reset() def reset(self): self.state_feeds = {} for s in self.transformer_specs: self.state_feeds[str(s["input_name"])] = np.zeros(tuple(s["shape"]), dtype=np.int32) for s in self.attention_specs: self.state_feeds[str(s["offset_input_name"])] = np.zeros(tuple(s["offset_shape"]), dtype=np.int32) self.state_feeds[str(s["cached_keys_input_name"])] = np.zeros(tuple(s["cache_shape"]), dtype=np.float32) self.state_feeds[str(s["cached_values_input_name"])] = np.zeros(tuple(s["cache_shape"]), dtype=np.float32) self.state_feeds[str(s["cached_positions_input_name"])] = np.full(tuple(s["positions_shape"]), -1, dtype=np.int32) def run_frames(self, frame_rows): if not frame_rows: return None nq = int(self.codec_meta["codec_config"]["num_quantizers"]) fc = len(frame_rows) ac = np.zeros((1, fc, nq), dtype=np.int32) for fi, fr in enumerate(frame_rows): for ci in range(nq): ac[0, fi, ci] = int(fr[ci] if ci < len(fr) else 0) feeds = {"audio_codes": ac, "audio_code_lengths": np.asarray([fc], dtype=np.int32)} feeds.update(self.state_feeds) outs = self.session.run(None, feeds) onames = [o.name for o in self.session.get_outputs()] named = dict(zip(onames, outs, strict=True)) for s in self.transformer_specs: self.state_feeds[str(s["input_name"])] = named[str(s["output_name"])] for s in self.attention_specs: self.state_feeds[str(s["offset_input_name"])] = named[str(s["offset_output_name"])] self.state_feeds[str(s["cached_keys_input_name"])] = named[str(s["cached_keys_output_name"])] self.state_feeds[str(s["cached_values_input_name"])] = named[str(s["cached_values_output_name"])] self.state_feeds[str(s["cached_positions_input_name"])] = named[str(s["cached_positions_output_name"])] return named["audio"], int(named["audio_lengths"].reshape(-1)[0]) def _resolve_stream_decode_frame_budget(emitted_total, sr, first_audio_at): if not first_audio_at or sr <= 0: return 1 elapsed = max(0.0, time.perf_counter() - first_audio_at) lead = emitted_total / float(sr) - elapsed if not first_audio_at or lead < 0.20: return 1 if lead < 0.55: return 2 if lead < 1.10: return 4 return 8 class MossTtsRuntime: def __init__(self, model_dir, thread_count=2, max_new_frames=375): self.model_dir = Path(model_dir).expanduser().resolve() self.thread_count = max(1, int(thread_count)) self.manifest_path = self._find_manifest() self.manifest_dir = self.manifest_path.parent self.manifest = json.loads(self.manifest_path.read_text("utf-8")) if max_new_frames is not None: self.manifest["generation_defaults"]["max_new_frames"] = int(max_new_frames) self.rng = np.random.default_rng(1234) self.tts_meta_path = self._resolve_path(self.manifest["model_files"]["tts_meta"]) self.codec_meta_path = self._resolve_path(self.manifest["model_files"]["codec_meta"]) self.tts_meta = json.loads(self.tts_meta_path.read_text("utf-8")) self.codec_meta = json.loads(self.codec_meta_path.read_text("utf-8")) tok_path = str(self._resolve_path(self.manifest["model_files"].get("tokenizer_model", "tokenizer.model"))) self.sp = spm.SentencePieceProcessor(model_file=tok_path) self.sessions = self._create_sessions() self.codec_stream = CodecStreamingSession(self.codec_meta, self.sessions["codec_decode_step"]) def _find_manifest(self): for rp in MANIFEST_CANDIDATE_RELATIVE_PATHS: c = (self.model_dir / rp).resolve() if c.is_file(): return c raise FileNotFoundError(f"browser_poc_manifest.json not found under {self.model_dir}") def _resolve_path(self, rel): resolved = (self.manifest_dir / Path(rel)).resolve() if resolved.exists(): return resolved rt = str(rel).replace("\\", "/") for old, new in MODEL_DIR_ALIAS_MAP.items(): frag = f"/{old}/" if frag in f"/{rt}/": rw = (self.manifest_dir / Path(rt.replace(old, new))).resolve() if rw.exists(): return rw return resolved def _session(self, p): opts = ort.SessionOptions() opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL opts.intra_op_num_threads = self.thread_count opts.inter_op_num_threads = 1 return ort.InferenceSession(str(p), sess_options=opts, providers=["CPUExecutionProvider"]) def _create_sessions(self): td = self.tts_meta_path.parent cd = self.codec_meta_path.parent sess = { "prefill": self._session(td / self.tts_meta["files"]["prefill"]), "decode": self._session(td / self.tts_meta["files"]["decode_step"]), "local_decoder": self._session(td / self.tts_meta["files"]["local_decoder"]), "codec_encode": self._session(cd / self.codec_meta["files"]["encode"]), "codec_decode": self._session(cd / self.codec_meta["files"]["decode_full"]), "codec_decode_step": self._session(cd / self.codec_meta["files"]["decode_step"]), } if self.tts_meta["files"].get("local_greedy_frame"): sess["local_greedy_frame"] = self._session(td / self.tts_meta["files"]["local_greedy_frame"]) if self.tts_meta["files"].get("local_fixed_sampled_frame"): sess["local_fixed_sampled_frame"] = self._session(td / self.tts_meta["files"]["local_fixed_sampled_frame"]) if self.tts_meta["files"].get("local_cached_step"): sess["local_cached_step"] = self._session(td / self.tts_meta["files"]["local_cached_step"]) return sess def list_builtin_voices(self): return list(self.manifest["builtin_voices"]) def encode_text(self, text): return [int(t) for t in self.sp.encode(str(text or ""), out_type=int)] def count_text_tokens(self, text): return len(self.encode_text(text)) def _load_ref_audio(self, path): wf, sr = torchaudio.load(str(Path(path).expanduser().resolve())) wf = wf.to(torch.float32) tsr = int(self.codec_meta["codec_config"]["sample_rate"]) tch = int(self.codec_meta["codec_config"]["channels"]) if sr != tsr: wf = torchaudio.functional.resample(wf, sr, tsr) cc = int(wf.shape[0]) if cc == tch: pass elif cc == 1 and tch > 1: wf = wf.repeat(tch, 1) elif cc > 1 and tch == 1: wf = wf.mean(dim=0, keepdim=True) else: raise ValueError(f"Unsupported channel conversion: {cc} -> {tch}") return wf.unsqueeze(0).detach().cpu().numpy().astype(np.float32, copy=False) def encode_ref_audio(self, path): wf = self._load_ref_audio(path) wl = int(wf.shape[-1]) outs = self.sessions["codec_encode"].run(None, {"waveform": wf, "input_lengths": np.asarray([wl], dtype=np.int32)}) onames = [o.name for o in self.sessions["codec_encode"].get_outputs()] named = dict(zip(onames, outs, strict=True)) ac = np.asarray(named["audio_codes"], dtype=np.int32) cl = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0]) nq = int(self.codec_meta["codec_config"]["num_quantizers"]) codes = [] for fi in range(cl): codes.append([int(ac[0, fi, qi]) for qi in range(nq)]) return codes def resolve_prompt_codes(self, *, voice, prompt_audio_path): if prompt_audio_path: return self.encode_ref_audio(prompt_audio_path) v = str(voice or self.list_builtin_voices()[0]["voice"]) row = next((x for x in self.list_builtin_voices() if x["voice"] == v), None) if row is None: raise ValueError(f"Built-in voice not found: {v}") return list(row["prompt_audio_codes"]) def build_text_rows(self, token_ids): rw = int(self.manifest["tts_config"]["n_vq"]) + 1 rows = [] for tid in token_ids: r = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * rw r[0] = int(tid) rows.append(r) return rows def build_audio_prefix_rows(self, codes, slot_id=None): rw = int(self.manifest["tts_config"]["n_vq"]) + 1 sid = int(self.manifest["tts_config"]["audio_user_slot_token_id"] if slot_id is None else slot_id) rows = [] for cr in codes: r = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * rw r[0] = sid for i in range(min(len(cr), rw - 1)): r[i + 1] = int(cr[i]) rows.append(r) return rows def build_request_rows(self, codes, text_ids): prefix = [*self.manifest["prompt_templates"]["user_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"])] suffix = [int(self.manifest["tts_config"]["audio_end_token_id"]), *self.manifest["prompt_templates"]["user_prompt_after_reference_token_ids"], *text_ids, *self.manifest["prompt_templates"]["assistant_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"])] rows = [*self.build_text_rows(prefix), *self.build_audio_prefix_rows(codes), *self.build_text_rows(suffix)] return {"inputIds": rows, "attentionMask": [[1 for _ in rows]]} def run_local_decoder(self, gh, text_tid, frame_prefix): nvq = int(self.manifest["tts_config"]["n_vq"]) apad = int(self.manifest["tts_config"]["audio_pad_token_id"]) pp = np.full((1, nvq - 1), apad, dtype=np.int32) for i in range(min(len(frame_prefix), nvq - 1)): pp[0, i] = int(frame_prefix[i]) outs = self.sessions["local_decoder"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "text_token_id": np.asarray([int(text_tid)], dtype=np.int32), "audio_prefix_token_ids": pp}) on = [o.name for o in self.sessions["local_decoder"].get_outputs()] nd = dict(zip(on, outs, strict=True)) return nd["text_logits"].reshape(-1), nd["audio_logits"] def create_empty_local_past(self): ll = int(self.tts_meta["model_config"]["local_layers"]) lh = int(self.tts_meta["model_config"]["local_heads"]) lhd = int(self.tts_meta["model_config"]["local_head_dim"]) return {n: np.zeros((1, 0, lh, lhd), dtype=np.float32) for li in range(ll) for n in (f"local_past_key_{li}", f"local_past_value_{li}")} def run_local_cached_step(self, gh, *, text_tid, audio_tid, ch_idx, step_type, past_vl, past): outs = self.sessions["local_cached_step"].run(None, { "global_hidden": gh.astype(np.float32, copy=False), "text_token_id": np.asarray([int(text_tid)], dtype=np.int32), "audio_token_id": np.asarray([int(audio_tid)], dtype=np.int32), "channel_index": np.asarray([int(ch_idx)], dtype=np.int32), "step_type": np.asarray([int(step_type)], dtype=np.int32), "past_valid_lengths": np.asarray([int(past_vl)], dtype=np.int32), **past, }) on = [o.name for o in self.sessions["local_cached_step"].get_outputs()] nd = dict(zip(on, outs, strict=True)) npast = {n.replace("local_present_", "local_past_"): nd[n] for n in self.tts_meta["onnx"]["local_cached_output_names"][2:]} return nd["text_logits"].reshape(-1), nd["audio_logits"], npast def run_local_greedy_frame(self, gh, *, prev_sets, rep_penalty): acs = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0]) nvq = int(self.manifest["tts_config"]["n_vq"]) rm = np.zeros((1, nvq, acs), dtype=np.int32) for ci, ts in enumerate(prev_sets): for tid in ts: if 0 <= tid < acs: rm[0, ci, tid] = 1 outs = self.sessions["local_greedy_frame"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "repetition_seen_mask": rm, "repetition_penalty": np.asarray([float(rep_penalty)], dtype=np.float32)}) on = [o.name for o in self.sessions["local_greedy_frame"].get_outputs()] nd = dict(zip(on, outs, strict=True)) cont = bool(int(np.asarray(nd["should_continue"]).reshape(-1)[0])) ftids = np.asarray(nd["frame_token_ids"]).reshape(-1).astype(np.int32, copy=False).tolist() return cont, [int(x) for x in ftids] def run_local_fixed_sampled_frame(self, gh, *, prev_sets): acs = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0]) nvq = int(self.manifest["tts_config"]["n_vq"]) rm = np.zeros((1, nvq, acs), dtype=np.int32) for ci, ts in enumerate(prev_sets): for tid in ts: if 0 <= tid < acs: rm[0, ci, tid] = 1 aru = np.asarray([min(0.99999994, max(0.0, float(self.rng.random())))], dtype=np.float32) au = np.asarray([[min(0.99999994, max(0.0, float(self.rng.random()))) for _ in range(nvq)]], dtype=np.float32) outs = self.sessions["local_fixed_sampled_frame"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "repetition_seen_mask": rm, "assistant_random_u": aru, "audio_random_u": au}) on = [o.name for o in self.sessions["local_fixed_sampled_frame"].get_outputs()] nd = dict(zip(on, outs, strict=True)) ftids = np.asarray(nd["frame_token_ids"]).reshape(-1).astype(np.int32, copy=False).tolist() cont = bool(int(np.asarray(nd["should_continue"]).reshape(-1)[0])) return cont, [int(x) for x in ftids] def slice_audio_channel_logits(self, alogits, ci): pc = int(alogits.shape[-1]) flat = alogits.reshape(-1) return flat[ci * pc:(ci + 1) * pc] def decode_full_audio(self, frames): if not frames: return [], 0 ac, dims = _flatten3d([frames]) outs = self.sessions["codec_decode"].run(None, {"audio_codes": ac.reshape(dims), "audio_code_lengths": np.asarray([len(frames)], dtype=np.int32)}) on = [o.name for o in self.sessions["codec_decode"].get_outputs()] nd = dict(zip(on, outs, strict=True)) al = int(nd["audio_lengths"].reshape(-1)[0]) return _slice_channel_major_audio(nd["audio"], 0, al), al def generate_audio_frames(self, req_rows, on_frame=None): gd = self.manifest["generation_defaults"] rw = int(self.manifest["tts_config"]["n_vq"]) + 1 pids, pdims = _flatten3d([req_rows["inputIds"]]) pmask, pmdims = _flatten2d(req_rows["attentionMask"]) outs = self.sessions["prefill"].run(None, {"input_ids": pids.reshape(pdims), "attention_mask": pmask.reshape(pmdims)}) on = [o.name for o in self.sessions["prefill"].get_outputs()] nd = dict(zip(on, outs, strict=True)) gh = _extract_last_hidden(nd["global_hidden"]) pvl = sum(int(x) for x in req_rows["attentionMask"][0]) past = {n.replace("present_", "past_"): nd[n] for n in self.tts_meta["onnx"]["prefill_output_names"][1:]} gen_frames = [] prev_by_ch = [[] for _ in range(int(self.manifest["tts_config"]["n_vq"]))] prev_set_by_ch = [set() for _ in range(int(self.manifest["tts_config"]["n_vq"]))] for si in range(int(gd["max_new_frames"])): frame = [] if "local_greedy_frame" in self.sessions and not bool(gd["do_sample"]): cont, frame = self.run_local_greedy_frame(gh, prev_sets=prev_set_by_ch, rep_penalty=float(gd["audio_repetition_penalty"])) if not cont: break for ci, st in enumerate(frame): prev_by_ch[ci].append(st) prev_set_by_ch[ci].add(st) elif "local_fixed_sampled_frame" in self.sessions and gd["sample_mode"] == SAMPLE_MODE_FIXED: cont, frame = self.run_local_fixed_sampled_frame(gh, prev_sets=prev_set_by_ch) if not cont: break for ci, st in enumerate(frame): prev_by_ch[ci].append(st) prev_set_by_ch[ci].add(st) elif "local_cached_step" in self.sessions: lp = self.create_empty_local_past() lpvl = 0 tl, _, lp = self.run_local_cached_step(gh, text_tid=0, audio_tid=0, ch_idx=0, step_type=0, past_vl=lpvl, past=lp) lpvl += 1 ntt = _sample_assistant_text_token(tl, self.manifest, gd, self.rng) if ntt != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]): break _, alogits, lp = self.run_local_cached_step(gh, text_tid=ntt, audio_tid=0, ch_idx=0, step_type=1, past_vl=lpvl, past=lp) lpvl += 1 fl = self.slice_audio_channel_logits(alogits, 0).astype(np.float32, copy=False) st = _sample_audio_token(fl, prev_by_ch[0], prev_set_by_ch[0], gd, self.rng) frame.append(st) prev_by_ch[0].append(st) prev_set_by_ch[0].add(st) prev = st for ci in range(1, int(self.manifest["tts_config"]["n_vq"])): _, alogits, lp = self.run_local_cached_step(gh, text_tid=0, audio_tid=prev, ch_idx=ci - 1, step_type=2, past_vl=lpvl, past=lp) lpvl += 1 cl = self.slice_audio_channel_logits(alogits, ci).astype(np.float32, copy=False) st = _sample_audio_token(cl, prev_by_ch[ci], prev_set_by_ch[ci], gd, self.rng) frame.append(st) prev_by_ch[ci].append(st) prev_set_by_ch[ci].add(st) prev = st else: tl, _ = self.run_local_decoder(gh, 0, []) ntt = _sample_assistant_text_token(tl, self.manifest, gd, self.rng) if ntt != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]): break for ci in range(int(self.manifest["tts_config"]["n_vq"])): _, alogits = self.run_local_decoder(gh, ntt, frame) cl = self.slice_audio_channel_logits(alogits, ci).astype(np.float32, copy=False) st = _sample_audio_token(cl, prev_by_ch[ci], prev_set_by_ch[ci], gd, self.rng) frame.append(st) prev_by_ch[ci].append(st) prev_set_by_ch[ci].add(st) gen_frames.append(frame) nr = np.full((1, 1, rw), int(self.manifest["tts_config"]["audio_pad_token_id"]), dtype=np.int32) nr[0, 0, 0] = int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]) for i, t in enumerate(frame): nr[0, 0, i + 1] = int(t) df = {"input_ids": nr, "past_valid_lengths": np.asarray([pvl], dtype=np.int32)} for iname in self.tts_meta["onnx"]["decode_input_names"][2:]: df[iname] = past[iname] dout = self.sessions["decode"].run(None, df) dn = [o.name for o in self.sessions["decode"].get_outputs()] dnd = dict(zip(dn, dout, strict=True)) gh = _extract_last_hidden(dnd["global_hidden"]) pvl += 1 past = {n.replace("present_", "past_"): dnd[n] for n in self.tts_meta["onnx"]["decode_output_names"][1:]} if on_frame is not None: on_frame(gen_frames, si, frame) return gen_frames def decode_full_audio_safe(self, frames): try: ch_arrays, _ = self.decode_full_audio(frames) return _merge_audio_channels(ch_arrays) except Exception as exc: import logging logging.warning("full codec decode failed, falling back: %s", exc) self.codec_stream.reset() nch = int(self.codec_meta["codec_config"]["channels"]) merged = [[] for _ in range(nch)] try: for si in range(0, len(frames), 8): chunk = frames[si:si + 8] dec = self.codec_stream.run_frames(chunk) if dec is None: continue audio, al = dec if al <= 0: continue for ci in range(nch): merged[ci].append(np.asarray(audio[0, ci, :al], dtype=np.float32)) finally: self.codec_stream.reset() return _merge_audio_channels([np.concatenate(c) if c else np.zeros((0,), dtype=np.float32) for c in merged]) def split_text_chunks(self, text, max_tokens=75): t = str(text or "").strip() if not t: return [] pieces = [] pref = set(CLAUSE_SPLIT_PUNCTUATION) | set(SENTENCE_END_PUNCTUATION) | {" "} while t: if self.count_text_tokens(t) <= max_tokens: pieces.append(t) break lo, hi, best = 1, len(t), 1 while lo <= hi: mid = (lo + hi) // 2 cand = t[:mid].strip() if cand and self.count_text_tokens(cand) <= max_tokens: best = mid lo = mid + 1 else: hi = mid - 1 if not cand: lo = mid + 1 ci = best pf = t[:best] pi = -1 for si in range(len(pf) - 1, max(-1, len(pf) - 25), -1): if pf[si] in pref: pi = si + 1 break if pi > 0: ci = pi piece = t[:ci].strip() if not piece: piece = t[:best].strip() ci = best pieces.append(piece) t = t[ci:].strip() return pieces if len(pieces) > 1 else [str(text or "").strip()] def synthesize(self, *, text, voice=None, prompt_audio_path=None, sample_mode="fixed", do_sample=True, streaming=True, max_new_frames=375): gd = self.manifest["generation_defaults"] gd["max_new_frames"] = int(max_new_frames) nsm = _normalize_sample_mode(sample_mode, do_sample) gd["sample_mode"] = nsm gd["do_sample"] = nsm != SAMPLE_MODE_GREEDY codes = self.resolve_prompt_codes(voice=voice, prompt_audio_path=prompt_audio_path) tid = self.encode_text(text) req = self.build_request_rows(codes, tid) if streaming: pending = [] emitted = [] emitted_total = 0 first_at = None self.codec_stream.reset() def decode_pending(force): nonlocal emitted_total, first_at pc = len(pending) if pc <= 0: return sr = int(self.codec_meta["codec_config"]["sample_rate"]) budget = _resolve_stream_decode_frame_budget(emitted_total, sr, first_at) if not force and pc < max(1, budget): return fb = pc if force else min(pc, max(1, budget)) chunk = pending[:fb] del pending[:fb] dec = self.codec_stream.run_frames(chunk) if dec is None: return audio, al = dec if al <= 0: return if first_at is None: first_at = time.perf_counter() emitted_total += al nch = int(self.codec_meta["codec_config"]["channels"]) emitted.append(_merge_audio_channels([audio[0, c, :al] for c in range(nch)])) def on_frame(gf, si, f): pending.append(list(f)) decode_pending(False) try: gf = self.generate_audio_frames(req, on_frame=on_frame) decode_pending(True) finally: self.codec_stream.reset() waveform = _concat_waveforms(emitted) else: gf = self.generate_audio_frames(req) waveform = self.decode_full_audio_safe(gf) sr = int(self.codec_meta["codec_config"]["sample_rate"]) out_path = OUTPUT_DIR / "output.wav" _write_wav(out_path, waveform, sr) return {"audio_path": str(out_path), "sample_rate": sr, "frames": len(gf)} def ensure_models(): tts_dir = MODEL_DIR / "MOSS-TTS-Nano-100M-ONNX" codec_dir = MODEL_DIR / "MOSS-Audio-Tokenizer-Nano-ONNX" if not (tts_dir / "browser_poc_manifest.json").is_file(): tts_dir.mkdir(parents=True, exist_ok=True) snapshot_download(DEFAULT_TTS_REPO, local_dir=str(tts_dir), local_dir_use_symlinks=False, allow_patterns=["*.onnx", "*.data", "*.json", "tokenizer.model"]) src = tts_dir / "MOSS-TTS-Nano-100M-ONNX" if src.is_dir(): for f in src.iterdir(): dst = tts_dir / f.name if not dst.exists(): shutil.move(str(f), str(dst)) if not (codec_dir / "codec_browser_onnx_meta.json").is_file(): codec_dir.mkdir(parents=True, exist_ok=True) snapshot_download(DEFAULT_CODEC_REPO, local_dir=str(codec_dir), local_dir_use_symlinks=False, allow_patterns=["*.onnx", "*.data", "*.json"]) src = codec_dir / "MOSS-Audio-Tokenizer-Nano-ONNX" if src.is_dir(): for f in src.iterdir(): dst = codec_dir / f.name if not dst.exists(): shutil.move(str(f), str(dst)) runtime = None def get_runtime(): global runtime if runtime is not None: return runtime ensure_models() runtime = MossTtsRuntime(MODEL_DIR, thread_count=2, max_new_frames=375) return runtime def synthesize_gradio(text, voice, audio_path, sample_mode, max_frames): rt = get_runtime() t0 = time.time() result = rt.synthesize( text=text, voice=voice if not audio_path else None, prompt_audio_path=audio_path if audio_path else None, sample_mode=sample_mode, do_sample=(sample_mode != "greedy"), streaming=True, max_new_frames=int(max_frames), ) elapsed = time.time() - t0 return result["audio_path"], f"Done in {elapsed:.1f}s | {result['sample_rate']}Hz | {result['frames']} frames" VOICES = ["Junhao", "Zhiming", "Weiguo", "Xiaoyu", "Yuewen", "Lingyu", "Trump", "Ava", "Bella", "Adam", "Nathan", "Soyo", "Saki", "Mortis", "Umiri", "Mei", "Anon", "Arisa"] with gr.Blocks(title="MOSS-TTS-Nano ONNX") as demo: gr.Markdown("# MOSS-TTS-Nano-100M-ONNX\nCPU-only TTS with voice cloning. First run downloads ~730MB model.") with gr.Row(): with gr.Column(): text_in = gr.Textbox(label="Text", value="Hello, welcome to MOSS TTS Nano.", lines=3) with gr.Row(): voice_in = gr.Dropdown(choices=VOICES, value="Junhao", label="Voice (overridden by ref audio)") ref_audio = gr.Audio(label="Reference Audio (optional, for voice cloning)", type="filepath") with gr.Row(): sample_mode = gr.Dropdown(choices=["fixed", "greedy", "full"], value="fixed", label="Sample Mode") max_frames = gr.Slider(16, 750, value=375, step=1, label="Max Frames") btn = gr.Button("Synthesize", variant="primary") with gr.Column(): audio_out = gr.Audio(label="Generated Audio", type="filepath") info_out = gr.Textbox(label="Info") btn.click(fn=synthesize_gradio, inputs=[text_in, voice_in, ref_audio, sample_mode, max_frames], outputs=[audio_out, info_out]) if __name__ == "__main__": get_runtime() demo.launch()