Spaces:
Running on Zero
Running on Zero
| """ | |
| moe_gradio.py -- adapter that drives the agents/ Gradio MoE panel with the ModMind | |
| SpikeWhale specialists instead of the byte-level ByteGPT experts. | |
| It exposes the SAME public API as agents/orchestrator.py: | |
| get_moe(device) -> obj with .available() .reload() .route() .shared_latent() | |
| .generate() .run() | |
| and .run() returns the SAME dict shape panel.py consumes: | |
| {winner, weights, bits_per_byte, steps, generation, shared_latent} | |
| so agents/panel.py can use it as a drop-in backend (see the MM_MOE_BACKEND switch there). | |
| What this adapter handles (the real differences vs the byte-level experts): | |
| * Each specialist has its OWN SpikeTokenizer (different vocab), so per-TOKEN NLL is | |
| not comparable across experts. Routing is therefore by BITS-PER-BYTE -- total NLL | |
| divided by the query's raw UTF-8 byte count -- which IS comparable across tokenizers. | |
| * d_latent=256 (vs 64); the shared RecursiveLink is built at 256, and every ModMind | |
| specialist already shares that width, so latent fusion works natively. | |
| * Checkpoints are ModMind format ({"model_state": ...}) under | |
| <MM_CKPT_ROOT or repo>/<domain>/checkpoints/step_*.pt, loaded exactly like | |
| train_link.py does, and hot-reloaded when newer ones appear. | |
| Display mapping (ModMind domain -> the panel's expert slot/emoji): | |
| language -> language (the panel shows 📖), reasoning -> math (➗), tool_use -> tool (🛠️) | |
| """ | |
| from __future__ import annotations | |
| import glob | |
| import math | |
| import os | |
| import random | |
| import string | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from model import RecursiveLink, SpikeWhaleLM | |
| from specialist_presets import specialist_config | |
| from spike_tokenizer import SpikeTokenizer | |
| HERE = Path(__file__).parent | |
| # checkpoints may live on a Modal Volume; tokenizers stay with the code (same rule as train_link.py) | |
| CKPT_ROOT = Path(os.environ.get("MM_CKPT_ROOT", str(HERE))) | |
| D_LATENT = 256 | |
| MAX_CTX = 512 # cap prompt/scoring length so CPU routing + generation stay snappy | |
| # (modmind_domain, panel_slot). reasoning trains on FineMath, so it IS the "math" expert. | |
| SLOTS = [("language", "language"), ("reasoning", "math"), ("tool_use", "tool")] | |
| DOMAIN2SLOT = {d: s for d, s in SLOTS} | |
| # trained bridges to surface in the panel as (asker_domain, consultant_domain) "consult" demos | |
| LINKS = [("language", "reasoning")] | |
| KEY_CHARS = string.ascii_letters + string.digits # must match train_link.py | |
| def _latest_ckpt(domain: str): | |
| base = CKPT_ROOT / domain / "checkpoints" | |
| sft = base / "model.safetensors" | |
| if sft.exists(): | |
| return str(sft) # slim fp16 safetensors (preferred) | |
| cks = sorted(glob.glob(str(base / "step_*.pt"))) | |
| return cks[-1] if cks else None # fall back to a full .pt checkpoint | |
| class SpikeWhaleMoE: | |
| def __init__(self, device: str = "cpu"): | |
| self.device = device | |
| self.models, self.toks, self.steps, self._mtime = {}, {}, {}, {} | |
| self.link = RecursiveLink(d_latent=D_LATENT).to(device).eval() | |
| self.trained_link = None # the TRAINED bridge (train_link.py output), for the consult demo | |
| self.bridge_asker = None # the FULL fine-tuned asker, for reproducible key-recall | |
| self.link_meta = None | |
| self.qa_link = None # the question->answer bridge (train_qa_link.py output) | |
| self.qa_asker = None | |
| self.qa_meta = None | |
| self._qa_mtime = None | |
| self.reload() | |
| def reload(self): | |
| """Load/refresh any specialist whose checkpoint exists or changed on disk.""" | |
| for domain, slot in SLOTS: | |
| ck = _latest_ckpt(domain) | |
| tok_path = HERE / domain / "tokenizer.json" | |
| if ck is None or not tok_path.exists(): | |
| continue | |
| mt = os.path.getmtime(ck) | |
| if self._mtime.get(slot) == mt: | |
| continue | |
| cfg = specialist_config(domain) | |
| m = SpikeWhaleLM(cfg).to(self.device).eval() | |
| if ck.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| from safetensors import safe_open | |
| sd = load_file(ck, device=self.device) | |
| sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()} | |
| m.load_state_dict(sd) | |
| with safe_open(ck, framework="pt") as f: | |
| step = int((f.metadata() or {}).get("step", 0)) | |
| else: | |
| st = torch.load(ck, map_location=self.device, weights_only=False) | |
| m.load_state_dict(st["model_state"]); step = int(st.get("step", 0)) | |
| self.models[slot] = m | |
| self.toks[slot] = SpikeTokenizer(vocab_file=str(tok_path)) | |
| self.steps[slot] = step | |
| self._mtime[slot] = mt | |
| self._load_links() | |
| self._load_qa() | |
| return list(self.models) | |
| def available(self): | |
| return list(self.models) | |
| def ensure_device(self, device): | |
| """Move all loaded models/links to `device`. On ZeroGPU this is called inside the | |
| @spaces.GPU function so CUDA is only touched there, never at import/startup. | |
| .to(device) is a no-op when already there, so it's safe to call on every request.""" | |
| for k in list(self.models): | |
| self.models[k] = self.models[k].to(device) | |
| if self.link is not None: | |
| self.link = self.link.to(device) | |
| if self.trained_link is not None: | |
| self.trained_link = self.trained_link.to(device) | |
| if self.bridge_asker is not None: | |
| self.bridge_asker = self.bridge_asker.to(device) | |
| if self.qa_link is not None: | |
| self.qa_link = self.qa_link.to(device) | |
| if self.qa_asker is not None: | |
| self.qa_asker = self.qa_asker.to(device) | |
| cache = getattr(self, "_merge_cache", None) | |
| if cache is not None: | |
| self._merge_cache = (cache[0], cache[1].to(device)) | |
| self.device = device | |
| def to_gpu_if_available(self): | |
| try: | |
| if torch.cuda.is_available(): | |
| self.ensure_device("cuda") | |
| except Exception: | |
| pass | |
| def _load_links(self): | |
| """Load a TRAINED RecursiveLink bridge (train_link.py output) from | |
| links/<asker>__from__<consultant>.pt and overlay its bridge-trained injection path | |
| onto the asker. The injection only fires when we pass inject_latent (the consult demo), | |
| so normal routing/generation is unaffected.""" | |
| self.trained_link = None | |
| self.bridge_asker = None | |
| self.link_meta = None | |
| self._bridge_ali = None | |
| for asker_domain, consultant_domain in LINKS: | |
| a, c = DOMAIN2SLOT.get(asker_domain), DOMAIN2SLOT.get(consultant_domain) | |
| if a not in self.models or c not in self.models: | |
| continue | |
| base = CKPT_ROOT / "links" / f"{asker_domain}__from__{consultant_domain}" | |
| sft, pt = Path(str(base) + ".safetensors"), Path(str(base) + ".pt") | |
| if sft.exists(): | |
| from safetensors.torch import load_file | |
| from safetensors import safe_open | |
| t = load_file(str(sft), device=self.device) | |
| t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()} | |
| with safe_open(str(sft), framework="pt") as f: | |
| md = f.metadata() or {} | |
| link_sd = {k[5:]: v for k, v in t.items() if k.startswith("link.")} | |
| ali_sd = {k[4:]: v for k, v in t.items() if k.startswith("ali.")} | |
| asker_sd = {k[6:]: v for k, v in t.items() if k.startswith("asker.")} or None | |
| key_len = int(md.get("key_len", 6)); prompt = md.get("prompt", "KEY> ") | |
| wl = float(md.get("with_latent", "nan")); nl = float(md.get("without_latent", "nan")) | |
| elif pt.exists(): | |
| st = torch.load(str(pt), map_location=self.device, weights_only=False) | |
| link_sd = st["link"]; ali_sd = st["asker_latent_inject"]; asker_sd = st.get("asker_state") | |
| key_len = int(st.get("key_len", 6)); prompt = st.get("prompt", "KEY> ") | |
| wl = float(st.get("with_latent", float("nan"))); nl = float(st.get("without_latent", float("nan"))) | |
| else: | |
| continue | |
| link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval() | |
| link.load_state_dict(link_sd) | |
| self._bridge_ali = ali_sd | |
| try: | |
| self.models[a].model.latent_inject.load_state_dict(ali_sd) | |
| except Exception: | |
| pass | |
| # FULL fine-tuned asker (if present) -> lets us reproduce key-recall live. | |
| if asker_sd: | |
| ba = SpikeWhaleLM(specialist_config(asker_domain)).to(self.device).eval() | |
| ba.load_state_dict(asker_sd) | |
| self.bridge_asker = ba | |
| self.trained_link = link | |
| self.link_meta = {"asker": a, "consultant": c, "key_len": key_len, "prompt": prompt, | |
| "with_latent": wl, "without_latent": nl} | |
| return # one bridge is enough for the panel demo | |
| def _load_qa(self): | |
| """Load the question->answer bridge (train_qa_link.py output): a NEW RecursiveLink | |
| + a fully fine-tuned asker that answer arithmetic shown only to the consultant. | |
| mtime-cached (the file is ~190MB) and hot-reloaded as training improves it.""" | |
| a_dom, c_dom = LINKS[0] | |
| path = CKPT_ROOT / "links" / f"qa__{a_dom}__from__{c_dom}.safetensors" | |
| if not path.exists(): | |
| self.qa_link = self.qa_asker = self.qa_meta = None | |
| self._qa_mtime = None | |
| return | |
| mt = os.path.getmtime(path) | |
| if self._qa_mtime == mt and self.qa_link is not None: | |
| return | |
| a, c = DOMAIN2SLOT[a_dom], DOMAIN2SLOT[c_dom] | |
| if a not in self.models or c not in self.models: | |
| return | |
| from safetensors.torch import load_file | |
| from safetensors import safe_open | |
| t = load_file(str(path), device=self.device) | |
| t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()} | |
| with safe_open(str(path), framework="pt") as f: | |
| md = f.metadata() or {} | |
| link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval() | |
| link.load_state_dict({k[5:]: v for k, v in t.items() if k.startswith("link.")}) | |
| ask = SpikeWhaleLM(specialist_config(a_dom)).to(self.device).eval() | |
| ask.load_state_dict({k[6:]: v for k, v in t.items() if k.startswith("asker.")}) | |
| self.qa_link, self.qa_asker = link, ask | |
| self.qa_meta = {"asker": a, "consultant": c, | |
| "ans_len": int(md.get("ans_len", 3)), "prompt": md.get("prompt", "ANS> "), | |
| "holdout_exact": float(md.get("holdout_exact", "nan")), | |
| "step": int(md.get("step", 0))} | |
| self._qa_mtime = mt | |
| def qa_available(self): | |
| return self.qa_link is not None and self.qa_asker is not None | |
| def qa_info(self): | |
| return dict(self.qa_meta) if self.qa_meta else None | |
| def ask_math(self, a: int, op: str, b: int, ablate: bool = False): | |
| """Language answers an arithmetic question SHOWN ONLY to Math: the frozen | |
| consultant encodes the question, the QA RecursiveLink carries it across, and the | |
| QA asker decodes the answer digits autoregressively from the latent alone (its | |
| own input is just the 'ANS> ' prompt -- the question never reaches it as text). | |
| ablate=True zeros the latent: the asker then has no question at all.""" | |
| if not self.qa_available(): | |
| return {"error": "qa bridge not trained yet"} | |
| meta = self.qa_meta | |
| a, b = int(a), int(b) | |
| if op not in ("+", "-", "*"): | |
| return {"error": "op must be one of + - *"} | |
| truth = {"+": a + b, "-": a - b, "*": a * b}[op] | |
| if not (0 <= truth < 10 ** meta["ans_len"]): | |
| return {"error": "answer out of the trained range"} | |
| q = f"{a} {op} {b} =" | |
| c_ids = torch.tensor([self.toks[meta["consultant"]].encode(q, add_special_tokens=False)], | |
| device=self.device) | |
| latent = self.models[meta["consultant"]](input_ids=c_ids).latent | |
| inj = torch.zeros_like(self.qa_link(latent)) if ablate else self.qa_link(latent) | |
| a_tok = self.toks[meta["asker"]] | |
| ids = torch.tensor([a_tok.encode(meta["prompt"], add_special_tokens=False)], | |
| device=self.device) | |
| plen = ids.shape[1] | |
| for _ in range(meta["ans_len"]): | |
| logits = self.qa_asker(input_ids=ids, inject_latent=inj).logits[:, -1, :] | |
| ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1) | |
| digits = a_tok.decode(ids[0, plen:].tolist()) | |
| want = f"{truth:0{meta['ans_len']}d}" | |
| return {"question": q, "digits": digits, "answer": digits.lstrip("0") or "0", | |
| "truth": truth, "want": want, | |
| "ok": [i < len(digits) and digits[i] == ch for i, ch in enumerate(want)], | |
| "exact": digits == want} | |
| def key_recall_available(self): | |
| return self.bridge_asker is not None and self.trained_link is not None | |
| def key_recall(self, n: int = 8, ablate: bool = False): | |
| """Reproduce the train_link.py forcing task with the FULL trained asker: a random key | |
| is shown ONLY to the consultant; the asker must emit it from the latent alone. | |
| Returns {acc, examples:[(key, recovered, ok)]}. ablate=True zeros the latent.""" | |
| if not self.key_recall_available(): | |
| return None | |
| a, c = self.link_meta["asker"], self.link_meta["consultant"] | |
| a_tok, c_tok = self.toks[a], self.toks[c] | |
| key_len = self.link_meta.get("key_len", 6) | |
| prompt = self.link_meta.get("prompt", "KEY> ") | |
| plen = len(a_tok.encode(prompt, add_special_tokens=False)) | |
| keys = ["".join(random.choice(KEY_CHARS) for _ in range(key_len)) for _ in range(n)] | |
| examples, correct = [], 0 | |
| for k in keys: | |
| c_ids = torch.tensor([c_tok.encode(k, add_special_tokens=False)], device=self.device) | |
| a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False) | |
| + a_tok.encode(k, add_special_tokens=False)], device=self.device) | |
| latent = self.models[c](input_ids=c_ids).latent | |
| inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent) | |
| logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits | |
| pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0] | |
| out = a_tok.decode(pred.tolist())[:len(k)] | |
| ok = (out == k) | |
| correct += int(ok) | |
| examples.append((k, out, ok)) | |
| return {"acc": correct / max(1, n), "examples": examples} | |
| def relay_secret(self, secret: str, ablate: bool = False): | |
| """Interactive bridge demo: a USER-CHOSEN key is shown only to the consultant; | |
| the asker reads it back from the latent alone (same mechanism as key_recall, but | |
| the human picks the secret). Returns {secret, recovered, ok:[per-char bool], | |
| aligned} -- aligned=False means the tokenizer fused some characters into | |
| multi-char tokens the bridge never saw in training, so expect degradation.""" | |
| if not self.key_recall_available(): | |
| return {"error": "bridge unavailable"} | |
| s = "".join(ch for ch in (secret or "") if ch in KEY_CHARS) | |
| key_len = self.link_meta.get("key_len", 6) | |
| if len(s) != key_len: | |
| return {"error": f"need exactly {key_len} characters (letters and digits only)"} | |
| a, c = self.link_meta["asker"], self.link_meta["consultant"] | |
| a_tok, c_tok = self.toks[a], self.toks[c] | |
| prompt = self.link_meta.get("prompt", "KEY> ") | |
| plen = len(a_tok.encode(prompt, add_special_tokens=False)) | |
| c_ids = torch.tensor([c_tok.encode(s, add_special_tokens=False)], device=self.device) | |
| a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False) | |
| + a_tok.encode(s, add_special_tokens=False)], device=self.device) | |
| aligned = c_ids.shape[1] == key_len and a_ids.shape[1] == plen + key_len | |
| latent = self.models[c](input_ids=c_ids).latent | |
| inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent) | |
| logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits | |
| pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0] | |
| out = a_tok.decode(pred.tolist())[:len(s)] | |
| return {"secret": s, "recovered": out, | |
| "ok": [i < len(out) and out[i] == ch for i, ch in enumerate(s)], | |
| "aligned": aligned} | |
| def consult_available(self): | |
| return self.trained_link is not None | |
| def consult_meta(self): | |
| return dict(self.link_meta) if self.link_meta else None | |
| def consult(self, query: str, max_new: int = 120, temperature: float = 0.8, | |
| top_k: int = 40, ablate: bool = False): | |
| """The asker generates while reading the consultant's latent through the TRAINED | |
| RecursiveLink. ablate=True zeros the latent (the truth-test: output should lose the | |
| consultant's influence).""" | |
| if not self.consult_available(): | |
| return "" | |
| a, c = self.link_meta["asker"], self.link_meta["consultant"] | |
| latent = self.models[c](input_ids=self._ids(c, query)).latent # [1, 256] | |
| inj = self.trained_link(latent) | |
| if ablate: | |
| inj = torch.zeros_like(inj) | |
| m, tok = self.models[a], self.toks[a] | |
| ids = self._ids(a, query); start = ids.shape[1] | |
| ctx_max = int(getattr(m.config, "max_position_embeddings", 2048)) | |
| for _ in range(max_new): | |
| logits = m(input_ids=ids[:, -ctx_max:], inject_latent=inj).logits[:, -1, :] / max(1e-5, temperature) | |
| if top_k: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("inf") | |
| nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) | |
| ids = torch.cat([ids, nxt], dim=1) | |
| if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: | |
| break | |
| return tok.decode(ids[0, start:].tolist()) | |
| def combine_available(self): | |
| return "language" in self.models and "math" in self.models | |
| def combine(self, query: str, max_new: int = 60, blend: float = 0.5, | |
| temperature: float = 0.8, top_k: int = 40, consult: bool = False): | |
| """Token-level MoE: blend BOTH specialists' next-token distributions every step. | |
| Possible because they share the 16k tokenizer (same vocab). blend = weight on Math | |
| (0 -> pure Language, 1 -> pure Math, 0.5 -> equal mix). consult=True also injects Math's | |
| latent into Language through the trained RecursiveLink during the blend.""" | |
| if not self.combine_available(): | |
| return "" | |
| lang, math_ = self.models["language"], self.models["math"] | |
| tok = self.toks["language"] # shared tokenizer (identical for both) | |
| ids = self._ids("language", query); start = ids.shape[1] | |
| cl = int(getattr(lang.config, "max_position_embeddings", 2048)) | |
| cm = int(getattr(math_.config, "max_position_embeddings", 2048)) | |
| inj = None | |
| if consult and self.trained_link is not None: | |
| latent = math_(input_ids=self._ids("math", query)).latent | |
| inj = self.trained_link(latent) | |
| t = max(1e-5, temperature) | |
| for _ in range(max_new): | |
| pl = F.softmax(lang(input_ids=ids[:, -cl:], inject_latent=inj).logits[:, -1, :] / t, dim=-1) | |
| pm = F.softmax(math_(input_ids=ids[:, -cm:]).logits[:, -1, :] / t, dim=-1) | |
| probs = (1.0 - blend) * pl + blend * pm # mixture of the two experts' distributions | |
| if top_k: | |
| v, _ = torch.topk(probs, min(top_k, probs.size(-1))) | |
| probs = torch.where(probs < v[:, [-1]], torch.zeros_like(probs), probs) | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| nxt = torch.multinomial(probs, 1) | |
| ids = torch.cat([ids, nxt], dim=1) | |
| if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: | |
| break | |
| return tok.decode(ids[0, start:].tolist()) | |
| def _ids(self, slot: str, text: str): | |
| tok = self.toks[slot] | |
| ids = tok.encode(text, add_special_tokens=False)[:MAX_CTX] or [tok.pad_token_id or 0] | |
| return torch.tensor([ids], dtype=torch.long, device=self.device) | |
| def _bits_per_byte(self, slot: str, text: str) -> float: | |
| """Total next-token NLL of `text` under this specialist, per RAW UTF-8 byte. | |
| Byte-normalisation makes the score comparable across the experts' different | |
| tokenizers (a smaller vocab spends more tokens but each is cheaper).""" | |
| nbytes = max(1, len(text.encode("utf-8"))) | |
| ids = self._ids(slot, text) | |
| if ids.shape[1] < 2: | |
| return float("inf") | |
| logits = self.models[slot](input_ids=ids).logits | |
| ce = F.cross_entropy(logits[:, :-1, :].reshape(-1, logits.size(-1)), | |
| ids[:, 1:].reshape(-1), reduction="sum") | |
| return (ce.item() / math.log(2)) / nbytes | |
| def route(self, query: str, temp: float = 0.5): | |
| """Per-expert bits/byte + softmax routing weights + the winner (lowest bits/byte).""" | |
| bits = {slot: self._bits_per_byte(slot, query) for slot in self.models} | |
| names = list(bits) | |
| logits = torch.tensor([-bits[n] / temp for n in names]) | |
| w = F.softmax(logits, dim=0).tolist() | |
| weights = {n: round(wi, 3) for n, wi in zip(names, w)} | |
| winner = min(bits, key=bits.get) | |
| return winner, weights, {n: round(b, 3) for n, b in bits.items()} | |
| def shared_latent(self, query: str): | |
| """Each expert's output latent -> sum -> RecursiveLink -> one shared latent (the bus).""" | |
| lats = {slot: self.models[slot](input_ids=self._ids(slot, query)).latent | |
| for slot in self.models} | |
| z = torch.stack([lats[s] for s in lats], 0).sum(0) # [1, 256] | |
| shared = self.link(z)[0] | |
| return {s: lats[s][0].tolist() for s in lats}, shared.tolist() | |
| def generate(self, query: str, expert: str | None = None, max_new: int = 160, | |
| temperature: float = 0.8, top_k: int = 40): | |
| if expert is None: | |
| expert, _, _ = self.route(query) | |
| m, tok = self.models[expert], self.toks[expert] | |
| ids = self._ids(expert, query) | |
| start = ids.shape[1] | |
| ctx_max = int(getattr(m.config, "max_position_embeddings", 2048)) | |
| for _ in range(max_new): | |
| logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature) | |
| if top_k: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("inf") | |
| nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) | |
| ids = torch.cat([ids, nxt], dim=1) | |
| if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: | |
| break | |
| return expert, tok.decode(ids[0, start:].tolist()) | |
| def generate_stream(self, query: str, expert: str | None = None, max_new: int = 160, | |
| temperature: float = 0.8, top_k: int = 40, chunk: int = 4): | |
| """Like generate(), but yields (expert, text_so_far) as tokens arrive, so the UI | |
| can show generation live instead of freezing until the whole thing is done.""" | |
| if expert is None: | |
| expert, _, _ = self.route(query) | |
| m, tok = self.models[expert], self.toks[expert] | |
| ids = self._ids(expert, query) | |
| start = ids.shape[1] | |
| ctx_max = int(getattr(m.config, "max_position_embeddings", 2048)) | |
| for i in range(max_new): | |
| logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature) | |
| if top_k: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("inf") | |
| nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) | |
| ids = torch.cat([ids, nxt], dim=1) | |
| done = tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id | |
| if done or i % chunk == chunk - 1 or i == max_new - 1: | |
| yield expert, tok.decode(ids[0, start:].tolist()) | |
| if done: | |
| break | |
| def run(self, query: str, max_new: int = 160, temperature: float = 0.8): | |
| """Full pass: route -> fuse latents -> generate from the winner.""" | |
| if not self.models: | |
| return {"error": "no experts trained yet"} | |
| winner, weights, bits = self.route(query) | |
| _, shared = self.shared_latent(query) | |
| _, gen = self.generate(query, winner, max_new, temperature) | |
| return { | |
| "winner": winner, "weights": weights, "bits_per_byte": bits, | |
| "steps": {n: self.steps.get(n, 0) for n in self.models}, | |
| "generation": gen, "shared_latent": [round(x, 3) for x in shared], | |
| } | |
| # ------------------------------------------------------------------ # | |
| # WEIGHT-MERGE section (self-contained; does not touch anything above). | |
| # Both specialists are the identical dense architecture, so their weights | |
| # can be linearly merged into ONE model: W = (1-alpha)*Language + alpha*Math. | |
| # This is a real merged model (one forward pass), not an output ensemble. | |
| # ------------------------------------------------------------------ # | |
| def merge_available(self): | |
| return "language" in self.models and "math" in self.models | |
| def _merged_model(self, alpha: float): | |
| """Build (and cache the most recent) single weight-merged model at ratio alpha.""" | |
| alpha = round(float(alpha), 2) | |
| cache = getattr(self, "_merge_cache", None) | |
| if cache is not None and abs(cache[0] - alpha) < 1e-6: | |
| return cache[1] | |
| lsd = self.models["language"].state_dict() | |
| msd = self.models["math"].state_dict() | |
| merged = {} | |
| for k in lsd: | |
| if lsd[k].is_floating_point(): | |
| merged[k] = ((1.0 - alpha) * lsd[k].float() + alpha * msd[k].float()).to(lsd[k].dtype) | |
| else: | |
| merged[k] = lsd[k].clone() # integer/bool buffers: copy as-is | |
| m = SpikeWhaleLM(specialist_config("language")).to(self.device).eval() | |
| m.load_state_dict(merged) | |
| # make it bridge-ready: overlay the bridge-trained injection weights (loaded by | |
| # _load_links, format-agnostic). The injection only fires when we pass inject_latent. | |
| ali = getattr(self, "_bridge_ali", None) | |
| if ali is not None: | |
| try: | |
| m.model.latent_inject.load_state_dict(ali) | |
| except Exception: | |
| pass | |
| self._merge_cache = (alpha, m) | |
| return m | |
| def merge_generate(self, query: str, alpha: float = 0.5, max_new: int = 60, | |
| temperature: float = 0.8, top_k: int = 40, consult: bool = False): | |
| """Generate from the WEIGHT-MERGED single model. consult=True injects Math's latent | |
| into the merged model through the trained RecursiveLink.""" | |
| if not self.merge_available(): | |
| return "" | |
| m = self._merged_model(alpha) | |
| tok = self.toks["language"] # shared tokenizer | |
| ids = self._ids("language", query); start = ids.shape[1] | |
| ctx = int(getattr(m.config, "max_position_embeddings", 2048)) | |
| inj = None | |
| if consult and getattr(self, "trained_link", None) is not None: | |
| latent = self.models["math"](input_ids=self._ids("math", query)).latent | |
| inj = self.trained_link(latent) | |
| t = max(1e-5, temperature) | |
| for _ in range(max_new): | |
| logits = m(input_ids=ids[:, -ctx:], inject_latent=inj).logits[:, -1, :] / t | |
| if top_k: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("inf") | |
| nxt = torch.multinomial(F.softmax(logits, dim=-1), 1) | |
| ids = torch.cat([ids, nxt], dim=1) | |
| if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id: | |
| break | |
| return tok.decode(ids[0, start:].tolist()) | |
| _MOE = None | |
| def get_moe(device: str = "cpu"): | |
| global _MOE | |
| if _MOE is None: | |
| _MOE = SpikeWhaleMoE(device=device) | |
| else: | |
| _MOE.reload() | |
| return _MOE | |
| if __name__ == "__main__": | |
| moe = get_moe() | |
| print("experts:", moe.available() or "(none trained yet)") | |
| if moe.available(): | |
| for q in ["The mitochondria is the", "Solve for x: 2x + 3 =", "Book a flight to"]: | |
| r = moe.run(q, max_new=60) | |
| print(f"\nQ: {q!r}\n routed-> {r['winner']} bits/byte {r['bits_per_byte']} weights {r['weights']}") | |
| print(f" gen: {r['generation']!r}") | |