Spaces:
Running on Zero
Running on Zero
| """ | |
| piano_mind.py -- a Modular Mind that plays piano, built with the SAME method as the | |
| boss-fight brain: several tiny specialist MLPs emit latents, a RecursiveLink merges | |
| them into one shared latent, and a coordinator reads it to pick the next note. Some | |
| specialists "own" a register (Bass/Tenor/Soprano/Rest) and add a drive to those keys; | |
| two are modulators (Onset/Phrase) that only write into the shared latent. | |
| Trained by next-note prediction on a note sequence transcribed from a song | |
| (piano/notes.json). Tiny -> trains in seconds and runs instantly for the self-playing | |
| piano in the Space. torch is used (already a Space dep); the architecture mirrors | |
| mm_torch / modular_mind. | |
| """ | |
| from __future__ import annotations | |
| import json, os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| K = 40 # context window (frames of recent notes) = 5s at 8fps | |
| H, D_LATENT = 64, 32 | |
| # (name, owned register or None) -- mirrors the boss's 5 owners + 2 modulators | |
| SPECS = [("Bass", "low"), ("Tenor", "mid"), ("Soprano", "high"), | |
| ("Rest", "rest"), ("Onset", None), ("Phrase", None)] | |
| def _group_masks(n_tokens): | |
| """token 0 = rest; notes 1..n-1 split into low/mid/high thirds.""" | |
| m = {g: torch.zeros(n_tokens) for g in ("low", "mid", "high", "rest")} | |
| m["rest"][0] = 1.0 | |
| notes = list(range(1, n_tokens)) | |
| third = max(1, len(notes) // 3) | |
| for j, t in enumerate(notes): | |
| g = "low" if j < third else ("mid" if j < 2 * third else "high") | |
| m[g][t] = 1.0 | |
| return m | |
| class PianoMind(nn.Module): | |
| def __init__(self, n_tokens): | |
| super().__init__() | |
| self.n_tokens = n_tokens | |
| self.fc1 = nn.ModuleList([nn.Linear(K, H) for _ in SPECS]) | |
| self.lat = nn.ModuleList([nn.Linear(H, D_LATENT) for _ in SPECS]) | |
| self.drv = nn.ModuleDict({n: nn.Linear(H, 1) for n, owns in SPECS if owns}) | |
| # RecursiveLink (ReGLU + residual), same shape as the boss bridge | |
| self.ni = nn.LayerNorm(D_LATENT) | |
| self.v = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) | |
| self.g = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) | |
| self.d = nn.Linear(2 * D_LATENT, D_LATENT, bias=False) | |
| self.no = nn.LayerNorm(D_LATENT) | |
| self.coord = nn.Linear(D_LATENT, n_tokens) | |
| gm = _group_masks(n_tokens) | |
| for g, t in gm.items(): | |
| self.register_buffer(f"mask_{g}", t) | |
| def forward(self, feat): # feat [B, K] in [0,1] | |
| B = feat.shape[0] | |
| drives = torch.zeros(B, self.n_tokens, device=feat.device) | |
| lats = [] | |
| for i, (name, owns) in enumerate(SPECS): | |
| h = torch.tanh(self.fc1[i](feat)) | |
| lats.append(self.lat[i](h)) | |
| if owns: | |
| drives = drives + self.drv[name](h) * getattr(self, f"mask_{owns}") | |
| z = torch.stack(lats, 0).sum(0) | |
| zn = self.ni(z) | |
| shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z) | |
| return drives + self.coord(shared) # logits [B, n_tokens] | |
| def telemetry(self, feat): # feat [1, K] -> logits, per-specialist info, shared | |
| drives = torch.zeros(1, self.n_tokens) | |
| lats, per = [], [] | |
| for i, (name, owns) in enumerate(SPECS): | |
| hh = torch.tanh(self.fc1[i](feat)) | |
| lat = self.lat[i](hh); lats.append(lat) | |
| d = None | |
| if owns: | |
| d = float(self.drv[name](hh).item()) | |
| drives = drives + d * getattr(self, f"mask_{owns}") | |
| per.append({"name": name, "owns": owns, | |
| "drive": round(d, 3) if d is not None else None, | |
| "act": round(float(lat.norm().item()), 3)}) | |
| z = torch.stack(lats, 0).sum(0); zn = self.ni(z) | |
| shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z) | |
| logits = drives + self.coord(shared) | |
| return logits[0], per, [round(float(v), 2) for v in shared[0].tolist()] | |
| def _seq_feats(seq, n_tokens): | |
| import torch | |
| s = torch.tensor(seq, dtype=torch.long) | |
| X, Y = [], [] | |
| for t in range(K, len(s)): | |
| X.append(s[t - K:t].float() / n_tokens) | |
| Y.append(s[t]) | |
| return torch.stack(X), torch.tensor(Y) | |
| def train_and_save(notes_path=os.path.join(HERE, "notes.json"), | |
| out=os.path.join(HERE, "piano_weights.pt"), epochs=900, seed=0): | |
| torch.manual_seed(seed) | |
| meta = json.load(open(notes_path)) | |
| seq, n_tokens = meta["seq"], meta["n_tokens"] | |
| X, Y = _seq_feats(seq, n_tokens) | |
| model = PianoMind(n_tokens) | |
| opt = torch.optim.Adam(model.parameters(), lr=1e-2) | |
| n = X.shape[0]; bs = 512 | |
| for ep in range(epochs): | |
| perm = torch.randperm(n) | |
| tot = 0.0 | |
| for i in range(0, n, bs): | |
| idx = perm[i:i + bs] | |
| logits = model(X[idx]) | |
| loss = F.cross_entropy(logits, Y[idx]) | |
| opt.zero_grad(); loss.backward(); opt.step() | |
| tot += loss.item() * len(idx) | |
| if ep % 80 == 0 or ep == epochs - 1: | |
| with torch.no_grad(): | |
| acc = (model(X).argmax(1) == Y).float().mean().item() | |
| print(f" epoch {ep:4d} loss {tot/n:.3f} next-note acc {acc:.3f}") | |
| import numpy as _np | |
| arr = _np.array(seq) # seed from the most-sounding window (song fades in) | |
| bi = int(max(range(len(seq) - K), key=lambda i: int((arr[i:i + K] > 0).sum()))) | |
| torch.save({"state": model.state_dict(), "n_tokens": n_tokens, | |
| "tok2midi": meta["tok2midi"], "K": K, "fps": meta.get("fps", 8), | |
| "seed_seq": seq[bi:bi + K]}, out) | |
| print(f"saved -> {out}") | |
| return model, meta | |
| class PianoPlayer: | |
| """Loads the trained PianoMind and autoregressively yields the next note token.""" | |
| def __init__(self, weights=os.path.join(HERE, "piano_weights.pt")): | |
| ck = torch.load(weights, map_location="cpu") | |
| self.n_tokens = ck["n_tokens"]; self.K = ck["K"]; self.fps = ck["fps"] | |
| self.tok2midi = {int(k): int(v) for k, v in ck["tok2midi"].items()} | |
| self.seed_seq = ck["seed_seq"] | |
| self.model = PianoMind(self.n_tokens); self.model.load_state_dict(ck["state"]); self.model.eval() | |
| def next_token(self, history, temperature=0.95, anti_silence=3, rep_penalty=1.5): | |
| h = list(history)[-self.K:] | |
| if len(h) < self.K: | |
| h = [0] * (self.K - len(h)) + h | |
| feat = torch.tensor([[x / self.n_tokens for x in h]], dtype=torch.float32) | |
| logits, per, shared = self.model.telemetry(feat) | |
| logits = logits / max(1e-3, temperature) | |
| # keep it musical: don't collapse to silence, don't get stuck on one key | |
| if anti_silence and all(t == 0 for t in h[-anti_silence:]): | |
| logits[0] -= 8.0 | |
| for t in set(h[-3:]): | |
| if t > 0: | |
| logits[t] -= rep_penalty | |
| tok = int(torch.multinomial(F.softmax(logits, -1), 1).item()) | |
| return tok, self.tok2midi.get(tok, 0), {"spec": per, "shared": shared} | |
| if __name__ == "__main__": | |
| model, meta = train_and_save() | |
| # quick listen: generate 40 notes from the song's seed | |
| p = PianoPlayer() | |
| hist = list(p.seed_seq) | |
| out = [] | |
| for _ in range(40): | |
| tok, midi, _ = p.next_token(hist); hist.append(tok); out.append(midi) | |
| print("sample MIDI stream:", out) | |