ModuleMind / piano /piano_mind.py
Quazim0t0's picture
Add files using upload-large-folder tool
45e7dfb verified
Raw
History Blame Contribute Delete
7.4 kB
"""
piano_mind.py -- a Modular Mind that plays piano, built with the SAME method as the
boss-fight brain: several tiny specialist MLPs emit latents, a RecursiveLink merges
them into one shared latent, and a coordinator reads it to pick the next note. Some
specialists "own" a register (Bass/Tenor/Soprano/Rest) and add a drive to those keys;
two are modulators (Onset/Phrase) that only write into the shared latent.
Trained by next-note prediction on a note sequence transcribed from a song
(piano/notes.json). Tiny -> trains in seconds and runs instantly for the self-playing
piano in the Space. torch is used (already a Space dep); the architecture mirrors
mm_torch / modular_mind.
"""
from __future__ import annotations
import json, os
import torch
import torch.nn as nn
import torch.nn.functional as F
HERE = os.path.dirname(os.path.abspath(__file__))
K = 40 # context window (frames of recent notes) = 5s at 8fps
H, D_LATENT = 64, 32
# (name, owned register or None) -- mirrors the boss's 5 owners + 2 modulators
SPECS = [("Bass", "low"), ("Tenor", "mid"), ("Soprano", "high"),
("Rest", "rest"), ("Onset", None), ("Phrase", None)]
def _group_masks(n_tokens):
"""token 0 = rest; notes 1..n-1 split into low/mid/high thirds."""
m = {g: torch.zeros(n_tokens) for g in ("low", "mid", "high", "rest")}
m["rest"][0] = 1.0
notes = list(range(1, n_tokens))
third = max(1, len(notes) // 3)
for j, t in enumerate(notes):
g = "low" if j < third else ("mid" if j < 2 * third else "high")
m[g][t] = 1.0
return m
class PianoMind(nn.Module):
def __init__(self, n_tokens):
super().__init__()
self.n_tokens = n_tokens
self.fc1 = nn.ModuleList([nn.Linear(K, H) for _ in SPECS])
self.lat = nn.ModuleList([nn.Linear(H, D_LATENT) for _ in SPECS])
self.drv = nn.ModuleDict({n: nn.Linear(H, 1) for n, owns in SPECS if owns})
# RecursiveLink (ReGLU + residual), same shape as the boss bridge
self.ni = nn.LayerNorm(D_LATENT)
self.v = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False)
self.g = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False)
self.d = nn.Linear(2 * D_LATENT, D_LATENT, bias=False)
self.no = nn.LayerNorm(D_LATENT)
self.coord = nn.Linear(D_LATENT, n_tokens)
gm = _group_masks(n_tokens)
for g, t in gm.items():
self.register_buffer(f"mask_{g}", t)
def forward(self, feat): # feat [B, K] in [0,1]
B = feat.shape[0]
drives = torch.zeros(B, self.n_tokens, device=feat.device)
lats = []
for i, (name, owns) in enumerate(SPECS):
h = torch.tanh(self.fc1[i](feat))
lats.append(self.lat[i](h))
if owns:
drives = drives + self.drv[name](h) * getattr(self, f"mask_{owns}")
z = torch.stack(lats, 0).sum(0)
zn = self.ni(z)
shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z)
return drives + self.coord(shared) # logits [B, n_tokens]
@torch.no_grad()
def telemetry(self, feat): # feat [1, K] -> logits, per-specialist info, shared
drives = torch.zeros(1, self.n_tokens)
lats, per = [], []
for i, (name, owns) in enumerate(SPECS):
hh = torch.tanh(self.fc1[i](feat))
lat = self.lat[i](hh); lats.append(lat)
d = None
if owns:
d = float(self.drv[name](hh).item())
drives = drives + d * getattr(self, f"mask_{owns}")
per.append({"name": name, "owns": owns,
"drive": round(d, 3) if d is not None else None,
"act": round(float(lat.norm().item()), 3)})
z = torch.stack(lats, 0).sum(0); zn = self.ni(z)
shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z)
logits = drives + self.coord(shared)
return logits[0], per, [round(float(v), 2) for v in shared[0].tolist()]
def _seq_feats(seq, n_tokens):
import torch
s = torch.tensor(seq, dtype=torch.long)
X, Y = [], []
for t in range(K, len(s)):
X.append(s[t - K:t].float() / n_tokens)
Y.append(s[t])
return torch.stack(X), torch.tensor(Y)
def train_and_save(notes_path=os.path.join(HERE, "notes.json"),
out=os.path.join(HERE, "piano_weights.pt"), epochs=900, seed=0):
torch.manual_seed(seed)
meta = json.load(open(notes_path))
seq, n_tokens = meta["seq"], meta["n_tokens"]
X, Y = _seq_feats(seq, n_tokens)
model = PianoMind(n_tokens)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
n = X.shape[0]; bs = 512
for ep in range(epochs):
perm = torch.randperm(n)
tot = 0.0
for i in range(0, n, bs):
idx = perm[i:i + bs]
logits = model(X[idx])
loss = F.cross_entropy(logits, Y[idx])
opt.zero_grad(); loss.backward(); opt.step()
tot += loss.item() * len(idx)
if ep % 80 == 0 or ep == epochs - 1:
with torch.no_grad():
acc = (model(X).argmax(1) == Y).float().mean().item()
print(f" epoch {ep:4d} loss {tot/n:.3f} next-note acc {acc:.3f}")
import numpy as _np
arr = _np.array(seq) # seed from the most-sounding window (song fades in)
bi = int(max(range(len(seq) - K), key=lambda i: int((arr[i:i + K] > 0).sum())))
torch.save({"state": model.state_dict(), "n_tokens": n_tokens,
"tok2midi": meta["tok2midi"], "K": K, "fps": meta.get("fps", 8),
"seed_seq": seq[bi:bi + K]}, out)
print(f"saved -> {out}")
return model, meta
class PianoPlayer:
"""Loads the trained PianoMind and autoregressively yields the next note token."""
def __init__(self, weights=os.path.join(HERE, "piano_weights.pt")):
ck = torch.load(weights, map_location="cpu")
self.n_tokens = ck["n_tokens"]; self.K = ck["K"]; self.fps = ck["fps"]
self.tok2midi = {int(k): int(v) for k, v in ck["tok2midi"].items()}
self.seed_seq = ck["seed_seq"]
self.model = PianoMind(self.n_tokens); self.model.load_state_dict(ck["state"]); self.model.eval()
@torch.no_grad()
def next_token(self, history, temperature=0.95, anti_silence=3, rep_penalty=1.5):
h = list(history)[-self.K:]
if len(h) < self.K:
h = [0] * (self.K - len(h)) + h
feat = torch.tensor([[x / self.n_tokens for x in h]], dtype=torch.float32)
logits, per, shared = self.model.telemetry(feat)
logits = logits / max(1e-3, temperature)
# keep it musical: don't collapse to silence, don't get stuck on one key
if anti_silence and all(t == 0 for t in h[-anti_silence:]):
logits[0] -= 8.0
for t in set(h[-3:]):
if t > 0:
logits[t] -= rep_penalty
tok = int(torch.multinomial(F.softmax(logits, -1), 1).item())
return tok, self.tok2midi.get(tok, 0), {"spec": per, "shared": shared}
if __name__ == "__main__":
model, meta = train_and_save()
# quick listen: generate 40 notes from the song's seed
p = PianoPlayer()
hist = list(p.seed_seq)
out = []
for _ in range(40):
tok, midi, _ = p.next_token(hist); hist.append(tok); out.append(midi)
print("sample MIDI stream:", out)