Spaces:
Running on Zero
Running on Zero
| """ | |
| poly_mind.py -- a CHORD-capable Modular Mind (same method as the boss / mono piano: | |
| specialists -> RecursiveLink -> coordinator), but multi-LABEL: it predicts the SET of | |
| notes sounding in the next frame, not a single note. Trained by multi-label next-frame | |
| prediction on a polyphonic transcription (poly_notes.json, a piano-roll of chords). | |
| """ | |
| from __future__ import annotations | |
| import json, os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| K, H, D_LATENT = 16, 64, 32 # context frames (=2s @8fps), hidden, latent | |
| SPECS = [("Bass", "low"), ("Tenor", "mid"), ("Soprano", "high"), | |
| ("Sustain", None), ("Onset", None), ("Phrase", None)] | |
| def _group_masks(n): | |
| m = {g: torch.zeros(n) for g in ("low", "mid", "high")} | |
| third = max(1, n // 3) | |
| for i in range(n): | |
| g = "low" if i < third else ("mid" if i < 2 * third else "high") | |
| m[g][i] = 1.0 | |
| return m | |
| class PolyMind(nn.Module): | |
| def __init__(self, n_notes): | |
| super().__init__() | |
| self.n = n_notes | |
| nf = K * n_notes | |
| self.fc1 = nn.ModuleList([nn.Linear(nf, H) for _ in SPECS]) | |
| self.lat = nn.ModuleList([nn.Linear(H, D_LATENT) for _ in SPECS]) | |
| self.drv = nn.ModuleDict({nm: nn.Linear(H, 1) for nm, o in SPECS if o}) | |
| self.ni = nn.LayerNorm(D_LATENT) | |
| self.v = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) | |
| self.g = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False) | |
| self.d = nn.Linear(2 * D_LATENT, D_LATENT, bias=False) | |
| self.no = nn.LayerNorm(D_LATENT) | |
| self.coord = nn.Linear(D_LATENT, n_notes) | |
| for gname, t in _group_masks(n_notes).items(): | |
| self.register_buffer("mask_" + gname, t) | |
| def _core(self, feat, tel=False): | |
| drives = torch.zeros(feat.shape[0], self.n, device=feat.device) | |
| lats, per = [], [] | |
| for i, (nm, o) in enumerate(SPECS): | |
| h = torch.tanh(self.fc1[i](feat)) | |
| lat = self.lat[i](h); lats.append(lat) | |
| dd = None | |
| if o: | |
| dd = self.drv[nm](h) | |
| drives = drives + dd * getattr(self, "mask_" + o) | |
| if tel: | |
| per.append({"name": nm, "owns": o, | |
| "drive": round(float(dd.mean().item()), 3) if dd is not None else None, | |
| "act": round(float(lat.norm().item()), 3)}) | |
| z = torch.stack(lats, 0).sum(0); zn = self.ni(z) | |
| shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z) | |
| logits = drives + self.coord(shared) | |
| return logits, per, shared | |
| def forward(self, feat): | |
| return self._core(feat)[0] | |
| def telemetry(self, feat): | |
| logits, per, shared = self._core(feat, tel=True) | |
| return logits[0], per, [round(float(v), 2) for v in shared[0].tolist()] | |
| def _rolls(seq, n): | |
| R = np.zeros((len(seq), n), dtype=np.float32) | |
| for i, fr in enumerate(seq): | |
| for t in fr: | |
| R[i, t] = 1.0 | |
| return R | |
| def train_and_save(notes=os.path.join(HERE, "poly_notes.json"), | |
| out=os.path.join(HERE, "poly_weights.pt"), epochs=700, seed=0): | |
| torch.manual_seed(seed) | |
| meta = json.load(open(notes)); seq, n = meta["seq"], meta["n_tokens"] | |
| R = _rolls(seq, n) | |
| X = np.stack([R[t - K:t].reshape(-1) for t in range(K, len(R))]) | |
| Y = np.stack([R[t] for t in range(K, len(R))]) | |
| X, Y = torch.tensor(X), torch.tensor(Y) | |
| model = PolyMind(n) | |
| opt = torch.optim.Adam(model.parameters(), lr=8e-3) | |
| pw = torch.tensor(float((Y.numel() - Y.sum()) / (Y.sum() + 1))) # balance sparse positives | |
| N, bs = X.shape[0], 512 | |
| for ep in range(epochs): | |
| perm = torch.randperm(N); tot = 0.0 | |
| for i in range(0, N, bs): | |
| idx = perm[i:i + bs] | |
| loss = F.binary_cross_entropy_with_logits(model(X[idx]), Y[idx], pos_weight=pw) | |
| opt.zero_grad(); loss.backward(); opt.step(); tot += loss.item() * len(idx) | |
| if ep % 100 == 0 or ep == epochs - 1: | |
| with torch.no_grad(): | |
| p = (torch.sigmoid(model(X)) > 0.5).float() | |
| tp = (p * Y).sum(); prec = tp / (p.sum() + 1); rec = tp / (Y.sum() + 1) | |
| print(f" ep {ep:4d} loss {tot/N:.3f} precision {prec:.2f} recall {rec:.2f}") | |
| cnt = R.sum(1) | |
| bi = int(np.argmax(np.convolve(cnt, np.ones(K), "valid"))) | |
| torch.save({"state": model.state_dict(), "n_tokens": n, "tok2midi": meta["tok2midi"], | |
| "K": K, "fps": meta.get("fps", 8), "seed": [seq[bi + j] for j in range(K)]}, out) | |
| print("saved ->", out) | |
| return model, meta | |
| class PolyPlayer: | |
| def __init__(self, weights=None): | |
| sft = os.path.join(HERE, "poly_weights.safetensors") | |
| if weights is None and os.path.exists(sft): | |
| from safetensors.torch import load_file | |
| from safetensors import safe_open | |
| state = load_file(sft) | |
| with safe_open(sft, framework="pt") as f: | |
| md = f.metadata() or {} | |
| self.n = int(md["n_tokens"]); self.K = int(md["K"]); self.fps = int(md["fps"]) | |
| self.tok2midi = {int(k): int(v) for k, v in json.loads(md["tok2midi"]).items()} | |
| self.seed = json.loads(md["seed"]) | |
| else: | |
| ck = torch.load(weights or os.path.join(HERE, "poly_weights.pt"), map_location="cpu") | |
| self.n, self.K, self.fps = ck["n_tokens"], ck["K"], ck["fps"] | |
| self.tok2midi = {int(k): int(v) for k, v in ck["tok2midi"].items()} | |
| self.seed = ck["seed"]; state = ck["state"] | |
| self.model = PolyMind(self.n); self.model.load_state_dict(state); self.model.eval() | |
| def next_frame(self, history, thresh=0.45, maxn=4): | |
| h = list(history)[-self.K:] | |
| while len(h) < self.K: | |
| h = [[]] + h | |
| R = np.zeros((self.K, self.n), dtype=np.float32) | |
| for i, fr in enumerate(h): | |
| for t in fr: | |
| if 0 <= t < self.n: | |
| R[i, t] = 1.0 | |
| feat = torch.tensor(R.reshape(1, -1)) | |
| logits, per, shared = self.model.telemetry(feat) | |
| probs = torch.sigmoid(logits) | |
| # anti-silence: if the last 2 frames were empty, force the most likely note(s) | |
| if all(len(f) == 0 for f in h[-2:]): | |
| thresh = min(thresh, float(probs.max()) * 0.6 + 1e-3) | |
| idx = (probs > thresh).nonzero().flatten().tolist() | |
| if len(idx) > maxn: | |
| idx = torch.topk(probs, maxn).indices.tolist() | |
| toks = sorted(int(t) for t in idx) | |
| return toks, [self.tok2midi.get(t, 0) for t in toks], {"spec": per, "shared": shared} | |
| if __name__ == "__main__": | |
| train_and_save() | |
| p = PolyPlayer() | |
| hist = list(p.seed) | |
| print("sample chords (MIDI sets):") | |
| for _ in range(10): | |
| toks, mid, _ = p.next_frame(hist); hist.append(toks); print(" ", mid) | |