""" moe_gradio.py -- adapter that drives the agents/ Gradio MoE panel with the ModMind SpikeWhale specialists instead of the byte-level ByteGPT experts. It exposes the SAME public API as agents/orchestrator.py: get_moe(device) -> obj with .available() .reload() .route() .shared_latent() .generate() .run() and .run() returns the SAME dict shape panel.py consumes: {winner, weights, bits_per_byte, steps, generation, shared_latent} so agents/panel.py can use it as a drop-in backend (see the MM_MOE_BACKEND switch there). What this adapter handles (the real differences vs the byte-level experts): * Each specialist has its OWN SpikeTokenizer (different vocab), so per-TOKEN NLL is not comparable across experts. Routing is therefore by BITS-PER-BYTE -- total NLL divided by the query's raw UTF-8 byte count -- which IS comparable across tokenizers. * d_latent=256 (vs 64); the shared RecursiveLink is built at 256, and every ModMind specialist already shares that width, so latent fusion works natively. * Checkpoints are ModMind format ({"model_state": ...}) under //checkpoints/step_*.pt, loaded exactly like train_link.py does, and hot-reloaded when newer ones appear. Display mapping (ModMind domain -> the panel's expert slot/emoji): language -> language (the panel shows 📖), reasoning -> math (➗), tool_use -> tool (🛠️) """ from __future__ import annotations import glob import math import os import random import string from pathlib import Path import torch import torch.nn.functional as F from model import RecursiveLink, SpikeWhaleLM from specialist_presets import specialist_config from spike_tokenizer import SpikeTokenizer HERE = Path(__file__).parent # checkpoints may live on a Modal Volume; tokenizers stay with the code (same rule as train_link.py) CKPT_ROOT = Path(os.environ.get("MM_CKPT_ROOT", str(HERE))) D_LATENT = 256 MAX_CTX = 512 # cap prompt/scoring length so CPU routing + generation stay snappy # (modmind_domain, panel_slot). reasoning trains on FineMath, so it IS the "math" expert. SLOTS = [("language", "language"), ("reasoning", "math"), ("tool_use", "tool")] DOMAIN2SLOT = {d: s for d, s in SLOTS} # trained bridges to surface in the panel as (asker_domain, consultant_domain) "consult" demos LINKS = [("language", "reasoning")] KEY_CHARS = string.ascii_letters + string.digits # must match train_link.py def _latest_ckpt(domain: str): base = CKPT_ROOT / domain / "checkpoints" sft = base / "model.safetensors" if sft.exists(): return str(sft) # slim fp16 safetensors (preferred) cks = sorted(glob.glob(str(base / "step_*.pt"))) return cks[-1] if cks else None # fall back to a full .pt checkpoint class SpikeWhaleMoE: def __init__(self, device: str = "cpu"): self.device = device self.models, self.toks, self.steps, self._mtime = {}, {}, {}, {} self.link = RecursiveLink(d_latent=D_LATENT).to(device).eval() self.trained_link = None # the TRAINED bridge (train_link.py output), for the consult demo self.bridge_asker = None # the FULL fine-tuned asker, for reproducible key-recall self.link_meta = None self.qa_link = None # the question->answer bridge (train_qa_link.py output) self.qa_asker = None self.qa_meta = None self._qa_mtime = None self.reload() def reload(self): """Load/refresh any specialist whose checkpoint exists or changed on disk.""" for domain, slot in SLOTS: ck = _latest_ckpt(domain) tok_path = HERE / domain / "tokenizer.json" if ck is None or not tok_path.exists(): continue mt = os.path.getmtime(ck) if self._mtime.get(slot) == mt: continue cfg = specialist_config(domain) m = SpikeWhaleLM(cfg).to(self.device).eval() if ck.endswith(".safetensors"): from safetensors.torch import load_file from safetensors import safe_open sd = load_file(ck, device=self.device) sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()} m.load_state_dict(sd) with safe_open(ck, framework="pt") as f: step = int((f.metadata() or {}).get("step", 0)) else: st = torch.load(ck, map_location=self.device, weights_only=False) m.load_state_dict(st["model_state"]); step = int(st.get("step", 0)) self.models[slot] = m self.toks[slot] = SpikeTokenizer(vocab_file=str(tok_path)) self.steps[slot] = step self._mtime[slot] = mt self._load_links() self._load_qa() return list(self.models) def available(self): return list(self.models) def ensure_device(self, device): """Move all loaded models/links to `device`. On ZeroGPU this is called inside the @spaces.GPU function so CUDA is only touched there, never at import/startup. .to(device) is a no-op when already there, so it's safe to call on every request.""" for k in list(self.models): self.models[k] = self.models[k].to(device) if self.link is not None: self.link = self.link.to(device) if self.trained_link is not None: self.trained_link = self.trained_link.to(device) if self.bridge_asker is not None: self.bridge_asker = self.bridge_asker.to(device) if self.qa_link is not None: self.qa_link = self.qa_link.to(device) if self.qa_asker is not None: self.qa_asker = self.qa_asker.to(device) cache = getattr(self, "_merge_cache", None) if cache is not None: self._merge_cache = (cache[0], cache[1].to(device)) self.device = device def to_gpu_if_available(self): try: if torch.cuda.is_available(): self.ensure_device("cuda") except Exception: pass def _load_links(self): """Load a TRAINED RecursiveLink bridge (train_link.py output) from links/__from__.pt and overlay its bridge-trained injection path onto the asker. The injection only fires when we pass inject_latent (the consult demo), so normal routing/generation is unaffected.""" self.trained_link = None self.bridge_asker = None self.link_meta = None self._bridge_ali = None for asker_domain, consultant_domain in LINKS: a, c = DOMAIN2SLOT.get(asker_domain), DOMAIN2SLOT.get(consultant_domain) if a not in self.models or c not in self.models: continue base = CKPT_ROOT / "links" / f"{asker_domain}__from__{consultant_domain}" sft, pt = Path(str(base) + ".safetensors"), Path(str(base) + ".pt") if sft.exists(): from safetensors.torch import load_file from safetensors import safe_open t = load_file(str(sft), device=self.device) t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()} with safe_open(str(sft), framework="pt") as f: md = f.metadata() or {} link_sd = {k[5:]: v for k, v in t.items() if k.startswith("link.")} ali_sd = {k[4:]: v for k, v in t.items() if k.startswith("ali.")} asker_sd = {k[6:]: v for k, v in t.items() if k.startswith("asker.")} or None key_len = int(md.get("key_len", 6)); prompt = md.get("prompt", "KEY> ") wl = float(md.get("with_latent", "nan")); nl = float(md.get("without_latent", "nan")) elif pt.exists(): st = torch.load(str(pt), map_location=self.device, weights_only=False) link_sd = st["link"]; ali_sd = st["asker_latent_inject"]; asker_sd = st.get("asker_state") key_len = int(st.get("key_len", 6)); prompt = st.get("prompt", "KEY> ") wl = float(st.get("with_latent", float("nan"))); nl = float(st.get("without_latent", float("nan"))) else: continue link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval() link.load_state_dict(link_sd) self._bridge_ali = ali_sd try: self.models[a].model.latent_inject.load_state_dict(ali_sd) except Exception: pass # FULL fine-tuned asker (if present) -> lets us reproduce key-recall live. if asker_sd: ba = SpikeWhaleLM(specialist_config(asker_domain)).to(self.device).eval() ba.load_state_dict(asker_sd) self.bridge_asker = ba self.trained_link = link self.link_meta = {"asker": a, "consultant": c, "key_len": key_len, "prompt": prompt, "with_latent": wl, "without_latent": nl} return # one bridge is enough for the panel demo def _load_qa(self): """Load the question->answer bridge (train_qa_link.py output): a NEW RecursiveLink + a fully fine-tuned asker that answer arithmetic shown only to the consultant. mtime-cached (the file is ~190MB) and hot-reloaded as training improves it.""" a_dom, c_dom = LINKS[0] path = CKPT_ROOT / "links" / f"qa__{a_dom}__from__{c_dom}.safetensors" if not path.exists(): self.qa_link = self.qa_asker = self.qa_meta = None self._qa_mtime = None return mt = os.path.getmtime(path) if self._qa_mtime == mt and self.qa_link is not None: return a, c = DOMAIN2SLOT[a_dom], DOMAIN2SLOT[c_dom] if a not in self.models or c not in self.models: return from safetensors.torch import load_file from safetensors import safe_open t = load_file(str(path), device=self.device) t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()} with safe_open(str(path), framework="pt") as f: md = f.metadata() or {} link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval() link.load_state_dict({k[5:]: v for k, v in t.items() if k.startswith("link.")}) ask = SpikeWhaleLM(specialist_config(a_dom)).to(self.device).eval() ask.load_state_dict({k[6:]: v for k, v in t.items() if k.startswith("asker.")}) self.qa_link, self.qa_asker = link, ask self.qa_meta = {"asker": a, "consultant": c, "ans_len": int(md.get("ans_len", 3)), "prompt": md.get("prompt", "ANS> "), "holdout_exact": float(md.get("holdout_exact", "nan")), "step": int(md.get("step", 0))} self._qa_mtime = mt def qa_available(self): return self.qa_link is not None and self.qa_asker is not None def qa_info(self): return dict(self.qa_meta) if self.qa_meta else None @torch.no_grad() def ask_math(self, a: int, op: str, b: int, ablate: bool = False): """Language answers an arithmetic question SHOWN ONLY to Math: the frozen consultant encodes the question, the QA RecursiveLink carries it across, and the QA asker decodes the answer digits autoregressively from the latent alone (its own input is just the 'ANS> ' prompt -- the question never reaches it as text). ablate=True zeros the latent: the asker then has no question at all.""" if not self.qa_available(): return {"error": "qa bridge not trained yet"} meta = self.qa_meta a, b = int(a), int(b) if op not in ("+", "-", "*"): return {"error": "op must be one of + - *"} truth = {"+": a + b, "-": a - b, "*": a * b}[op] if not (0 <= truth < 10 ** meta["ans_len"]): return {"error": "answer out of the trained range"} q = f"{a} {op} {b} =" c_ids = torch.tensor([self.toks[meta["consultant"]].encode(q, add_special_tokens=False)], device=self.device) latent = self.models[meta["consultant"]](input_ids=c_ids).latent inj = torch.zeros_like(self.qa_link(latent)) if ablate else self.qa_link(latent) a_tok = self.toks[meta["asker"]] ids = torch.tensor([a_tok.encode(meta["prompt"], add_special_tokens=False)], device=self.device) plen = ids.shape[1] for _ in range(meta["ans_len"]): logits = self.qa_asker(input_ids=ids, inject_latent=inj).logits[:, -1, :] ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1) digits = a_tok.decode(ids[0, plen:].tolist()) want = f"{truth:0{meta['ans_len']}d}" return {"question": q, "digits": digits, "answer": digits.lstrip("0") or "0", "truth": truth, "want": want, "ok": [i < len(digits) and digits[i] == ch for i, ch in enumerate(want)], "exact": digits == want} def key_recall_available(self): return self.bridge_asker is not None and self.trained_link is not None @torch.no_grad() def key_recall(self, n: int = 8, ablate: bool = False): """Reproduce the train_link.py forcing task with the FULL trained asker: a random key is shown ONLY to the consultant; the asker must emit it from the latent alone. Returns {acc, examples:[(key, recovered, ok)]}. ablate=True zeros the latent.""" if not self.key_recall_available(): return None a, c = self.link_meta["asker"], self.link_meta["consultant"] a_tok, c_tok = self.toks[a], self.toks[c] key_len = self.link_meta.get("key_len", 6) prompt = self.link_meta.get("prompt", "KEY> ") plen = len(a_tok.encode(prompt, add_special_tokens=False)) keys = ["".join(random.choice(KEY_CHARS) for _ in range(key_len)) for _ in range(n)] examples, correct = [], 0 for k in keys: c_ids = torch.tensor([c_tok.encode(k, add_special_tokens=False)], device=self.device) a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False) + a_tok.encode(k, add_special_tokens=False)], device=self.device) latent = self.models[c](input_ids=c_ids).latent inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent) logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0] out = a_tok.decode(pred.tolist())[:len(k)] ok = (out == k) correct += int(ok) examples.append((k, out, ok)) return {"acc": correct / max(1, n), "examples": examples} @torch.no_grad() def relay_secret(self, secret: str, ablate: bool = False): """Interactive bridge demo: a USER-CHOSEN key is shown only to the consultant; the asker reads it back from the latent alone (same mechanism as key_recall, but the human picks the secret). Returns {secret, recovered, ok:[per-char bool], aligned} -- aligned=False means the tokenizer fused some characters into multi-char tokens the bridge never saw in training, so expect degradation.""" if not self.key_recall_available(): return {"error": "bridge unavailable"} s = "".join(ch for ch in (secret or "") if ch in KEY_CHARS) key_len = self.link_meta.get("key_len", 6) if len(s) != key_len: return {"error": f"need exactly {key_len} characters (letters and digits only)"} a, c = self.link_meta["asker"], self.link_meta["consultant"] a_tok, c_tok = self.toks[a], self.toks[c] prompt = self.link_meta.get("prompt", "KEY> ") plen = len(a_tok.encode(prompt, add_special_tokens=False)) c_ids = torch.tensor([c_tok.encode(s, add_special_tokens=False)], device=self.device) a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False) + a_tok.encode(s, add_special_tokens=False)], device=self.device) aligned = c_ids.shape[1] == key_len and a_ids.shape[1] == plen + key_len latent = self.models[c](input_ids=c_ids).latent inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent) logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0] out = a_tok.decode(pred.tolist())[:len(s)] return {"secret": s, "recovered": out, "ok": [i < len(out) and out[i] == ch for i, ch in enumerate(s)], "aligned": aligned} def consult_available(self): return self.trained_link is not None def consult_meta(self): return dict(self.link_meta) if self.link_meta else None @torch.no_grad() def consult(self, query: str, max_new: int = 120, temperature: float = 0.8, top_k: int = 40, ablate: bool = False): """The asker generates while reading the consultant's latent through the TRAINED RecursiveLink. ablate=True zeros the latent (the truth-test: output should lose the consultant's influence).""" if not self.consult_available(): return "" a, c = self.link_meta["asker"], self.link_meta["consultant"] latent = self.models[c](input_ids=self._ids(c, query)).latent # [1, 256] inj = self.trained_link(latent) if ablate: inj = torch.zeros_like(inj) m, tok = self.models[a], self.toks[a] ids = self._ids(a, query); start = ids.shape[1] ctx_max = int(getattr(m.config, "max_position_embeddings", 2048)) for _ in range(max_new): logits = m(input_ids=ids[:, -ctx_max:], inject_latent=inj).logits[:, -1, :] / max(1e-5, temperature) if top_k: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) ids = torch.cat([ids, nxt], dim=1) if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: break return tok.decode(ids[0, start:].tolist()) def combine_available(self): return "language" in self.models and "math" in self.models @torch.no_grad() def combine(self, query: str, max_new: int = 60, blend: float = 0.5, temperature: float = 0.8, top_k: int = 40, consult: bool = False): """Token-level MoE: blend BOTH specialists' next-token distributions every step. Possible because they share the 16k tokenizer (same vocab). blend = weight on Math (0 -> pure Language, 1 -> pure Math, 0.5 -> equal mix). consult=True also injects Math's latent into Language through the trained RecursiveLink during the blend.""" if not self.combine_available(): return "" lang, math_ = self.models["language"], self.models["math"] tok = self.toks["language"] # shared tokenizer (identical for both) ids = self._ids("language", query); start = ids.shape[1] cl = int(getattr(lang.config, "max_position_embeddings", 2048)) cm = int(getattr(math_.config, "max_position_embeddings", 2048)) inj = None if consult and self.trained_link is not None: latent = math_(input_ids=self._ids("math", query)).latent inj = self.trained_link(latent) t = max(1e-5, temperature) for _ in range(max_new): pl = F.softmax(lang(input_ids=ids[:, -cl:], inject_latent=inj).logits[:, -1, :] / t, dim=-1) pm = F.softmax(math_(input_ids=ids[:, -cm:]).logits[:, -1, :] / t, dim=-1) probs = (1.0 - blend) * pl + blend * pm # mixture of the two experts' distributions if top_k: v, _ = torch.topk(probs, min(top_k, probs.size(-1))) probs = torch.where(probs < v[:, [-1]], torch.zeros_like(probs), probs) probs = probs / probs.sum(dim=-1, keepdim=True) nxt = torch.multinomial(probs, 1) ids = torch.cat([ids, nxt], dim=1) if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: break return tok.decode(ids[0, start:].tolist()) def _ids(self, slot: str, text: str): tok = self.toks[slot] ids = tok.encode(text, add_special_tokens=False)[:MAX_CTX] or [tok.pad_token_id or 0] return torch.tensor([ids], dtype=torch.long, device=self.device) @torch.no_grad() def _bits_per_byte(self, slot: str, text: str) -> float: """Total next-token NLL of `text` under this specialist, per RAW UTF-8 byte. Byte-normalisation makes the score comparable across the experts' different tokenizers (a smaller vocab spends more tokens but each is cheaper).""" nbytes = max(1, len(text.encode("utf-8"))) ids = self._ids(slot, text) if ids.shape[1] < 2: return float("inf") logits = self.models[slot](input_ids=ids).logits ce = F.cross_entropy(logits[:, :-1, :].reshape(-1, logits.size(-1)), ids[:, 1:].reshape(-1), reduction="sum") return (ce.item() / math.log(2)) / nbytes @torch.no_grad() def route(self, query: str, temp: float = 0.5): """Per-expert bits/byte + softmax routing weights + the winner (lowest bits/byte).""" bits = {slot: self._bits_per_byte(slot, query) for slot in self.models} names = list(bits) logits = torch.tensor([-bits[n] / temp for n in names]) w = F.softmax(logits, dim=0).tolist() weights = {n: round(wi, 3) for n, wi in zip(names, w)} winner = min(bits, key=bits.get) return winner, weights, {n: round(b, 3) for n, b in bits.items()} @torch.no_grad() def shared_latent(self, query: str): """Each expert's output latent -> sum -> RecursiveLink -> one shared latent (the bus).""" lats = {slot: self.models[slot](input_ids=self._ids(slot, query)).latent for slot in self.models} z = torch.stack([lats[s] for s in lats], 0).sum(0) # [1, 256] shared = self.link(z)[0] return {s: lats[s][0].tolist() for s in lats}, shared.tolist() @torch.no_grad() def generate(self, query: str, expert: str | None = None, max_new: int = 160, temperature: float = 0.8, top_k: int = 40): if expert is None: expert, _, _ = self.route(query) m, tok = self.models[expert], self.toks[expert] ids = self._ids(expert, query) start = ids.shape[1] ctx_max = int(getattr(m.config, "max_position_embeddings", 2048)) for _ in range(max_new): logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature) if top_k: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) ids = torch.cat([ids, nxt], dim=1) if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: break return expert, tok.decode(ids[0, start:].tolist()) @torch.no_grad() def generate_stream(self, query: str, expert: str | None = None, max_new: int = 160, temperature: float = 0.8, top_k: int = 40, chunk: int = 4): """Like generate(), but yields (expert, text_so_far) as tokens arrive, so the UI can show generation live instead of freezing until the whole thing is done.""" if expert is None: expert, _, _ = self.route(query) m, tok = self.models[expert], self.toks[expert] ids = self._ids(expert, query) start = ids.shape[1] ctx_max = int(getattr(m.config, "max_position_embeddings", 2048)) for i in range(max_new): logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature) if top_k: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) ids = torch.cat([ids, nxt], dim=1) done = tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id if done or i % chunk == chunk - 1 or i == max_new - 1: yield expert, tok.decode(ids[0, start:].tolist()) if done: break @torch.no_grad() def run(self, query: str, max_new: int = 160, temperature: float = 0.8): """Full pass: route -> fuse latents -> generate from the winner.""" if not self.models: return {"error": "no experts trained yet"} winner, weights, bits = self.route(query) _, shared = self.shared_latent(query) _, gen = self.generate(query, winner, max_new, temperature) return { "winner": winner, "weights": weights, "bits_per_byte": bits, "steps": {n: self.steps.get(n, 0) for n in self.models}, "generation": gen, "shared_latent": [round(x, 3) for x in shared], } # ------------------------------------------------------------------ # # WEIGHT-MERGE section (self-contained; does not touch anything above). # Both specialists are the identical dense architecture, so their weights # can be linearly merged into ONE model: W = (1-alpha)*Language + alpha*Math. # This is a real merged model (one forward pass), not an output ensemble. # ------------------------------------------------------------------ # def merge_available(self): return "language" in self.models and "math" in self.models def _merged_model(self, alpha: float): """Build (and cache the most recent) single weight-merged model at ratio alpha.""" alpha = round(float(alpha), 2) cache = getattr(self, "_merge_cache", None) if cache is not None and abs(cache[0] - alpha) < 1e-6: return cache[1] lsd = self.models["language"].state_dict() msd = self.models["math"].state_dict() merged = {} for k in lsd: if lsd[k].is_floating_point(): merged[k] = ((1.0 - alpha) * lsd[k].float() + alpha * msd[k].float()).to(lsd[k].dtype) else: merged[k] = lsd[k].clone() # integer/bool buffers: copy as-is m = SpikeWhaleLM(specialist_config("language")).to(self.device).eval() m.load_state_dict(merged) # make it bridge-ready: overlay the bridge-trained injection weights (loaded by # _load_links, format-agnostic). The injection only fires when we pass inject_latent. ali = getattr(self, "_bridge_ali", None) if ali is not None: try: m.model.latent_inject.load_state_dict(ali) except Exception: pass self._merge_cache = (alpha, m) return m @torch.no_grad() def merge_generate(self, query: str, alpha: float = 0.5, max_new: int = 60, temperature: float = 0.8, top_k: int = 40, consult: bool = False): """Generate from the WEIGHT-MERGED single model. consult=True injects Math's latent into the merged model through the trained RecursiveLink.""" if not self.merge_available(): return "" m = self._merged_model(alpha) tok = self.toks["language"] # shared tokenizer ids = self._ids("language", query); start = ids.shape[1] ctx = int(getattr(m.config, "max_position_embeddings", 2048)) inj = None if consult and getattr(self, "trained_link", None) is not None: latent = self.models["math"](input_ids=self._ids("math", query)).latent inj = self.trained_link(latent) t = max(1e-5, temperature) for _ in range(max_new): logits = m(input_ids=ids[:, -ctx:], inject_latent=inj).logits[:, -1, :] / t if top_k: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) ids = torch.cat([ids, nxt], dim=1) if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: break return tok.decode(ids[0, start:].tolist()) _MOE = None def get_moe(device: str = "cpu"): global _MOE if _MOE is None: _MOE = SpikeWhaleMoE(device=device) else: _MOE.reload() return _MOE if __name__ == "__main__": moe = get_moe() print("experts:", moe.available() or "(none trained yet)") if moe.available(): for q in ["The mitochondria is the", "Solve for x: 2x + 3 =", "Book a flight to"]: r = moe.run(q, max_new=60) print(f"\nQ: {q!r}\n routed-> {r['winner']} bits/byte {r['bits_per_byte']} weights {r['weights']}") print(f" gen: {r['generation']!r}")