ModuleMind / agents /modmind /moe_gradio.py
Quazim0t0's picture
Upload 7 files
73dd4cf verified
Raw
History Blame Contribute Delete
29.6 kB
"""
moe_gradio.py -- adapter that drives the agents/ Gradio MoE panel with the ModMind
SpikeWhale specialists instead of the byte-level ByteGPT experts.
It exposes the SAME public API as agents/orchestrator.py:
get_moe(device) -> obj with .available() .reload() .route() .shared_latent()
.generate() .run()
and .run() returns the SAME dict shape panel.py consumes:
{winner, weights, bits_per_byte, steps, generation, shared_latent}
so agents/panel.py can use it as a drop-in backend (see the MM_MOE_BACKEND switch there).
What this adapter handles (the real differences vs the byte-level experts):
* Each specialist has its OWN SpikeTokenizer (different vocab), so per-TOKEN NLL is
not comparable across experts. Routing is therefore by BITS-PER-BYTE -- total NLL
divided by the query's raw UTF-8 byte count -- which IS comparable across tokenizers.
* d_latent=256 (vs 64); the shared RecursiveLink is built at 256, and every ModMind
specialist already shares that width, so latent fusion works natively.
* Checkpoints are ModMind format ({"model_state": ...}) under
<MM_CKPT_ROOT or repo>/<domain>/checkpoints/step_*.pt, loaded exactly like
train_link.py does, and hot-reloaded when newer ones appear.
Display mapping (ModMind domain -> the panel's expert slot/emoji):
language -> language (the panel shows 📖), reasoning -> math (➗), tool_use -> tool (🛠️)
"""
from __future__ import annotations
import glob
import math
import os
import random
import string
from pathlib import Path
import torch
import torch.nn.functional as F
from model import RecursiveLink, SpikeWhaleLM
from specialist_presets import specialist_config
from spike_tokenizer import SpikeTokenizer
HERE = Path(__file__).parent
# checkpoints may live on a Modal Volume; tokenizers stay with the code (same rule as train_link.py)
CKPT_ROOT = Path(os.environ.get("MM_CKPT_ROOT", str(HERE)))
D_LATENT = 256
MAX_CTX = 512 # cap prompt/scoring length so CPU routing + generation stay snappy
# (modmind_domain, panel_slot). reasoning trains on FineMath, so it IS the "math" expert.
SLOTS = [("language", "language"), ("reasoning", "math"), ("tool_use", "tool")]
DOMAIN2SLOT = {d: s for d, s in SLOTS}
# trained bridges to surface in the panel as (asker_domain, consultant_domain) "consult" demos
LINKS = [("language", "reasoning")]
KEY_CHARS = string.ascii_letters + string.digits # must match train_link.py
def _latest_ckpt(domain: str):
base = CKPT_ROOT / domain / "checkpoints"
sft = base / "model.safetensors"
if sft.exists():
return str(sft) # slim fp16 safetensors (preferred)
cks = sorted(glob.glob(str(base / "step_*.pt")))
return cks[-1] if cks else None # fall back to a full .pt checkpoint
class SpikeWhaleMoE:
def __init__(self, device: str = "cpu"):
self.device = device
self.models, self.toks, self.steps, self._mtime = {}, {}, {}, {}
self.link = RecursiveLink(d_latent=D_LATENT).to(device).eval()
self.trained_link = None # the TRAINED bridge (train_link.py output), for the consult demo
self.bridge_asker = None # the FULL fine-tuned asker, for reproducible key-recall
self.link_meta = None
self.qa_link = None # the question->answer bridge (train_qa_link.py output)
self.qa_asker = None
self.qa_meta = None
self._qa_mtime = None
self.reload()
def reload(self):
"""Load/refresh any specialist whose checkpoint exists or changed on disk."""
for domain, slot in SLOTS:
ck = _latest_ckpt(domain)
tok_path = HERE / domain / "tokenizer.json"
if ck is None or not tok_path.exists():
continue
mt = os.path.getmtime(ck)
if self._mtime.get(slot) == mt:
continue
cfg = specialist_config(domain)
m = SpikeWhaleLM(cfg).to(self.device).eval()
if ck.endswith(".safetensors"):
from safetensors.torch import load_file
from safetensors import safe_open
sd = load_file(ck, device=self.device)
sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()}
m.load_state_dict(sd)
with safe_open(ck, framework="pt") as f:
step = int((f.metadata() or {}).get("step", 0))
else:
st = torch.load(ck, map_location=self.device, weights_only=False)
m.load_state_dict(st["model_state"]); step = int(st.get("step", 0))
self.models[slot] = m
self.toks[slot] = SpikeTokenizer(vocab_file=str(tok_path))
self.steps[slot] = step
self._mtime[slot] = mt
self._load_links()
self._load_qa()
return list(self.models)
def available(self):
return list(self.models)
def ensure_device(self, device):
"""Move all loaded models/links to `device`. On ZeroGPU this is called inside the
@spaces.GPU function so CUDA is only touched there, never at import/startup.
.to(device) is a no-op when already there, so it's safe to call on every request."""
for k in list(self.models):
self.models[k] = self.models[k].to(device)
if self.link is not None:
self.link = self.link.to(device)
if self.trained_link is not None:
self.trained_link = self.trained_link.to(device)
if self.bridge_asker is not None:
self.bridge_asker = self.bridge_asker.to(device)
if self.qa_link is not None:
self.qa_link = self.qa_link.to(device)
if self.qa_asker is not None:
self.qa_asker = self.qa_asker.to(device)
cache = getattr(self, "_merge_cache", None)
if cache is not None:
self._merge_cache = (cache[0], cache[1].to(device))
self.device = device
def to_gpu_if_available(self):
try:
if torch.cuda.is_available():
self.ensure_device("cuda")
except Exception:
pass
def _load_links(self):
"""Load a TRAINED RecursiveLink bridge (train_link.py output) from
links/<asker>__from__<consultant>.pt and overlay its bridge-trained injection path
onto the asker. The injection only fires when we pass inject_latent (the consult demo),
so normal routing/generation is unaffected."""
self.trained_link = None
self.bridge_asker = None
self.link_meta = None
self._bridge_ali = None
for asker_domain, consultant_domain in LINKS:
a, c = DOMAIN2SLOT.get(asker_domain), DOMAIN2SLOT.get(consultant_domain)
if a not in self.models or c not in self.models:
continue
base = CKPT_ROOT / "links" / f"{asker_domain}__from__{consultant_domain}"
sft, pt = Path(str(base) + ".safetensors"), Path(str(base) + ".pt")
if sft.exists():
from safetensors.torch import load_file
from safetensors import safe_open
t = load_file(str(sft), device=self.device)
t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()}
with safe_open(str(sft), framework="pt") as f:
md = f.metadata() or {}
link_sd = {k[5:]: v for k, v in t.items() if k.startswith("link.")}
ali_sd = {k[4:]: v for k, v in t.items() if k.startswith("ali.")}
asker_sd = {k[6:]: v for k, v in t.items() if k.startswith("asker.")} or None
key_len = int(md.get("key_len", 6)); prompt = md.get("prompt", "KEY> ")
wl = float(md.get("with_latent", "nan")); nl = float(md.get("without_latent", "nan"))
elif pt.exists():
st = torch.load(str(pt), map_location=self.device, weights_only=False)
link_sd = st["link"]; ali_sd = st["asker_latent_inject"]; asker_sd = st.get("asker_state")
key_len = int(st.get("key_len", 6)); prompt = st.get("prompt", "KEY> ")
wl = float(st.get("with_latent", float("nan"))); nl = float(st.get("without_latent", float("nan")))
else:
continue
link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval()
link.load_state_dict(link_sd)
self._bridge_ali = ali_sd
try:
self.models[a].model.latent_inject.load_state_dict(ali_sd)
except Exception:
pass
# FULL fine-tuned asker (if present) -> lets us reproduce key-recall live.
if asker_sd:
ba = SpikeWhaleLM(specialist_config(asker_domain)).to(self.device).eval()
ba.load_state_dict(asker_sd)
self.bridge_asker = ba
self.trained_link = link
self.link_meta = {"asker": a, "consultant": c, "key_len": key_len, "prompt": prompt,
"with_latent": wl, "without_latent": nl}
return # one bridge is enough for the panel demo
def _load_qa(self):
"""Load the question->answer bridge (train_qa_link.py output): a NEW RecursiveLink
+ a fully fine-tuned asker that answer arithmetic shown only to the consultant.
mtime-cached (the file is ~190MB) and hot-reloaded as training improves it."""
a_dom, c_dom = LINKS[0]
path = CKPT_ROOT / "links" / f"qa__{a_dom}__from__{c_dom}.safetensors"
if not path.exists():
self.qa_link = self.qa_asker = self.qa_meta = None
self._qa_mtime = None
return
mt = os.path.getmtime(path)
if self._qa_mtime == mt and self.qa_link is not None:
return
a, c = DOMAIN2SLOT[a_dom], DOMAIN2SLOT[c_dom]
if a not in self.models or c not in self.models:
return
from safetensors.torch import load_file
from safetensors import safe_open
t = load_file(str(path), device=self.device)
t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()}
with safe_open(str(path), framework="pt") as f:
md = f.metadata() or {}
link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval()
link.load_state_dict({k[5:]: v for k, v in t.items() if k.startswith("link.")})
ask = SpikeWhaleLM(specialist_config(a_dom)).to(self.device).eval()
ask.load_state_dict({k[6:]: v for k, v in t.items() if k.startswith("asker.")})
self.qa_link, self.qa_asker = link, ask
self.qa_meta = {"asker": a, "consultant": c,
"ans_len": int(md.get("ans_len", 3)), "prompt": md.get("prompt", "ANS> "),
"holdout_exact": float(md.get("holdout_exact", "nan")),
"step": int(md.get("step", 0))}
self._qa_mtime = mt
def qa_available(self):
return self.qa_link is not None and self.qa_asker is not None
def qa_info(self):
return dict(self.qa_meta) if self.qa_meta else None
@torch.no_grad()
def ask_math(self, a: int, op: str, b: int, ablate: bool = False):
"""Language answers an arithmetic question SHOWN ONLY to Math: the frozen
consultant encodes the question, the QA RecursiveLink carries it across, and the
QA asker decodes the answer digits autoregressively from the latent alone (its
own input is just the 'ANS> ' prompt -- the question never reaches it as text).
ablate=True zeros the latent: the asker then has no question at all."""
if not self.qa_available():
return {"error": "qa bridge not trained yet"}
meta = self.qa_meta
a, b = int(a), int(b)
if op not in ("+", "-", "*"):
return {"error": "op must be one of + - *"}
truth = {"+": a + b, "-": a - b, "*": a * b}[op]
if not (0 <= truth < 10 ** meta["ans_len"]):
return {"error": "answer out of the trained range"}
q = f"{a} {op} {b} ="
c_ids = torch.tensor([self.toks[meta["consultant"]].encode(q, add_special_tokens=False)],
device=self.device)
latent = self.models[meta["consultant"]](input_ids=c_ids).latent
inj = torch.zeros_like(self.qa_link(latent)) if ablate else self.qa_link(latent)
a_tok = self.toks[meta["asker"]]
ids = torch.tensor([a_tok.encode(meta["prompt"], add_special_tokens=False)],
device=self.device)
plen = ids.shape[1]
for _ in range(meta["ans_len"]):
logits = self.qa_asker(input_ids=ids, inject_latent=inj).logits[:, -1, :]
ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1)
digits = a_tok.decode(ids[0, plen:].tolist())
want = f"{truth:0{meta['ans_len']}d}"
return {"question": q, "digits": digits, "answer": digits.lstrip("0") or "0",
"truth": truth, "want": want,
"ok": [i < len(digits) and digits[i] == ch for i, ch in enumerate(want)],
"exact": digits == want}
def key_recall_available(self):
return self.bridge_asker is not None and self.trained_link is not None
@torch.no_grad()
def key_recall(self, n: int = 8, ablate: bool = False):
"""Reproduce the train_link.py forcing task with the FULL trained asker: a random key
is shown ONLY to the consultant; the asker must emit it from the latent alone.
Returns {acc, examples:[(key, recovered, ok)]}. ablate=True zeros the latent."""
if not self.key_recall_available():
return None
a, c = self.link_meta["asker"], self.link_meta["consultant"]
a_tok, c_tok = self.toks[a], self.toks[c]
key_len = self.link_meta.get("key_len", 6)
prompt = self.link_meta.get("prompt", "KEY> ")
plen = len(a_tok.encode(prompt, add_special_tokens=False))
keys = ["".join(random.choice(KEY_CHARS) for _ in range(key_len)) for _ in range(n)]
examples, correct = [], 0
for k in keys:
c_ids = torch.tensor([c_tok.encode(k, add_special_tokens=False)], device=self.device)
a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False)
+ a_tok.encode(k, add_special_tokens=False)], device=self.device)
latent = self.models[c](input_ids=c_ids).latent
inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent)
logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits
pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0]
out = a_tok.decode(pred.tolist())[:len(k)]
ok = (out == k)
correct += int(ok)
examples.append((k, out, ok))
return {"acc": correct / max(1, n), "examples": examples}
@torch.no_grad()
def relay_secret(self, secret: str, ablate: bool = False):
"""Interactive bridge demo: a USER-CHOSEN key is shown only to the consultant;
the asker reads it back from the latent alone (same mechanism as key_recall, but
the human picks the secret). Returns {secret, recovered, ok:[per-char bool],
aligned} -- aligned=False means the tokenizer fused some characters into
multi-char tokens the bridge never saw in training, so expect degradation."""
if not self.key_recall_available():
return {"error": "bridge unavailable"}
s = "".join(ch for ch in (secret or "") if ch in KEY_CHARS)
key_len = self.link_meta.get("key_len", 6)
if len(s) != key_len:
return {"error": f"need exactly {key_len} characters (letters and digits only)"}
a, c = self.link_meta["asker"], self.link_meta["consultant"]
a_tok, c_tok = self.toks[a], self.toks[c]
prompt = self.link_meta.get("prompt", "KEY> ")
plen = len(a_tok.encode(prompt, add_special_tokens=False))
c_ids = torch.tensor([c_tok.encode(s, add_special_tokens=False)], device=self.device)
a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False)
+ a_tok.encode(s, add_special_tokens=False)], device=self.device)
aligned = c_ids.shape[1] == key_len and a_ids.shape[1] == plen + key_len
latent = self.models[c](input_ids=c_ids).latent
inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent)
logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits
pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0]
out = a_tok.decode(pred.tolist())[:len(s)]
return {"secret": s, "recovered": out,
"ok": [i < len(out) and out[i] == ch for i, ch in enumerate(s)],
"aligned": aligned}
def consult_available(self):
return self.trained_link is not None
def consult_meta(self):
return dict(self.link_meta) if self.link_meta else None
@torch.no_grad()
def consult(self, query: str, max_new: int = 120, temperature: float = 0.8,
top_k: int = 40, ablate: bool = False):
"""The asker generates while reading the consultant's latent through the TRAINED
RecursiveLink. ablate=True zeros the latent (the truth-test: output should lose the
consultant's influence)."""
if not self.consult_available():
return ""
a, c = self.link_meta["asker"], self.link_meta["consultant"]
latent = self.models[c](input_ids=self._ids(c, query)).latent # [1, 256]
inj = self.trained_link(latent)
if ablate:
inj = torch.zeros_like(inj)
m, tok = self.models[a], self.toks[a]
ids = self._ids(a, query); start = ids.shape[1]
ctx_max = int(getattr(m.config, "max_position_embeddings", 2048))
for _ in range(max_new):
logits = m(input_ids=ids[:, -ctx_max:], inject_latent=inj).logits[:, -1, :] / max(1e-5, temperature)
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
ids = torch.cat([ids, nxt], dim=1)
if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id:
break
return tok.decode(ids[0, start:].tolist())
def combine_available(self):
return "language" in self.models and "math" in self.models
@torch.no_grad()
def combine(self, query: str, max_new: int = 60, blend: float = 0.5,
temperature: float = 0.8, top_k: int = 40, consult: bool = False):
"""Token-level MoE: blend BOTH specialists' next-token distributions every step.
Possible because they share the 16k tokenizer (same vocab). blend = weight on Math
(0 -> pure Language, 1 -> pure Math, 0.5 -> equal mix). consult=True also injects Math's
latent into Language through the trained RecursiveLink during the blend."""
if not self.combine_available():
return ""
lang, math_ = self.models["language"], self.models["math"]
tok = self.toks["language"] # shared tokenizer (identical for both)
ids = self._ids("language", query); start = ids.shape[1]
cl = int(getattr(lang.config, "max_position_embeddings", 2048))
cm = int(getattr(math_.config, "max_position_embeddings", 2048))
inj = None
if consult and self.trained_link is not None:
latent = math_(input_ids=self._ids("math", query)).latent
inj = self.trained_link(latent)
t = max(1e-5, temperature)
for _ in range(max_new):
pl = F.softmax(lang(input_ids=ids[:, -cl:], inject_latent=inj).logits[:, -1, :] / t, dim=-1)
pm = F.softmax(math_(input_ids=ids[:, -cm:]).logits[:, -1, :] / t, dim=-1)
probs = (1.0 - blend) * pl + blend * pm # mixture of the two experts' distributions
if top_k:
v, _ = torch.topk(probs, min(top_k, probs.size(-1)))
probs = torch.where(probs < v[:, [-1]], torch.zeros_like(probs), probs)
probs = probs / probs.sum(dim=-1, keepdim=True)
nxt = torch.multinomial(probs, 1)
ids = torch.cat([ids, nxt], dim=1)
if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id:
break
return tok.decode(ids[0, start:].tolist())
def _ids(self, slot: str, text: str):
tok = self.toks[slot]
ids = tok.encode(text, add_special_tokens=False)[:MAX_CTX] or [tok.pad_token_id or 0]
return torch.tensor([ids], dtype=torch.long, device=self.device)
@torch.no_grad()
def _bits_per_byte(self, slot: str, text: str) -> float:
"""Total next-token NLL of `text` under this specialist, per RAW UTF-8 byte.
Byte-normalisation makes the score comparable across the experts' different
tokenizers (a smaller vocab spends more tokens but each is cheaper)."""
nbytes = max(1, len(text.encode("utf-8")))
ids = self._ids(slot, text)
if ids.shape[1] < 2:
return float("inf")
logits = self.models[slot](input_ids=ids).logits
ce = F.cross_entropy(logits[:, :-1, :].reshape(-1, logits.size(-1)),
ids[:, 1:].reshape(-1), reduction="sum")
return (ce.item() / math.log(2)) / nbytes
@torch.no_grad()
def route(self, query: str, temp: float = 0.5):
"""Per-expert bits/byte + softmax routing weights + the winner (lowest bits/byte)."""
bits = {slot: self._bits_per_byte(slot, query) for slot in self.models}
names = list(bits)
logits = torch.tensor([-bits[n] / temp for n in names])
w = F.softmax(logits, dim=0).tolist()
weights = {n: round(wi, 3) for n, wi in zip(names, w)}
winner = min(bits, key=bits.get)
return winner, weights, {n: round(b, 3) for n, b in bits.items()}
@torch.no_grad()
def shared_latent(self, query: str):
"""Each expert's output latent -> sum -> RecursiveLink -> one shared latent (the bus)."""
lats = {slot: self.models[slot](input_ids=self._ids(slot, query)).latent
for slot in self.models}
z = torch.stack([lats[s] for s in lats], 0).sum(0) # [1, 256]
shared = self.link(z)[0]
return {s: lats[s][0].tolist() for s in lats}, shared.tolist()
@torch.no_grad()
def generate(self, query: str, expert: str | None = None, max_new: int = 160,
temperature: float = 0.8, top_k: int = 40):
if expert is None:
expert, _, _ = self.route(query)
m, tok = self.models[expert], self.toks[expert]
ids = self._ids(expert, query)
start = ids.shape[1]
ctx_max = int(getattr(m.config, "max_position_embeddings", 2048))
for _ in range(max_new):
logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature)
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
ids = torch.cat([ids, nxt], dim=1)
if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id:
break
return expert, tok.decode(ids[0, start:].tolist())
@torch.no_grad()
def generate_stream(self, query: str, expert: str | None = None, max_new: int = 160,
temperature: float = 0.8, top_k: int = 40, chunk: int = 4):
"""Like generate(), but yields (expert, text_so_far) as tokens arrive, so the UI
can show generation live instead of freezing until the whole thing is done."""
if expert is None:
expert, _, _ = self.route(query)
m, tok = self.models[expert], self.toks[expert]
ids = self._ids(expert, query)
start = ids.shape[1]
ctx_max = int(getattr(m.config, "max_position_embeddings", 2048))
for i in range(max_new):
logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature)
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
ids = torch.cat([ids, nxt], dim=1)
done = tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id
if done or i % chunk == chunk - 1 or i == max_new - 1:
yield expert, tok.decode(ids[0, start:].tolist())
if done:
break
@torch.no_grad()
def run(self, query: str, max_new: int = 160, temperature: float = 0.8):
"""Full pass: route -> fuse latents -> generate from the winner."""
if not self.models:
return {"error": "no experts trained yet"}
winner, weights, bits = self.route(query)
_, shared = self.shared_latent(query)
_, gen = self.generate(query, winner, max_new, temperature)
return {
"winner": winner, "weights": weights, "bits_per_byte": bits,
"steps": {n: self.steps.get(n, 0) for n in self.models},
"generation": gen, "shared_latent": [round(x, 3) for x in shared],
}
# ------------------------------------------------------------------ #
# WEIGHT-MERGE section (self-contained; does not touch anything above).
# Both specialists are the identical dense architecture, so their weights
# can be linearly merged into ONE model: W = (1-alpha)*Language + alpha*Math.
# This is a real merged model (one forward pass), not an output ensemble.
# ------------------------------------------------------------------ #
def merge_available(self):
return "language" in self.models and "math" in self.models
def _merged_model(self, alpha: float):
"""Build (and cache the most recent) single weight-merged model at ratio alpha."""
alpha = round(float(alpha), 2)
cache = getattr(self, "_merge_cache", None)
if cache is not None and abs(cache[0] - alpha) < 1e-6:
return cache[1]
lsd = self.models["language"].state_dict()
msd = self.models["math"].state_dict()
merged = {}
for k in lsd:
if lsd[k].is_floating_point():
merged[k] = ((1.0 - alpha) * lsd[k].float() + alpha * msd[k].float()).to(lsd[k].dtype)
else:
merged[k] = lsd[k].clone() # integer/bool buffers: copy as-is
m = SpikeWhaleLM(specialist_config("language")).to(self.device).eval()
m.load_state_dict(merged)
# make it bridge-ready: overlay the bridge-trained injection weights (loaded by
# _load_links, format-agnostic). The injection only fires when we pass inject_latent.
ali = getattr(self, "_bridge_ali", None)
if ali is not None:
try:
m.model.latent_inject.load_state_dict(ali)
except Exception:
pass
self._merge_cache = (alpha, m)
return m
@torch.no_grad()
def merge_generate(self, query: str, alpha: float = 0.5, max_new: int = 60,
temperature: float = 0.8, top_k: int = 40, consult: bool = False):
"""Generate from the WEIGHT-MERGED single model. consult=True injects Math's latent
into the merged model through the trained RecursiveLink."""
if not self.merge_available():
return ""
m = self._merged_model(alpha)
tok = self.toks["language"] # shared tokenizer
ids = self._ids("language", query); start = ids.shape[1]
ctx = int(getattr(m.config, "max_position_embeddings", 2048))
inj = None
if consult and getattr(self, "trained_link", None) is not None:
latent = self.models["math"](input_ids=self._ids("math", query)).latent
inj = self.trained_link(latent)
t = max(1e-5, temperature)
for _ in range(max_new):
logits = m(input_ids=ids[:, -ctx:], inject_latent=inj).logits[:, -1, :] / t
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
ids = torch.cat([ids, nxt], dim=1)
if tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id:
break
return tok.decode(ids[0, start:].tolist())
_MOE = None
def get_moe(device: str = "cpu"):
global _MOE
if _MOE is None:
_MOE = SpikeWhaleMoE(device=device)
else:
_MOE.reload()
return _MOE
if __name__ == "__main__":
moe = get_moe()
print("experts:", moe.available() or "(none trained yet)")
if moe.available():
for q in ["The mitochondria is the", "Solve for x: 2x + 3 =", "Book a flight to"]:
r = moe.run(q, max_new=60)
print(f"\nQ: {q!r}\n routed-> {r['winner']} bits/byte {r['bits_per_byte']} weights {r['weights']}")
print(f" gen: {r['generation']!r}")