Spaces:
Running on Zero
Running on Zero
File size: 10,656 Bytes
45e7dfb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | """
mm_grad.py -- pure-numpy forward + backward (REINFORCE gradient) for the Modular
Mind policy, so the boss can be **finetuned from real player data on a CPU** with
no torch at runtime.
The math is identical to mm_torch.ModularMindPolicy, hand-differentiated so a
gradient step is a few thousand FLOPs (microseconds). Verified against torch
autograd in test_grad() to <1e-6.
Pipeline:
player plays a fight -> browser logs (state, action, bossHP, playerHP) per boss
decision + who died -> /learn -> we rebuild the per-step rewards (damage dealt
- taken, + kill/- death), compute REINFORCE returns, and take one Adam step that
nudges the policy toward what worked against real humans. A frozen copy of the
sim-trained weights is kept as an anchor (small pull-back) so it can't drift far.
"""
from __future__ import annotations
import numpy as np
from features import ACTIONS, NF, extract_features, legal_mask
from modular_mind import SPEC_DEFS, D_LATENT, H
NA = len(ACTIONS)
EPS = 1e-5
def _ln_fwd(x, w, b):
mu = x.mean()
var = ((x - mu) ** 2).mean()
std = np.sqrt(var + EPS)
xhat = (x - mu) / std
return xhat * w + b, (xhat, std, w)
def _ln_bwd(gy, cache):
xhat, std, w = cache
n = xhat.shape[0]
gw = gy * xhat
gb = gy.copy()
gxhat = gy * w
gx = (gxhat - gxhat.mean() - xhat * (gxhat * xhat).mean()) / std
return gx, gw, gb
def _relu(x):
return np.maximum(x, 0.0)
class OnlineLearner:
"""Holds the live weights + Adam state; updates them from player trajectories."""
def __init__(self, weights, lr=5e-3, gamma=0.97, anchor_pull=0.02,
w_deal=6.0, w_take=5.0, time_pen=0.01, entropy_coef=0.01):
self.W = {k: v.astype(np.float64).copy() for k, v in weights.items()}
self.anchor = {k: v.copy() for k, v in self.W.items()} # sim-trained anchor
self.lr, self.gamma, self.anchor_pull = lr, gamma, anchor_pull
self.w_deal, self.w_take, self.time_pen = w_deal, w_take, time_pen
self.entropy_coef = entropy_coef
self.owns = [ACTIONS.index(o) if o else None for _, o, _ in SPEC_DEFS]
self.m = {k: np.zeros_like(v) for k, v in self.W.items()}
self.v = {k: np.zeros_like(v) for k, v in self.W.items()}
self.t = 0
# ---- forward with cached intermediates -------------------------------
def _forward(self, f):
W = self.W
hs, lats, drives = [], [], np.zeros(NA)
for i, owns in enumerate(self.owns):
pre = W[f"s{i}_fc1_w"] @ f + W[f"s{i}_fc1_b"]
h = np.tanh(pre)
hs.append(h)
lat = W[f"s{i}_lat_w"] @ h + W[f"s{i}_lat_b"]
lats.append(lat)
if owns is not None:
drives[owns] += W[f"s{i}_drv_w"][0] @ h + W[f"s{i}_drv_b"][0]
z = np.sum(lats, axis=0)
zn, ln_in_c = _ln_fwd(z, W["link_ni_w"], W["link_ni_b"])
pre_g = W["link_g"] @ zn
g_act = _relu(pre_g)
v_act = W["link_v"] @ zn
reglu = g_act * v_act
out = W["link_d"] @ reglu
shared, ln_out_c = _ln_fwd(out + z, W["link_no_w"], W["link_no_b"])
modulation = W["coord_w"] @ shared + W["coord_b"]
logits = drives + modulation
cache = dict(f=f, hs=hs, lats=lats, z=z, zn=zn, ln_in_c=ln_in_c, pre_g=pre_g,
g_act=g_act, v_act=v_act, reglu=reglu, out=out, shared=shared,
ln_out_c=ln_out_c)
return logits, cache
# ---- backward: accumulate grads of (advantage * -logpi - H) ----------
def _backward(self, cache, g_logits, grads):
W = self.W
# coordinator
grads["coord_w"] += np.outer(g_logits, cache["shared"])
grads["coord_b"] += g_logits
g_shared = W["coord_w"].T @ g_logits
# owned-action drives
g_drive = {}
for i, owns in enumerate(self.owns):
if owns is not None:
g_drive[i] = g_logits[owns]
# out + z layernorm
g_outz, gw, gb = _ln_bwd(g_shared, cache["ln_out_c"])
grads["link_no_w"] += gw
grads["link_no_b"] += gb
g_out = g_outz
g_z = g_outz.copy()
# out = Wd @ reglu
grads["link_d"] += np.outer(g_out, cache["reglu"])
g_reglu = W["link_d"].T @ g_out
# reglu = relu(Wg@zn) * (Wv@zn)
g_g_act = g_reglu * cache["v_act"]
g_v_act = g_reglu * cache["g_act"]
g_pre_g = g_g_act * (cache["pre_g"] > 0)
grads["link_g"] += np.outer(g_pre_g, cache["zn"])
grads["link_v"] += np.outer(g_v_act, cache["zn"])
g_zn = W["link_g"].T @ g_pre_g + W["link_v"].T @ g_v_act
# zn = layernorm(z)
g_z_ln, gw, gb = _ln_bwd(g_zn, cache["ln_in_c"])
grads["link_ni_w"] += gw
grads["link_ni_b"] += gb
g_z += g_z_ln
# z = sum(lat_i) -> each specialist
for i, owns in enumerate(self.owns):
h = cache["hs"][i]
g_lat = g_z
grads[f"s{i}_lat_w"] += np.outer(g_lat, h)
grads[f"s{i}_lat_b"] += g_lat
g_h = W[f"s{i}_lat_w"].T @ g_lat
if owns is not None:
grads[f"s{i}_drv_w"][0] += g_drive[i] * h
grads[f"s{i}_drv_b"][0] += g_drive[i]
g_h = g_h + W[f"s{i}_drv_w"][0] * g_drive[i]
g_pre = g_h * (1.0 - h * h)
grads[f"s{i}_fc1_w"] += np.outer(g_pre, cache["f"])
grads[f"s{i}_fc1_b"] += g_pre
def logpi_grad(self, f, action, advantage, mask):
"""Grad of advantage * -log pi(action|state) (+ entropy bonus), accumulated."""
logits, cache = self._forward(f)
masked = np.where(mask > 0.5, logits, -1e9)
p = np.exp(masked - masked.max())
p = p / p.sum()
onehot = np.zeros(NA)
onehot[action] = 1.0
# d(-adv*logpi)/dlogits = adv*(p - onehot); entropy bonus grad = ent_coef*(p*(logp+H_)...)
g_logits = advantage * (p - onehot)
# entropy regularizer (encourage exploration): d(-ent_coef*H)/dlogits
with np.errstate(divide="ignore"):
logp = np.where(p > 1e-12, np.log(p), 0.0)
ent_term = self.entropy_coef * p * (logp + (p * (-logp)).sum())
g_logits = g_logits + np.where(mask > 0.5, ent_term, 0.0)
grads = {k: np.zeros_like(v) for k, v in self.W.items()}
self._backward(cache, g_logits, grads)
return grads
def _trajectory_rewards(self, steps, result):
"""Rebuild per-decision rewards from logged HP (damage dealt - taken)."""
n = len(steps)
rews = np.zeros(n)
for t in range(n):
nb = steps[t + 1]["bossHP"] if t + 1 < n else (0.0 if result.get("bossDied") else steps[t]["bossHP"])
npl = steps[t + 1]["playerHP"] if t + 1 < n else (0.0 if result.get("playerDied") else steps[t]["playerHP"])
dealt = max(0.0, steps[t]["playerHP"] - npl)
taken = max(0.0, steps[t]["bossHP"] - nb)
rews[t] = dealt * self.w_deal - taken * self.w_take - self.time_pen
if result.get("playerDied"):
rews[-1] += 8.0
elif result.get("bossDied"):
rews[-1] -= 5.0
return rews
def update(self, trajectories):
"""trajectories: list of {steps:[{state,action,bossHP,playerHP}], result:{}}.
Returns dict of stats. Mutates self.W in place (one Adam step)."""
grads = {k: np.zeros_like(v) for k, v in self.W.items()}
all_returns, nsteps = [], 0
# first pass: gather returns for a baseline
per_traj = []
for tr in trajectories:
steps = tr.get("steps", [])
if len(steps) < 2:
continue
rews = self._trajectory_rewards(steps, tr.get("result", {}))
G = np.zeros(len(rews))
acc = 0.0
for t in reversed(range(len(rews))):
acc = rews[t] + self.gamma * acc
G[t] = acc
per_traj.append((steps, G))
all_returns.extend(G.tolist())
if not per_traj:
return {"updated": False, "reason": "not enough data"}
baseline = float(np.mean(all_returns))
adv_std = float(np.std(all_returns)) + 1e-6
# second pass: accumulate gradient
for steps, G in per_traj:
for t, st in enumerate(steps):
s = st["state"]
f = extract_features(s).astype(np.float64)
mask = legal_mask(s)
action = ACTIONS.index(st["action"]) if isinstance(st["action"], str) else int(st["action"])
adv = (G[t] - baseline) / adv_std
g = self.logpi_grad(f, action, adv, mask)
for k in grads:
grads[k] += g[k]
nsteps += 1
# average + anchor pull-back (stay near the sim-trained policy)
self.t += 1
b1, b2 = 0.9, 0.999
for k in self.W:
gk = grads[k] / max(1, nsteps) + self.anchor_pull * (self.W[k] - self.anchor[k])
self.m[k] = b1 * self.m[k] + (1 - b1) * gk
self.v[k] = b2 * self.v[k] + (1 - b2) * (gk * gk)
mhat = self.m[k] / (1 - b1 ** self.t)
vhat = self.v[k] / (1 - b2 ** self.t)
self.W[k] -= self.lr * mhat / (np.sqrt(vhat) + 1e-8)
return {"updated": True, "steps": nsteps, "trajectories": len(per_traj),
"avg_return": round(baseline, 3)}
def test_grad():
"""Verify the numpy logpi-gradient matches torch autograd."""
import torch
from mm_torch import ModularMindPolicy
m = ModularMindPolicy().double()
m.export_npz("_gradchk.npz")
W = {k: v for k, v in np.load("_gradchk.npz").items()}
learner = OnlineLearner(W, entropy_coef=0.0)
rng = np.random.default_rng(0)
maxrel = 0.0
for _ in range(5):
f = rng.normal(size=NF)
action = int(rng.integers(NA))
mask = np.ones(NA)
# numpy grad of -logpi(action) (advantage=1)
gnp = learner.logpi_grad(f, action, 1.0, mask)
# torch grad
m.zero_grad()
x = torch.tensor(f, dtype=torch.float64).unsqueeze(0)
logits, _ = m(x)
logp = torch.log_softmax(logits, dim=-1)[0, action]
(-logp).backward()
# compare coordinator weight grad as a representative
gt = m.coordinator.weight.grad.numpy()
rel = np.abs(gnp["coord_w"] - gt).max() / (np.abs(gt).max() + 1e-9)
maxrel = max(maxrel, rel)
import os
os.remove("_gradchk.npz")
print(f"max relative grad error (coord_w) vs torch: {maxrel:.2e}")
return maxrel
if __name__ == "__main__":
test_grad()
|