mtn / app.py
leesenx's picture
Update app.py
48de81f verified
import os, json, math, time, wave, shutil
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Callable
os.environ["OMP_NUM_THREADS"] = "2"
import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import torch
import torchaudio
import gradio as gr
from huggingface_hub import snapshot_download
SAMPLE_MODE_GREEDY = "greedy"
SAMPLE_MODE_FIXED = "fixed"
SAMPLE_MODE_FULL = "full"
EXECUTION_PROVIDER_CPU = "cpu"
MODEL_DIR = Path(os.environ.get("MOSS_MODEL_DIR", "/app/models"))
OUTPUT_DIR = Path(os.environ.get("MOSS_OUTPUT_DIR", "/tmp/moss_output"))
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SENTENCE_END_PUNCTUATION = set(".!?。!?;;")
CLAUSE_SPLIT_PUNCTUATION = set(",,、;;::")
CLOSING_PUNCTUATION = set("\"'\"')]})】》」』")
MANIFEST_CANDIDATE_RELATIVE_PATHS = (
"browser_poc_manifest.json",
"MOSS-TTS-Nano-100M-ONNX/browser_poc_manifest.json",
"MOSS-TTS-Nano-ONNX-CPU/browser_poc_manifest.json",
)
MODEL_DIR_ALIAS_MAP = {
"MOSS-TTS-Nano-ONNX-CPU": "MOSS-TTS-Nano-100M-ONNX",
"MOSS-Audio-Tokenizer-Nano-ONNX-CPU": "MOSS-Audio-Tokenizer-Nano-ONNX",
}
DEFAULT_TTS_REPO = "OpenMOSS-Team/MOSS-TTS-Nano-100M-ONNX"
DEFAULT_CODEC_REPO = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano-ONNX"
DEFAULT_INTER_CHUNK_PAUSE_SHORT = 0.40
DEFAULT_INTER_CHUNK_PAUSE_LONG = 0.24
def _argmax(values):
return int(np.argmax(values))
def _normalize_sample_mode(raw, do_sample=True):
s = str(raw or "").strip()
if s in {SAMPLE_MODE_GREEDY, SAMPLE_MODE_FIXED, SAMPLE_MODE_FULL}:
return s
if not do_sample:
return SAMPLE_MODE_GREEDY
return SAMPLE_MODE_FIXED
def _softmax(values):
mx = float(np.max(values))
shifted = np.asarray(values - mx, dtype=np.float64)
exps = np.exp(shifted)
return exps / np.sum(exps, dtype=np.float64)
def _sample_from_scores(values, *, do_sample, temperature, top_k, top_p, rng):
if not do_sample:
return _argmax(values)
scores = np.asarray(values, dtype=np.float32).copy() / float(temperature)
if top_k > 0 and top_k < scores.shape[0]:
threshold = float(np.sort(scores)[::-1][top_k - 1])
scores[scores < threshold] = float("-inf")
if top_p > 0 and top_p < 1:
indexed = list(enumerate(scores.tolist()))
indexed.sort(key=lambda x: x[1], reverse=True)
sorted_scores = np.asarray([x[1] for x in indexed], dtype=np.float32)
sorted_probs = _softmax(sorted_scores)
remove_mask = [False] * len(indexed)
cumulative = 0.0
for i, p in enumerate(sorted_probs):
cumulative += float(p)
if cumulative > float(top_p):
remove_mask[i] = True
for i in range(len(remove_mask) - 1, 0, -1):
remove_mask[i] = remove_mask[i - 1]
if remove_mask:
remove_mask[0] = False
for i, rm in enumerate(remove_mask):
if rm:
scores[indexed[i][0]] = float("-inf")
probs = _softmax(scores)
rv = float(rng.random())
for i, p in enumerate(probs):
rv -= float(p)
if rv <= 0:
return int(i)
return _argmax(scores)
def _apply_repetition_penalty(values, prev_ids, penalty):
if not prev_ids or penalty == 1.0:
return values
result = values.copy()
for tid in set(int(x) for x in prev_ids):
if 0 <= tid < result.shape[0]:
result[tid] = result[tid] * penalty if result[tid] < 0 else result[tid] / penalty
return result
def _argmax_with_repetition_penalty(values, prev_set, penalty):
best_idx, best_val = 0, float("-inf")
apply = bool(prev_set) and penalty != 1.0
for i, v in enumerate(values):
s = float(v)
if apply and i in prev_set:
s = s * penalty if s < 0 else s / penalty
if s > best_val:
best_val, best_idx = s, i
return int(best_idx)
def _sample_assistant_text_token(text_logits, manifest, gen_defaults, rng):
cids = np.asarray([
int(manifest["tts_config"]["audio_assistant_slot_token_id"]),
int(manifest["tts_config"]["audio_end_token_id"]),
], dtype=np.int32)
cs = text_logits[cids]
si = _sample_from_scores(cs, do_sample=bool(gen_defaults["do_sample"]),
temperature=float(gen_defaults["text_temperature"]),
top_k=min(int(gen_defaults["text_top_k"]), int(cs.shape[0])),
top_p=float(gen_defaults["text_top_p"]), rng=rng)
return int(cids[si])
def _sample_audio_token(audio_logits, prev_ids, prev_set, gen_defaults, rng):
rp = float(gen_defaults["audio_repetition_penalty"])
if not bool(gen_defaults["do_sample"]):
return _argmax_with_repetition_penalty(audio_logits, prev_set, rp)
penalized = _apply_repetition_penalty(audio_logits, prev_ids, rp)
return _sample_from_scores(penalized, do_sample=True,
temperature=float(gen_defaults["audio_temperature"]),
top_k=int(gen_defaults["audio_top_k"]),
top_p=float(gen_defaults["audio_top_p"]), rng=rng)
def _flatten3d(nested):
d0, d1, d2 = len(nested), len(nested[0]), len(nested[0][0])
data = np.zeros((d0 * d1 * d2,), dtype=np.int32)
off = 0
for i in range(d0):
for j in range(d1):
for k in range(d2):
data[off] = int(nested[i][j][k])
off += 1
return data, [d0, d1, d2]
def _flatten2d(nested):
d0, d1 = len(nested), len(nested[0])
data = np.zeros((d0 * d1,), dtype=np.int32)
off = 0
for i in range(d0):
for j in range(d1):
data[off] = int(nested[i][j])
off += 1
return data, [d0, d1]
def _extract_last_hidden(hs):
if hs.ndim == 2:
return hs.astype(np.float32, copy=False)
return hs[:, -1, :].astype(np.float32, copy=False)
def _slice_channel_major_audio(audio, start=0, end=None):
ch = int(audio.shape[1])
total = int(audio.shape[2])
s = max(0, int(start))
e = total if end is None else max(s, min(int(end), total))
return [audio[0, c, s:e].astype(np.float32, copy=False) for c in range(ch)]
def _contains_cjk(text):
for c in str(text or ""):
if "\u4e00" <= c <= "\u9fff" or "\u3400" <= c <= "\u4dbf" or "\u3040" <= c <= "\u30ff" or "\uac00" <= c <= "\ud7af":
return True
return False
def _prepare_text_for_sentence_chunking(text):
t = str(text or "").strip()
if not t:
raise ValueError("Text prompt cannot be empty.")
t = t.replace("\r", " ").replace("\n", " ")
while " " in t:
t = t.replace(" ", " ")
if _contains_cjk(t):
if t[-1] not in SENTENCE_END_PUNCTUATION:
t += "。"
return t
if t[:1].islower():
t = t[:1].upper() + t[1:]
if t[-1].isalnum():
t += "."
if len([x for x in t.split() if x]) < 5:
t = f" {t}"
return t
def _split_by_punct(text, punct):
sentences, cur, i = [], [], 0
while i < len(text):
c = text[i]
cur.append(c)
if c in punct:
la = i + 1
while la < len(text) and text[la] in CLOSING_PUNCTUATION:
cur.append(text[la])
la += 1
s = "".join(cur).strip()
if s:
sentences.append(s)
cur.clear()
while la < len(text) and text[la].isspace():
la += 1
i = la
continue
i += 1
tail = "".join(cur).strip()
if tail:
sentences.append(tail)
return sentences
def _merge_audio_channels(channels):
if not channels:
return np.zeros((0, 1), dtype=np.float32)
if len(channels) == 1:
return np.asarray(channels[0], dtype=np.float32).reshape(-1, 1)
ml = min(int(c.shape[0]) for c in channels)
return np.stack([np.asarray(c[:ml], dtype=np.float32) for c in channels], axis=1)
def _concat_waveforms(wfs):
if not wfs:
return np.zeros((0, 1), dtype=np.float32)
ne = [w for w in wfs if w.size > 0]
if not ne:
return np.zeros((0, max(1, int(wfs[0].shape[1]) if wfs[0].ndim > 1 and wfs[0].shape[1] > 0 else 1)), dtype=np.float32)
return np.concatenate(ne, axis=0)
def _write_wav(path, waveform, sr):
p = Path(path).expanduser().resolve()
p.parent.mkdir(parents=True, exist_ok=True)
audio = np.asarray(waveform, dtype=np.float32)
if audio.ndim == 1:
audio = audio.reshape(-1, 1)
pcm16 = np.round(np.clip(audio, -1.0, 1.0) * 32767.0).astype(np.int16)
with wave.open(str(p), "wb") as f:
f.setnchannels(int(pcm16.shape[1]))
f.setsampwidth(2)
f.setframerate(int(sr))
f.writeframes(pcm16.tobytes())
return p
@dataclass
class CodecStreamingSession:
codec_meta: dict
session: ort.InferenceSession
def __post_init__(self):
self.transformer_specs = list(self.codec_meta.get("streaming_decode", {}).get("transformer_offsets", []))
self.attention_specs = list(self.codec_meta.get("streaming_decode", {}).get("attention_caches", []))
self.state_feeds = {}
self.reset()
def reset(self):
self.state_feeds = {}
for s in self.transformer_specs:
self.state_feeds[str(s["input_name"])] = np.zeros(tuple(s["shape"]), dtype=np.int32)
for s in self.attention_specs:
self.state_feeds[str(s["offset_input_name"])] = np.zeros(tuple(s["offset_shape"]), dtype=np.int32)
self.state_feeds[str(s["cached_keys_input_name"])] = np.zeros(tuple(s["cache_shape"]), dtype=np.float32)
self.state_feeds[str(s["cached_values_input_name"])] = np.zeros(tuple(s["cache_shape"]), dtype=np.float32)
self.state_feeds[str(s["cached_positions_input_name"])] = np.full(tuple(s["positions_shape"]), -1, dtype=np.int32)
def run_frames(self, frame_rows):
if not frame_rows:
return None
nq = int(self.codec_meta["codec_config"]["num_quantizers"])
fc = len(frame_rows)
ac = np.zeros((1, fc, nq), dtype=np.int32)
for fi, fr in enumerate(frame_rows):
for ci in range(nq):
ac[0, fi, ci] = int(fr[ci] if ci < len(fr) else 0)
feeds = {"audio_codes": ac, "audio_code_lengths": np.asarray([fc], dtype=np.int32)}
feeds.update(self.state_feeds)
outs = self.session.run(None, feeds)
onames = [o.name for o in self.session.get_outputs()]
named = dict(zip(onames, outs, strict=True))
for s in self.transformer_specs:
self.state_feeds[str(s["input_name"])] = named[str(s["output_name"])]
for s in self.attention_specs:
self.state_feeds[str(s["offset_input_name"])] = named[str(s["offset_output_name"])]
self.state_feeds[str(s["cached_keys_input_name"])] = named[str(s["cached_keys_output_name"])]
self.state_feeds[str(s["cached_values_input_name"])] = named[str(s["cached_values_output_name"])]
self.state_feeds[str(s["cached_positions_input_name"])] = named[str(s["cached_positions_output_name"])]
return named["audio"], int(named["audio_lengths"].reshape(-1)[0])
def _resolve_stream_decode_frame_budget(emitted_total, sr, first_audio_at):
if not first_audio_at or sr <= 0:
return 1
elapsed = max(0.0, time.perf_counter() - first_audio_at)
lead = emitted_total / float(sr) - elapsed
if not first_audio_at or lead < 0.20:
return 1
if lead < 0.55:
return 2
if lead < 1.10:
return 4
return 8
class MossTtsRuntime:
def __init__(self, model_dir, thread_count=2, max_new_frames=375):
self.model_dir = Path(model_dir).expanduser().resolve()
self.thread_count = max(1, int(thread_count))
self.manifest_path = self._find_manifest()
self.manifest_dir = self.manifest_path.parent
self.manifest = json.loads(self.manifest_path.read_text("utf-8"))
if max_new_frames is not None:
self.manifest["generation_defaults"]["max_new_frames"] = int(max_new_frames)
self.rng = np.random.default_rng(1234)
self.tts_meta_path = self._resolve_path(self.manifest["model_files"]["tts_meta"])
self.codec_meta_path = self._resolve_path(self.manifest["model_files"]["codec_meta"])
self.tts_meta = json.loads(self.tts_meta_path.read_text("utf-8"))
self.codec_meta = json.loads(self.codec_meta_path.read_text("utf-8"))
tok_path = str(self._resolve_path(self.manifest["model_files"].get("tokenizer_model", "tokenizer.model")))
self.sp = spm.SentencePieceProcessor(model_file=tok_path)
self.sessions = self._create_sessions()
self.codec_stream = CodecStreamingSession(self.codec_meta, self.sessions["codec_decode_step"])
def _find_manifest(self):
for rp in MANIFEST_CANDIDATE_RELATIVE_PATHS:
c = (self.model_dir / rp).resolve()
if c.is_file():
return c
raise FileNotFoundError(f"browser_poc_manifest.json not found under {self.model_dir}")
def _resolve_path(self, rel):
resolved = (self.manifest_dir / Path(rel)).resolve()
if resolved.exists():
return resolved
rt = str(rel).replace("\\", "/")
for old, new in MODEL_DIR_ALIAS_MAP.items():
frag = f"/{old}/"
if frag in f"/{rt}/":
rw = (self.manifest_dir / Path(rt.replace(old, new))).resolve()
if rw.exists():
return rw
return resolved
def _session(self, p):
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.intra_op_num_threads = self.thread_count
opts.inter_op_num_threads = 1
return ort.InferenceSession(str(p), sess_options=opts, providers=["CPUExecutionProvider"])
def _create_sessions(self):
td = self.tts_meta_path.parent
cd = self.codec_meta_path.parent
sess = {
"prefill": self._session(td / self.tts_meta["files"]["prefill"]),
"decode": self._session(td / self.tts_meta["files"]["decode_step"]),
"local_decoder": self._session(td / self.tts_meta["files"]["local_decoder"]),
"codec_encode": self._session(cd / self.codec_meta["files"]["encode"]),
"codec_decode": self._session(cd / self.codec_meta["files"]["decode_full"]),
"codec_decode_step": self._session(cd / self.codec_meta["files"]["decode_step"]),
}
if self.tts_meta["files"].get("local_greedy_frame"):
sess["local_greedy_frame"] = self._session(td / self.tts_meta["files"]["local_greedy_frame"])
if self.tts_meta["files"].get("local_fixed_sampled_frame"):
sess["local_fixed_sampled_frame"] = self._session(td / self.tts_meta["files"]["local_fixed_sampled_frame"])
if self.tts_meta["files"].get("local_cached_step"):
sess["local_cached_step"] = self._session(td / self.tts_meta["files"]["local_cached_step"])
return sess
def list_builtin_voices(self):
return list(self.manifest["builtin_voices"])
def encode_text(self, text):
return [int(t) for t in self.sp.encode(str(text or ""), out_type=int)]
def count_text_tokens(self, text):
return len(self.encode_text(text))
def _load_ref_audio(self, path):
wf, sr = torchaudio.load(str(Path(path).expanduser().resolve()))
wf = wf.to(torch.float32)
tsr = int(self.codec_meta["codec_config"]["sample_rate"])
tch = int(self.codec_meta["codec_config"]["channels"])
if sr != tsr:
wf = torchaudio.functional.resample(wf, sr, tsr)
cc = int(wf.shape[0])
if cc == tch:
pass
elif cc == 1 and tch > 1:
wf = wf.repeat(tch, 1)
elif cc > 1 and tch == 1:
wf = wf.mean(dim=0, keepdim=True)
else:
raise ValueError(f"Unsupported channel conversion: {cc} -> {tch}")
return wf.unsqueeze(0).detach().cpu().numpy().astype(np.float32, copy=False)
def encode_ref_audio(self, path):
wf = self._load_ref_audio(path)
wl = int(wf.shape[-1])
outs = self.sessions["codec_encode"].run(None, {"waveform": wf, "input_lengths": np.asarray([wl], dtype=np.int32)})
onames = [o.name for o in self.sessions["codec_encode"].get_outputs()]
named = dict(zip(onames, outs, strict=True))
ac = np.asarray(named["audio_codes"], dtype=np.int32)
cl = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0])
nq = int(self.codec_meta["codec_config"]["num_quantizers"])
codes = []
for fi in range(cl):
codes.append([int(ac[0, fi, qi]) for qi in range(nq)])
return codes
def resolve_prompt_codes(self, *, voice, prompt_audio_path):
if prompt_audio_path:
return self.encode_ref_audio(prompt_audio_path)
v = str(voice or self.list_builtin_voices()[0]["voice"])
row = next((x for x in self.list_builtin_voices() if x["voice"] == v), None)
if row is None:
raise ValueError(f"Built-in voice not found: {v}")
return list(row["prompt_audio_codes"])
def build_text_rows(self, token_ids):
rw = int(self.manifest["tts_config"]["n_vq"]) + 1
rows = []
for tid in token_ids:
r = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * rw
r[0] = int(tid)
rows.append(r)
return rows
def build_audio_prefix_rows(self, codes, slot_id=None):
rw = int(self.manifest["tts_config"]["n_vq"]) + 1
sid = int(self.manifest["tts_config"]["audio_user_slot_token_id"] if slot_id is None else slot_id)
rows = []
for cr in codes:
r = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * rw
r[0] = sid
for i in range(min(len(cr), rw - 1)):
r[i + 1] = int(cr[i])
rows.append(r)
return rows
def build_request_rows(self, codes, text_ids):
prefix = [*self.manifest["prompt_templates"]["user_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"])]
suffix = [int(self.manifest["tts_config"]["audio_end_token_id"]), *self.manifest["prompt_templates"]["user_prompt_after_reference_token_ids"], *text_ids, *self.manifest["prompt_templates"]["assistant_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"])]
rows = [*self.build_text_rows(prefix), *self.build_audio_prefix_rows(codes), *self.build_text_rows(suffix)]
return {"inputIds": rows, "attentionMask": [[1 for _ in rows]]}
def run_local_decoder(self, gh, text_tid, frame_prefix):
nvq = int(self.manifest["tts_config"]["n_vq"])
apad = int(self.manifest["tts_config"]["audio_pad_token_id"])
pp = np.full((1, nvq - 1), apad, dtype=np.int32)
for i in range(min(len(frame_prefix), nvq - 1)):
pp[0, i] = int(frame_prefix[i])
outs = self.sessions["local_decoder"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "text_token_id": np.asarray([int(text_tid)], dtype=np.int32), "audio_prefix_token_ids": pp})
on = [o.name for o in self.sessions["local_decoder"].get_outputs()]
nd = dict(zip(on, outs, strict=True))
return nd["text_logits"].reshape(-1), nd["audio_logits"]
def create_empty_local_past(self):
ll = int(self.tts_meta["model_config"]["local_layers"])
lh = int(self.tts_meta["model_config"]["local_heads"])
lhd = int(self.tts_meta["model_config"]["local_head_dim"])
return {n: np.zeros((1, 0, lh, lhd), dtype=np.float32) for li in range(ll) for n in (f"local_past_key_{li}", f"local_past_value_{li}")}
def run_local_cached_step(self, gh, *, text_tid, audio_tid, ch_idx, step_type, past_vl, past):
outs = self.sessions["local_cached_step"].run(None, {
"global_hidden": gh.astype(np.float32, copy=False),
"text_token_id": np.asarray([int(text_tid)], dtype=np.int32),
"audio_token_id": np.asarray([int(audio_tid)], dtype=np.int32),
"channel_index": np.asarray([int(ch_idx)], dtype=np.int32),
"step_type": np.asarray([int(step_type)], dtype=np.int32),
"past_valid_lengths": np.asarray([int(past_vl)], dtype=np.int32),
**past,
})
on = [o.name for o in self.sessions["local_cached_step"].get_outputs()]
nd = dict(zip(on, outs, strict=True))
npast = {n.replace("local_present_", "local_past_"): nd[n] for n in self.tts_meta["onnx"]["local_cached_output_names"][2:]}
return nd["text_logits"].reshape(-1), nd["audio_logits"], npast
def run_local_greedy_frame(self, gh, *, prev_sets, rep_penalty):
acs = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0])
nvq = int(self.manifest["tts_config"]["n_vq"])
rm = np.zeros((1, nvq, acs), dtype=np.int32)
for ci, ts in enumerate(prev_sets):
for tid in ts:
if 0 <= tid < acs:
rm[0, ci, tid] = 1
outs = self.sessions["local_greedy_frame"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "repetition_seen_mask": rm, "repetition_penalty": np.asarray([float(rep_penalty)], dtype=np.float32)})
on = [o.name for o in self.sessions["local_greedy_frame"].get_outputs()]
nd = dict(zip(on, outs, strict=True))
cont = bool(int(np.asarray(nd["should_continue"]).reshape(-1)[0]))
ftids = np.asarray(nd["frame_token_ids"]).reshape(-1).astype(np.int32, copy=False).tolist()
return cont, [int(x) for x in ftids]
def run_local_fixed_sampled_frame(self, gh, *, prev_sets):
acs = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0])
nvq = int(self.manifest["tts_config"]["n_vq"])
rm = np.zeros((1, nvq, acs), dtype=np.int32)
for ci, ts in enumerate(prev_sets):
for tid in ts:
if 0 <= tid < acs:
rm[0, ci, tid] = 1
aru = np.asarray([min(0.99999994, max(0.0, float(self.rng.random())))], dtype=np.float32)
au = np.asarray([[min(0.99999994, max(0.0, float(self.rng.random()))) for _ in range(nvq)]], dtype=np.float32)
outs = self.sessions["local_fixed_sampled_frame"].run(None, {"global_hidden": gh.astype(np.float32, copy=False), "repetition_seen_mask": rm, "assistant_random_u": aru, "audio_random_u": au})
on = [o.name for o in self.sessions["local_fixed_sampled_frame"].get_outputs()]
nd = dict(zip(on, outs, strict=True))
ftids = np.asarray(nd["frame_token_ids"]).reshape(-1).astype(np.int32, copy=False).tolist()
cont = bool(int(np.asarray(nd["should_continue"]).reshape(-1)[0]))
return cont, [int(x) for x in ftids]
def slice_audio_channel_logits(self, alogits, ci):
pc = int(alogits.shape[-1])
flat = alogits.reshape(-1)
return flat[ci * pc:(ci + 1) * pc]
def decode_full_audio(self, frames):
if not frames:
return [], 0
ac, dims = _flatten3d([frames])
outs = self.sessions["codec_decode"].run(None, {"audio_codes": ac.reshape(dims), "audio_code_lengths": np.asarray([len(frames)], dtype=np.int32)})
on = [o.name for o in self.sessions["codec_decode"].get_outputs()]
nd = dict(zip(on, outs, strict=True))
al = int(nd["audio_lengths"].reshape(-1)[0])
return _slice_channel_major_audio(nd["audio"], 0, al), al
def generate_audio_frames(self, req_rows, on_frame=None):
gd = self.manifest["generation_defaults"]
rw = int(self.manifest["tts_config"]["n_vq"]) + 1
pids, pdims = _flatten3d([req_rows["inputIds"]])
pmask, pmdims = _flatten2d(req_rows["attentionMask"])
outs = self.sessions["prefill"].run(None, {"input_ids": pids.reshape(pdims), "attention_mask": pmask.reshape(pmdims)})
on = [o.name for o in self.sessions["prefill"].get_outputs()]
nd = dict(zip(on, outs, strict=True))
gh = _extract_last_hidden(nd["global_hidden"])
pvl = sum(int(x) for x in req_rows["attentionMask"][0])
past = {n.replace("present_", "past_"): nd[n] for n in self.tts_meta["onnx"]["prefill_output_names"][1:]}
gen_frames = []
prev_by_ch = [[] for _ in range(int(self.manifest["tts_config"]["n_vq"]))]
prev_set_by_ch = [set() for _ in range(int(self.manifest["tts_config"]["n_vq"]))]
for si in range(int(gd["max_new_frames"])):
frame = []
if "local_greedy_frame" in self.sessions and not bool(gd["do_sample"]):
cont, frame = self.run_local_greedy_frame(gh, prev_sets=prev_set_by_ch, rep_penalty=float(gd["audio_repetition_penalty"]))
if not cont:
break
for ci, st in enumerate(frame):
prev_by_ch[ci].append(st)
prev_set_by_ch[ci].add(st)
elif "local_fixed_sampled_frame" in self.sessions and gd["sample_mode"] == SAMPLE_MODE_FIXED:
cont, frame = self.run_local_fixed_sampled_frame(gh, prev_sets=prev_set_by_ch)
if not cont:
break
for ci, st in enumerate(frame):
prev_by_ch[ci].append(st)
prev_set_by_ch[ci].add(st)
elif "local_cached_step" in self.sessions:
lp = self.create_empty_local_past()
lpvl = 0
tl, _, lp = self.run_local_cached_step(gh, text_tid=0, audio_tid=0, ch_idx=0, step_type=0, past_vl=lpvl, past=lp)
lpvl += 1
ntt = _sample_assistant_text_token(tl, self.manifest, gd, self.rng)
if ntt != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]):
break
_, alogits, lp = self.run_local_cached_step(gh, text_tid=ntt, audio_tid=0, ch_idx=0, step_type=1, past_vl=lpvl, past=lp)
lpvl += 1
fl = self.slice_audio_channel_logits(alogits, 0).astype(np.float32, copy=False)
st = _sample_audio_token(fl, prev_by_ch[0], prev_set_by_ch[0], gd, self.rng)
frame.append(st)
prev_by_ch[0].append(st)
prev_set_by_ch[0].add(st)
prev = st
for ci in range(1, int(self.manifest["tts_config"]["n_vq"])):
_, alogits, lp = self.run_local_cached_step(gh, text_tid=0, audio_tid=prev, ch_idx=ci - 1, step_type=2, past_vl=lpvl, past=lp)
lpvl += 1
cl = self.slice_audio_channel_logits(alogits, ci).astype(np.float32, copy=False)
st = _sample_audio_token(cl, prev_by_ch[ci], prev_set_by_ch[ci], gd, self.rng)
frame.append(st)
prev_by_ch[ci].append(st)
prev_set_by_ch[ci].add(st)
prev = st
else:
tl, _ = self.run_local_decoder(gh, 0, [])
ntt = _sample_assistant_text_token(tl, self.manifest, gd, self.rng)
if ntt != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]):
break
for ci in range(int(self.manifest["tts_config"]["n_vq"])):
_, alogits = self.run_local_decoder(gh, ntt, frame)
cl = self.slice_audio_channel_logits(alogits, ci).astype(np.float32, copy=False)
st = _sample_audio_token(cl, prev_by_ch[ci], prev_set_by_ch[ci], gd, self.rng)
frame.append(st)
prev_by_ch[ci].append(st)
prev_set_by_ch[ci].add(st)
gen_frames.append(frame)
nr = np.full((1, 1, rw), int(self.manifest["tts_config"]["audio_pad_token_id"]), dtype=np.int32)
nr[0, 0, 0] = int(self.manifest["tts_config"]["audio_assistant_slot_token_id"])
for i, t in enumerate(frame):
nr[0, 0, i + 1] = int(t)
df = {"input_ids": nr, "past_valid_lengths": np.asarray([pvl], dtype=np.int32)}
for iname in self.tts_meta["onnx"]["decode_input_names"][2:]:
df[iname] = past[iname]
dout = self.sessions["decode"].run(None, df)
dn = [o.name for o in self.sessions["decode"].get_outputs()]
dnd = dict(zip(dn, dout, strict=True))
gh = _extract_last_hidden(dnd["global_hidden"])
pvl += 1
past = {n.replace("present_", "past_"): dnd[n] for n in self.tts_meta["onnx"]["decode_output_names"][1:]}
if on_frame is not None:
on_frame(gen_frames, si, frame)
return gen_frames
def decode_full_audio_safe(self, frames):
try:
ch_arrays, _ = self.decode_full_audio(frames)
return _merge_audio_channels(ch_arrays)
except Exception as exc:
import logging
logging.warning("full codec decode failed, falling back: %s", exc)
self.codec_stream.reset()
nch = int(self.codec_meta["codec_config"]["channels"])
merged = [[] for _ in range(nch)]
try:
for si in range(0, len(frames), 8):
chunk = frames[si:si + 8]
dec = self.codec_stream.run_frames(chunk)
if dec is None:
continue
audio, al = dec
if al <= 0:
continue
for ci in range(nch):
merged[ci].append(np.asarray(audio[0, ci, :al], dtype=np.float32))
finally:
self.codec_stream.reset()
return _merge_audio_channels([np.concatenate(c) if c else np.zeros((0,), dtype=np.float32) for c in merged])
def split_text_chunks(self, text, max_tokens=75):
t = str(text or "").strip()
if not t:
return []
pieces = []
pref = set(CLAUSE_SPLIT_PUNCTUATION) | set(SENTENCE_END_PUNCTUATION) | {" "}
while t:
if self.count_text_tokens(t) <= max_tokens:
pieces.append(t)
break
lo, hi, best = 1, len(t), 1
while lo <= hi:
mid = (lo + hi) // 2
cand = t[:mid].strip()
if cand and self.count_text_tokens(cand) <= max_tokens:
best = mid
lo = mid + 1
else:
hi = mid - 1
if not cand:
lo = mid + 1
ci = best
pf = t[:best]
pi = -1
for si in range(len(pf) - 1, max(-1, len(pf) - 25), -1):
if pf[si] in pref:
pi = si + 1
break
if pi > 0:
ci = pi
piece = t[:ci].strip()
if not piece:
piece = t[:best].strip()
ci = best
pieces.append(piece)
t = t[ci:].strip()
return pieces if len(pieces) > 1 else [str(text or "").strip()]
def synthesize(self, *, text, voice=None, prompt_audio_path=None, sample_mode="fixed", do_sample=True, streaming=True, max_new_frames=375):
gd = self.manifest["generation_defaults"]
gd["max_new_frames"] = int(max_new_frames)
nsm = _normalize_sample_mode(sample_mode, do_sample)
gd["sample_mode"] = nsm
gd["do_sample"] = nsm != SAMPLE_MODE_GREEDY
codes = self.resolve_prompt_codes(voice=voice, prompt_audio_path=prompt_audio_path)
tid = self.encode_text(text)
req = self.build_request_rows(codes, tid)
if streaming:
pending = []
emitted = []
emitted_total = 0
first_at = None
self.codec_stream.reset()
def decode_pending(force):
nonlocal emitted_total, first_at
pc = len(pending)
if pc <= 0:
return
sr = int(self.codec_meta["codec_config"]["sample_rate"])
budget = _resolve_stream_decode_frame_budget(emitted_total, sr, first_at)
if not force and pc < max(1, budget):
return
fb = pc if force else min(pc, max(1, budget))
chunk = pending[:fb]
del pending[:fb]
dec = self.codec_stream.run_frames(chunk)
if dec is None:
return
audio, al = dec
if al <= 0:
return
if first_at is None:
first_at = time.perf_counter()
emitted_total += al
nch = int(self.codec_meta["codec_config"]["channels"])
emitted.append(_merge_audio_channels([audio[0, c, :al] for c in range(nch)]))
def on_frame(gf, si, f):
pending.append(list(f))
decode_pending(False)
try:
gf = self.generate_audio_frames(req, on_frame=on_frame)
decode_pending(True)
finally:
self.codec_stream.reset()
waveform = _concat_waveforms(emitted)
else:
gf = self.generate_audio_frames(req)
waveform = self.decode_full_audio_safe(gf)
sr = int(self.codec_meta["codec_config"]["sample_rate"])
out_path = OUTPUT_DIR / "output.wav"
_write_wav(out_path, waveform, sr)
return {"audio_path": str(out_path), "sample_rate": sr, "frames": len(gf)}
def ensure_models():
tts_dir = MODEL_DIR / "MOSS-TTS-Nano-100M-ONNX"
codec_dir = MODEL_DIR / "MOSS-Audio-Tokenizer-Nano-ONNX"
if not (tts_dir / "browser_poc_manifest.json").is_file():
tts_dir.mkdir(parents=True, exist_ok=True)
snapshot_download(DEFAULT_TTS_REPO, local_dir=str(tts_dir), local_dir_use_symlinks=False, allow_patterns=["*.onnx", "*.data", "*.json", "tokenizer.model"])
src = tts_dir / "MOSS-TTS-Nano-100M-ONNX"
if src.is_dir():
for f in src.iterdir():
dst = tts_dir / f.name
if not dst.exists():
shutil.move(str(f), str(dst))
if not (codec_dir / "codec_browser_onnx_meta.json").is_file():
codec_dir.mkdir(parents=True, exist_ok=True)
snapshot_download(DEFAULT_CODEC_REPO, local_dir=str(codec_dir), local_dir_use_symlinks=False, allow_patterns=["*.onnx", "*.data", "*.json"])
src = codec_dir / "MOSS-Audio-Tokenizer-Nano-ONNX"
if src.is_dir():
for f in src.iterdir():
dst = codec_dir / f.name
if not dst.exists():
shutil.move(str(f), str(dst))
runtime = None
def get_runtime():
global runtime
if runtime is not None:
return runtime
ensure_models()
runtime = MossTtsRuntime(MODEL_DIR, thread_count=2, max_new_frames=375)
return runtime
def synthesize_gradio(text, voice, audio_path, sample_mode, max_frames):
rt = get_runtime()
t0 = time.time()
result = rt.synthesize(
text=text,
voice=voice if not audio_path else None,
prompt_audio_path=audio_path if audio_path else None,
sample_mode=sample_mode,
do_sample=(sample_mode != "greedy"),
streaming=True,
max_new_frames=int(max_frames),
)
elapsed = time.time() - t0
return result["audio_path"], f"Done in {elapsed:.1f}s | {result['sample_rate']}Hz | {result['frames']} frames"
VOICES = ["Junhao", "Zhiming", "Weiguo", "Xiaoyu", "Yuewen", "Lingyu", "Trump", "Ava", "Bella", "Adam", "Nathan", "Soyo", "Saki", "Mortis", "Umiri", "Mei", "Anon", "Arisa"]
with gr.Blocks(title="MOSS-TTS-Nano ONNX") as demo:
gr.Markdown("# MOSS-TTS-Nano-100M-ONNX\nCPU-only TTS with voice cloning. First run downloads ~730MB model.")
with gr.Row():
with gr.Column():
text_in = gr.Textbox(label="Text", value="Hello, welcome to MOSS TTS Nano.", lines=3)
with gr.Row():
voice_in = gr.Dropdown(choices=VOICES, value="Junhao", label="Voice (overridden by ref audio)")
ref_audio = gr.Audio(label="Reference Audio (optional, for voice cloning)", type="filepath")
with gr.Row():
sample_mode = gr.Dropdown(choices=["fixed", "greedy", "full"], value="fixed", label="Sample Mode")
max_frames = gr.Slider(16, 750, value=375, step=1, label="Max Frames")
btn = gr.Button("Synthesize", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="Generated Audio", type="filepath")
info_out = gr.Textbox(label="Info")
btn.click(fn=synthesize_gradio, inputs=[text_in, voice_in, ref_audio, sample_mode, max_frames], outputs=[audio_out, info_out])
if __name__ == "__main__":
get_runtime()
demo.launch()