ModuleMind / mm_grad.py
Quazim0t0's picture
Add files using upload-large-folder tool
45e7dfb verified
Raw
History Blame Contribute Delete
10.7 kB
"""
mm_grad.py -- pure-numpy forward + backward (REINFORCE gradient) for the Modular
Mind policy, so the boss can be **finetuned from real player data on a CPU** with
no torch at runtime.
The math is identical to mm_torch.ModularMindPolicy, hand-differentiated so a
gradient step is a few thousand FLOPs (microseconds). Verified against torch
autograd in test_grad() to <1e-6.
Pipeline:
player plays a fight -> browser logs (state, action, bossHP, playerHP) per boss
decision + who died -> /learn -> we rebuild the per-step rewards (damage dealt
- taken, + kill/- death), compute REINFORCE returns, and take one Adam step that
nudges the policy toward what worked against real humans. A frozen copy of the
sim-trained weights is kept as an anchor (small pull-back) so it can't drift far.
"""
from __future__ import annotations
import numpy as np
from features import ACTIONS, NF, extract_features, legal_mask
from modular_mind import SPEC_DEFS, D_LATENT, H
NA = len(ACTIONS)
EPS = 1e-5
def _ln_fwd(x, w, b):
mu = x.mean()
var = ((x - mu) ** 2).mean()
std = np.sqrt(var + EPS)
xhat = (x - mu) / std
return xhat * w + b, (xhat, std, w)
def _ln_bwd(gy, cache):
xhat, std, w = cache
n = xhat.shape[0]
gw = gy * xhat
gb = gy.copy()
gxhat = gy * w
gx = (gxhat - gxhat.mean() - xhat * (gxhat * xhat).mean()) / std
return gx, gw, gb
def _relu(x):
return np.maximum(x, 0.0)
class OnlineLearner:
"""Holds the live weights + Adam state; updates them from player trajectories."""
def __init__(self, weights, lr=5e-3, gamma=0.97, anchor_pull=0.02,
w_deal=6.0, w_take=5.0, time_pen=0.01, entropy_coef=0.01):
self.W = {k: v.astype(np.float64).copy() for k, v in weights.items()}
self.anchor = {k: v.copy() for k, v in self.W.items()} # sim-trained anchor
self.lr, self.gamma, self.anchor_pull = lr, gamma, anchor_pull
self.w_deal, self.w_take, self.time_pen = w_deal, w_take, time_pen
self.entropy_coef = entropy_coef
self.owns = [ACTIONS.index(o) if o else None for _, o, _ in SPEC_DEFS]
self.m = {k: np.zeros_like(v) for k, v in self.W.items()}
self.v = {k: np.zeros_like(v) for k, v in self.W.items()}
self.t = 0
# ---- forward with cached intermediates -------------------------------
def _forward(self, f):
W = self.W
hs, lats, drives = [], [], np.zeros(NA)
for i, owns in enumerate(self.owns):
pre = W[f"s{i}_fc1_w"] @ f + W[f"s{i}_fc1_b"]
h = np.tanh(pre)
hs.append(h)
lat = W[f"s{i}_lat_w"] @ h + W[f"s{i}_lat_b"]
lats.append(lat)
if owns is not None:
drives[owns] += W[f"s{i}_drv_w"][0] @ h + W[f"s{i}_drv_b"][0]
z = np.sum(lats, axis=0)
zn, ln_in_c = _ln_fwd(z, W["link_ni_w"], W["link_ni_b"])
pre_g = W["link_g"] @ zn
g_act = _relu(pre_g)
v_act = W["link_v"] @ zn
reglu = g_act * v_act
out = W["link_d"] @ reglu
shared, ln_out_c = _ln_fwd(out + z, W["link_no_w"], W["link_no_b"])
modulation = W["coord_w"] @ shared + W["coord_b"]
logits = drives + modulation
cache = dict(f=f, hs=hs, lats=lats, z=z, zn=zn, ln_in_c=ln_in_c, pre_g=pre_g,
g_act=g_act, v_act=v_act, reglu=reglu, out=out, shared=shared,
ln_out_c=ln_out_c)
return logits, cache
# ---- backward: accumulate grads of (advantage * -logpi - H) ----------
def _backward(self, cache, g_logits, grads):
W = self.W
# coordinator
grads["coord_w"] += np.outer(g_logits, cache["shared"])
grads["coord_b"] += g_logits
g_shared = W["coord_w"].T @ g_logits
# owned-action drives
g_drive = {}
for i, owns in enumerate(self.owns):
if owns is not None:
g_drive[i] = g_logits[owns]
# out + z layernorm
g_outz, gw, gb = _ln_bwd(g_shared, cache["ln_out_c"])
grads["link_no_w"] += gw
grads["link_no_b"] += gb
g_out = g_outz
g_z = g_outz.copy()
# out = Wd @ reglu
grads["link_d"] += np.outer(g_out, cache["reglu"])
g_reglu = W["link_d"].T @ g_out
# reglu = relu(Wg@zn) * (Wv@zn)
g_g_act = g_reglu * cache["v_act"]
g_v_act = g_reglu * cache["g_act"]
g_pre_g = g_g_act * (cache["pre_g"] > 0)
grads["link_g"] += np.outer(g_pre_g, cache["zn"])
grads["link_v"] += np.outer(g_v_act, cache["zn"])
g_zn = W["link_g"].T @ g_pre_g + W["link_v"].T @ g_v_act
# zn = layernorm(z)
g_z_ln, gw, gb = _ln_bwd(g_zn, cache["ln_in_c"])
grads["link_ni_w"] += gw
grads["link_ni_b"] += gb
g_z += g_z_ln
# z = sum(lat_i) -> each specialist
for i, owns in enumerate(self.owns):
h = cache["hs"][i]
g_lat = g_z
grads[f"s{i}_lat_w"] += np.outer(g_lat, h)
grads[f"s{i}_lat_b"] += g_lat
g_h = W[f"s{i}_lat_w"].T @ g_lat
if owns is not None:
grads[f"s{i}_drv_w"][0] += g_drive[i] * h
grads[f"s{i}_drv_b"][0] += g_drive[i]
g_h = g_h + W[f"s{i}_drv_w"][0] * g_drive[i]
g_pre = g_h * (1.0 - h * h)
grads[f"s{i}_fc1_w"] += np.outer(g_pre, cache["f"])
grads[f"s{i}_fc1_b"] += g_pre
def logpi_grad(self, f, action, advantage, mask):
"""Grad of advantage * -log pi(action|state) (+ entropy bonus), accumulated."""
logits, cache = self._forward(f)
masked = np.where(mask > 0.5, logits, -1e9)
p = np.exp(masked - masked.max())
p = p / p.sum()
onehot = np.zeros(NA)
onehot[action] = 1.0
# d(-adv*logpi)/dlogits = adv*(p - onehot); entropy bonus grad = ent_coef*(p*(logp+H_)...)
g_logits = advantage * (p - onehot)
# entropy regularizer (encourage exploration): d(-ent_coef*H)/dlogits
with np.errstate(divide="ignore"):
logp = np.where(p > 1e-12, np.log(p), 0.0)
ent_term = self.entropy_coef * p * (logp + (p * (-logp)).sum())
g_logits = g_logits + np.where(mask > 0.5, ent_term, 0.0)
grads = {k: np.zeros_like(v) for k, v in self.W.items()}
self._backward(cache, g_logits, grads)
return grads
def _trajectory_rewards(self, steps, result):
"""Rebuild per-decision rewards from logged HP (damage dealt - taken)."""
n = len(steps)
rews = np.zeros(n)
for t in range(n):
nb = steps[t + 1]["bossHP"] if t + 1 < n else (0.0 if result.get("bossDied") else steps[t]["bossHP"])
npl = steps[t + 1]["playerHP"] if t + 1 < n else (0.0 if result.get("playerDied") else steps[t]["playerHP"])
dealt = max(0.0, steps[t]["playerHP"] - npl)
taken = max(0.0, steps[t]["bossHP"] - nb)
rews[t] = dealt * self.w_deal - taken * self.w_take - self.time_pen
if result.get("playerDied"):
rews[-1] += 8.0
elif result.get("bossDied"):
rews[-1] -= 5.0
return rews
def update(self, trajectories):
"""trajectories: list of {steps:[{state,action,bossHP,playerHP}], result:{}}.
Returns dict of stats. Mutates self.W in place (one Adam step)."""
grads = {k: np.zeros_like(v) for k, v in self.W.items()}
all_returns, nsteps = [], 0
# first pass: gather returns for a baseline
per_traj = []
for tr in trajectories:
steps = tr.get("steps", [])
if len(steps) < 2:
continue
rews = self._trajectory_rewards(steps, tr.get("result", {}))
G = np.zeros(len(rews))
acc = 0.0
for t in reversed(range(len(rews))):
acc = rews[t] + self.gamma * acc
G[t] = acc
per_traj.append((steps, G))
all_returns.extend(G.tolist())
if not per_traj:
return {"updated": False, "reason": "not enough data"}
baseline = float(np.mean(all_returns))
adv_std = float(np.std(all_returns)) + 1e-6
# second pass: accumulate gradient
for steps, G in per_traj:
for t, st in enumerate(steps):
s = st["state"]
f = extract_features(s).astype(np.float64)
mask = legal_mask(s)
action = ACTIONS.index(st["action"]) if isinstance(st["action"], str) else int(st["action"])
adv = (G[t] - baseline) / adv_std
g = self.logpi_grad(f, action, adv, mask)
for k in grads:
grads[k] += g[k]
nsteps += 1
# average + anchor pull-back (stay near the sim-trained policy)
self.t += 1
b1, b2 = 0.9, 0.999
for k in self.W:
gk = grads[k] / max(1, nsteps) + self.anchor_pull * (self.W[k] - self.anchor[k])
self.m[k] = b1 * self.m[k] + (1 - b1) * gk
self.v[k] = b2 * self.v[k] + (1 - b2) * (gk * gk)
mhat = self.m[k] / (1 - b1 ** self.t)
vhat = self.v[k] / (1 - b2 ** self.t)
self.W[k] -= self.lr * mhat / (np.sqrt(vhat) + 1e-8)
return {"updated": True, "steps": nsteps, "trajectories": len(per_traj),
"avg_return": round(baseline, 3)}
def test_grad():
"""Verify the numpy logpi-gradient matches torch autograd."""
import torch
from mm_torch import ModularMindPolicy
m = ModularMindPolicy().double()
m.export_npz("_gradchk.npz")
W = {k: v for k, v in np.load("_gradchk.npz").items()}
learner = OnlineLearner(W, entropy_coef=0.0)
rng = np.random.default_rng(0)
maxrel = 0.0
for _ in range(5):
f = rng.normal(size=NF)
action = int(rng.integers(NA))
mask = np.ones(NA)
# numpy grad of -logpi(action) (advantage=1)
gnp = learner.logpi_grad(f, action, 1.0, mask)
# torch grad
m.zero_grad()
x = torch.tensor(f, dtype=torch.float64).unsqueeze(0)
logits, _ = m(x)
logp = torch.log_softmax(logits, dim=-1)[0, action]
(-logp).backward()
# compare coordinator weight grad as a representative
gt = m.coordinator.weight.grad.numpy()
rel = np.abs(gnp["coord_w"] - gt).max() / (np.abs(gt).max() + 1e-9)
maxrel = max(maxrel, rel)
import os
os.remove("_gradchk.npz")
print(f"max relative grad error (coord_w) vs torch: {maxrel:.2e}")
return maxrel
if __name__ == "__main__":
test_grad()