""" poly_mind.py -- a CHORD-capable Modular Mind (same method as the boss / mono piano: specialists -> RecursiveLink -> coordinator), but multi-LABEL: it predicts the SET of notes sounding in the next frame, not a single note. Trained by multi-label next-frame prediction on a polyphonic transcription (poly_notes.json, a piano-roll of chords). """ from __future__ import annotations import json, os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F HERE = os.path.dirname(os.path.abspath(__file__)) K, H, D_LATENT = 16, 64, 32 # context frames (=2s @8fps), hidden, latent SPECS = [("Bass", "low"), ("Tenor", "mid"), ("Soprano", "high"), ("Sustain", None), ("Onset", None), ("Phrase", None)] def _group_masks(n): m = {g: torch.zeros(n) for g in ("low", "mid", "high")} third = max(1, n // 3) for i in range(n): g = "low" if i < third else ("mid" if i < 2 * third else "high") m[g][i] = 1.0 return m class PolyMind(nn.Module): def __init__(self, n_notes): super().__init__() self.n = n_notes nf = K * n_notes self.fc1 = nn.ModuleList([nn.Linear(nf, H) for _ in SPECS]) self.lat = nn.ModuleList([nn.Linear(H, D_LATENT) for _ in SPECS]) self.drv = nn.ModuleDict({nm: nn.Linear(H, 1) for nm, o in SPECS if o}) self.ni = nn.LayerNorm(D_LATENT) self.v = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) self.g = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) self.d = nn.Linear(2 * D_LATENT, D_LATENT, bias=False) self.no = nn.LayerNorm(D_LATENT) self.coord = nn.Linear(D_LATENT, n_notes) for gname, t in _group_masks(n_notes).items(): self.register_buffer("mask_" + gname, t) def _core(self, feat, tel=False): drives = torch.zeros(feat.shape[0], self.n, device=feat.device) lats, per = [], [] for i, (nm, o) in enumerate(SPECS): h = torch.tanh(self.fc1[i](feat)) lat = self.lat[i](h); lats.append(lat) dd = None if o: dd = self.drv[nm](h) drives = drives + dd * getattr(self, "mask_" + o) if tel: per.append({"name": nm, "owns": o, "drive": round(float(dd.mean().item()), 3) if dd is not None else None, "act": round(float(lat.norm().item()), 3)}) z = torch.stack(lats, 0).sum(0); zn = self.ni(z) shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z) logits = drives + self.coord(shared) return logits, per, shared def forward(self, feat): return self._core(feat)[0] @torch.no_grad() def telemetry(self, feat): logits, per, shared = self._core(feat, tel=True) return logits[0], per, [round(float(v), 2) for v in shared[0].tolist()] def _rolls(seq, n): R = np.zeros((len(seq), n), dtype=np.float32) for i, fr in enumerate(seq): for t in fr: R[i, t] = 1.0 return R def train_and_save(notes=os.path.join(HERE, "poly_notes.json"), out=os.path.join(HERE, "poly_weights.pt"), epochs=700, seed=0): torch.manual_seed(seed) meta = json.load(open(notes)); seq, n = meta["seq"], meta["n_tokens"] R = _rolls(seq, n) X = np.stack([R[t - K:t].reshape(-1) for t in range(K, len(R))]) Y = np.stack([R[t] for t in range(K, len(R))]) X, Y = torch.tensor(X), torch.tensor(Y) model = PolyMind(n) opt = torch.optim.Adam(model.parameters(), lr=8e-3) pw = torch.tensor(float((Y.numel() - Y.sum()) / (Y.sum() + 1))) # balance sparse positives N, bs = X.shape[0], 512 for ep in range(epochs): perm = torch.randperm(N); tot = 0.0 for i in range(0, N, bs): idx = perm[i:i + bs] loss = F.binary_cross_entropy_with_logits(model(X[idx]), Y[idx], pos_weight=pw) opt.zero_grad(); loss.backward(); opt.step(); tot += loss.item() * len(idx) if ep % 100 == 0 or ep == epochs - 1: with torch.no_grad(): p = (torch.sigmoid(model(X)) > 0.5).float() tp = (p * Y).sum(); prec = tp / (p.sum() + 1); rec = tp / (Y.sum() + 1) print(f" ep {ep:4d} loss {tot/N:.3f} precision {prec:.2f} recall {rec:.2f}") cnt = R.sum(1) bi = int(np.argmax(np.convolve(cnt, np.ones(K), "valid"))) torch.save({"state": model.state_dict(), "n_tokens": n, "tok2midi": meta["tok2midi"], "K": K, "fps": meta.get("fps", 8), "seed": [seq[bi + j] for j in range(K)]}, out) print("saved ->", out) return model, meta class PolyPlayer: def __init__(self, weights=None): sft = os.path.join(HERE, "poly_weights.safetensors") if weights is None and os.path.exists(sft): from safetensors.torch import load_file from safetensors import safe_open state = load_file(sft) with safe_open(sft, framework="pt") as f: md = f.metadata() or {} self.n = int(md["n_tokens"]); self.K = int(md["K"]); self.fps = int(md["fps"]) self.tok2midi = {int(k): int(v) for k, v in json.loads(md["tok2midi"]).items()} self.seed = json.loads(md["seed"]) else: ck = torch.load(weights or os.path.join(HERE, "poly_weights.pt"), map_location="cpu") self.n, self.K, self.fps = ck["n_tokens"], ck["K"], ck["fps"] self.tok2midi = {int(k): int(v) for k, v in ck["tok2midi"].items()} self.seed = ck["seed"]; state = ck["state"] self.model = PolyMind(self.n); self.model.load_state_dict(state); self.model.eval() @torch.no_grad() def next_frame(self, history, thresh=0.45, maxn=4): h = list(history)[-self.K:] while len(h) < self.K: h = [[]] + h R = np.zeros((self.K, self.n), dtype=np.float32) for i, fr in enumerate(h): for t in fr: if 0 <= t < self.n: R[i, t] = 1.0 feat = torch.tensor(R.reshape(1, -1)) logits, per, shared = self.model.telemetry(feat) probs = torch.sigmoid(logits) # anti-silence: if the last 2 frames were empty, force the most likely note(s) if all(len(f) == 0 for f in h[-2:]): thresh = min(thresh, float(probs.max()) * 0.6 + 1e-3) idx = (probs > thresh).nonzero().flatten().tolist() if len(idx) > maxn: idx = torch.topk(probs, maxn).indices.tolist() toks = sorted(int(t) for t in idx) return toks, [self.tok2midi.get(t, 0) for t in toks], {"spec": per, "shared": shared} if __name__ == "__main__": train_and_save() p = PolyPlayer() hist = list(p.seed) print("sample chords (MIDI sets):") for _ in range(10): toks, mid, _ = p.next_frame(hist); hist.append(toks); print(" ", mid)