""" piano_mind.py -- a Modular Mind that plays piano, built with the SAME method as the boss-fight brain: several tiny specialist MLPs emit latents, a RecursiveLink merges them into one shared latent, and a coordinator reads it to pick the next note. Some specialists "own" a register (Bass/Tenor/Soprano/Rest) and add a drive to those keys; two are modulators (Onset/Phrase) that only write into the shared latent. Trained by next-note prediction on a note sequence transcribed from a song (piano/notes.json). Tiny -> trains in seconds and runs instantly for the self-playing piano in the Space. torch is used (already a Space dep); the architecture mirrors mm_torch / modular_mind. """ from __future__ import annotations import json, os import torch import torch.nn as nn import torch.nn.functional as F HERE = os.path.dirname(os.path.abspath(__file__)) K = 40 # context window (frames of recent notes) = 5s at 8fps H, D_LATENT = 64, 32 # (name, owned register or None) -- mirrors the boss's 5 owners + 2 modulators SPECS = [("Bass", "low"), ("Tenor", "mid"), ("Soprano", "high"), ("Rest", "rest"), ("Onset", None), ("Phrase", None)] def _group_masks(n_tokens): """token 0 = rest; notes 1..n-1 split into low/mid/high thirds.""" m = {g: torch.zeros(n_tokens) for g in ("low", "mid", "high", "rest")} m["rest"][0] = 1.0 notes = list(range(1, n_tokens)) third = max(1, len(notes) // 3) for j, t in enumerate(notes): g = "low" if j < third else ("mid" if j < 2 * third else "high") m[g][t] = 1.0 return m class PianoMind(nn.Module): def __init__(self, n_tokens): super().__init__() self.n_tokens = n_tokens self.fc1 = nn.ModuleList([nn.Linear(K, H) for _ in SPECS]) self.lat = nn.ModuleList([nn.Linear(H, D_LATENT) for _ in SPECS]) self.drv = nn.ModuleDict({n: nn.Linear(H, 1) for n, owns in SPECS if owns}) # RecursiveLink (ReGLU + residual), same shape as the boss bridge self.ni = nn.LayerNorm(D_LATENT) self.v = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) self.g = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) self.d = nn.Linear(2 * D_LATENT, D_LATENT, bias=False) self.no = nn.LayerNorm(D_LATENT) self.coord = nn.Linear(D_LATENT, n_tokens) gm = _group_masks(n_tokens) for g, t in gm.items(): self.register_buffer(f"mask_{g}", t) def forward(self, feat): # feat [B, K] in [0,1] B = feat.shape[0] drives = torch.zeros(B, self.n_tokens, device=feat.device) lats = [] for i, (name, owns) in enumerate(SPECS): h = torch.tanh(self.fc1[i](feat)) lats.append(self.lat[i](h)) if owns: drives = drives + self.drv[name](h) * getattr(self, f"mask_{owns}") z = torch.stack(lats, 0).sum(0) zn = self.ni(z) shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z) return drives + self.coord(shared) # logits [B, n_tokens] @torch.no_grad() def telemetry(self, feat): # feat [1, K] -> logits, per-specialist info, shared drives = torch.zeros(1, self.n_tokens) lats, per = [], [] for i, (name, owns) in enumerate(SPECS): hh = torch.tanh(self.fc1[i](feat)) lat = self.lat[i](hh); lats.append(lat) d = None if owns: d = float(self.drv[name](hh).item()) drives = drives + d * getattr(self, f"mask_{owns}") per.append({"name": name, "owns": owns, "drive": round(d, 3) if d is not None else None, "act": round(float(lat.norm().item()), 3)}) z = torch.stack(lats, 0).sum(0); zn = self.ni(z) shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z) logits = drives + self.coord(shared) return logits[0], per, [round(float(v), 2) for v in shared[0].tolist()] def _seq_feats(seq, n_tokens): import torch s = torch.tensor(seq, dtype=torch.long) X, Y = [], [] for t in range(K, len(s)): X.append(s[t - K:t].float() / n_tokens) Y.append(s[t]) return torch.stack(X), torch.tensor(Y) def train_and_save(notes_path=os.path.join(HERE, "notes.json"), out=os.path.join(HERE, "piano_weights.pt"), epochs=900, seed=0): torch.manual_seed(seed) meta = json.load(open(notes_path)) seq, n_tokens = meta["seq"], meta["n_tokens"] X, Y = _seq_feats(seq, n_tokens) model = PianoMind(n_tokens) opt = torch.optim.Adam(model.parameters(), lr=1e-2) n = X.shape[0]; bs = 512 for ep in range(epochs): perm = torch.randperm(n) tot = 0.0 for i in range(0, n, bs): idx = perm[i:i + bs] logits = model(X[idx]) loss = F.cross_entropy(logits, Y[idx]) opt.zero_grad(); loss.backward(); opt.step() tot += loss.item() * len(idx) if ep % 80 == 0 or ep == epochs - 1: with torch.no_grad(): acc = (model(X).argmax(1) == Y).float().mean().item() print(f" epoch {ep:4d} loss {tot/n:.3f} next-note acc {acc:.3f}") import numpy as _np arr = _np.array(seq) # seed from the most-sounding window (song fades in) bi = int(max(range(len(seq) - K), key=lambda i: int((arr[i:i + K] > 0).sum()))) torch.save({"state": model.state_dict(), "n_tokens": n_tokens, "tok2midi": meta["tok2midi"], "K": K, "fps": meta.get("fps", 8), "seed_seq": seq[bi:bi + K]}, out) print(f"saved -> {out}") return model, meta class PianoPlayer: """Loads the trained PianoMind and autoregressively yields the next note token.""" def __init__(self, weights=os.path.join(HERE, "piano_weights.pt")): ck = torch.load(weights, map_location="cpu") self.n_tokens = ck["n_tokens"]; self.K = ck["K"]; self.fps = ck["fps"] self.tok2midi = {int(k): int(v) for k, v in ck["tok2midi"].items()} self.seed_seq = ck["seed_seq"] self.model = PianoMind(self.n_tokens); self.model.load_state_dict(ck["state"]); self.model.eval() @torch.no_grad() def next_token(self, history, temperature=0.95, anti_silence=3, rep_penalty=1.5): h = list(history)[-self.K:] if len(h) < self.K: h = [0] * (self.K - len(h)) + h feat = torch.tensor([[x / self.n_tokens for x in h]], dtype=torch.float32) logits, per, shared = self.model.telemetry(feat) logits = logits / max(1e-3, temperature) # keep it musical: don't collapse to silence, don't get stuck on one key if anti_silence and all(t == 0 for t in h[-anti_silence:]): logits[0] -= 8.0 for t in set(h[-3:]): if t > 0: logits[t] -= rep_penalty tok = int(torch.multinomial(F.softmax(logits, -1), 1).item()) return tok, self.tok2midi.get(tok, 0), {"spec": per, "shared": shared} if __name__ == "__main__": model, meta = train_and_save() # quick listen: generate 40 notes from the song's seed p = PianoPlayer() hist = list(p.seed_seq) out = [] for _ in range(40): tok, midi, _ = p.next_token(hist); hist.append(tok); out.append(midi) print("sample MIDI stream:", out)