ModuleMind / piano /poly_mind.py
Quazim0t0's picture
Upload 88 files
6e3b98a verified
Raw
History Blame Contribute Delete
6.94 kB
"""
poly_mind.py -- a CHORD-capable Modular Mind (same method as the boss / mono piano:
specialists -> RecursiveLink -> coordinator), but multi-LABEL: it predicts the SET of
notes sounding in the next frame, not a single note. Trained by multi-label next-frame
prediction on a polyphonic transcription (poly_notes.json, a piano-roll of chords).
"""
from __future__ import annotations
import json, os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
HERE = os.path.dirname(os.path.abspath(__file__))
K, H, D_LATENT = 16, 64, 32 # context frames (=2s @8fps), hidden, latent
SPECS = [("Bass", "low"), ("Tenor", "mid"), ("Soprano", "high"),
("Sustain", None), ("Onset", None), ("Phrase", None)]
def _group_masks(n):
m = {g: torch.zeros(n) for g in ("low", "mid", "high")}
third = max(1, n // 3)
for i in range(n):
g = "low" if i < third else ("mid" if i < 2 * third else "high")
m[g][i] = 1.0
return m
class PolyMind(nn.Module):
def __init__(self, n_notes):
super().__init__()
self.n = n_notes
nf = K * n_notes
self.fc1 = nn.ModuleList([nn.Linear(nf, H) for _ in SPECS])
self.lat = nn.ModuleList([nn.Linear(H, D_LATENT) for _ in SPECS])
self.drv = nn.ModuleDict({nm: nn.Linear(H, 1) for nm, o in SPECS if o})
self.ni = nn.LayerNorm(D_LATENT)
self.v = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False)
self.g = nn.Linear(D_LATENT, 2 * D_LATENT, bias=False)
self.d = nn.Linear(2 * D_LATENT, D_LATENT, bias=False)
self.no = nn.LayerNorm(D_LATENT)
self.coord = nn.Linear(D_LATENT, n_notes)
for gname, t in _group_masks(n_notes).items():
self.register_buffer("mask_" + gname, t)
def _core(self, feat, tel=False):
drives = torch.zeros(feat.shape[0], self.n, device=feat.device)
lats, per = [], []
for i, (nm, o) in enumerate(SPECS):
h = torch.tanh(self.fc1[i](feat))
lat = self.lat[i](h); lats.append(lat)
dd = None
if o:
dd = self.drv[nm](h)
drives = drives + dd * getattr(self, "mask_" + o)
if tel:
per.append({"name": nm, "owns": o,
"drive": round(float(dd.mean().item()), 3) if dd is not None else None,
"act": round(float(lat.norm().item()), 3)})
z = torch.stack(lats, 0).sum(0); zn = self.ni(z)
shared = self.no(self.d(F.relu(self.g(zn)) * self.v(zn)) + z)
logits = drives + self.coord(shared)
return logits, per, shared
def forward(self, feat):
return self._core(feat)[0]
@torch.no_grad()
def telemetry(self, feat):
logits, per, shared = self._core(feat, tel=True)
return logits[0], per, [round(float(v), 2) for v in shared[0].tolist()]
def _rolls(seq, n):
R = np.zeros((len(seq), n), dtype=np.float32)
for i, fr in enumerate(seq):
for t in fr:
R[i, t] = 1.0
return R
def train_and_save(notes=os.path.join(HERE, "poly_notes.json"),
out=os.path.join(HERE, "poly_weights.pt"), epochs=700, seed=0):
torch.manual_seed(seed)
meta = json.load(open(notes)); seq, n = meta["seq"], meta["n_tokens"]
R = _rolls(seq, n)
X = np.stack([R[t - K:t].reshape(-1) for t in range(K, len(R))])
Y = np.stack([R[t] for t in range(K, len(R))])
X, Y = torch.tensor(X), torch.tensor(Y)
model = PolyMind(n)
opt = torch.optim.Adam(model.parameters(), lr=8e-3)
pw = torch.tensor(float((Y.numel() - Y.sum()) / (Y.sum() + 1))) # balance sparse positives
N, bs = X.shape[0], 512
for ep in range(epochs):
perm = torch.randperm(N); tot = 0.0
for i in range(0, N, bs):
idx = perm[i:i + bs]
loss = F.binary_cross_entropy_with_logits(model(X[idx]), Y[idx], pos_weight=pw)
opt.zero_grad(); loss.backward(); opt.step(); tot += loss.item() * len(idx)
if ep % 100 == 0 or ep == epochs - 1:
with torch.no_grad():
p = (torch.sigmoid(model(X)) > 0.5).float()
tp = (p * Y).sum(); prec = tp / (p.sum() + 1); rec = tp / (Y.sum() + 1)
print(f" ep {ep:4d} loss {tot/N:.3f} precision {prec:.2f} recall {rec:.2f}")
cnt = R.sum(1)
bi = int(np.argmax(np.convolve(cnt, np.ones(K), "valid")))
torch.save({"state": model.state_dict(), "n_tokens": n, "tok2midi": meta["tok2midi"],
"K": K, "fps": meta.get("fps", 8), "seed": [seq[bi + j] for j in range(K)]}, out)
print("saved ->", out)
return model, meta
class PolyPlayer:
def __init__(self, weights=None):
sft = os.path.join(HERE, "poly_weights.safetensors")
if weights is None and os.path.exists(sft):
from safetensors.torch import load_file
from safetensors import safe_open
state = load_file(sft)
with safe_open(sft, framework="pt") as f:
md = f.metadata() or {}
self.n = int(md["n_tokens"]); self.K = int(md["K"]); self.fps = int(md["fps"])
self.tok2midi = {int(k): int(v) for k, v in json.loads(md["tok2midi"]).items()}
self.seed = json.loads(md["seed"])
else:
ck = torch.load(weights or os.path.join(HERE, "poly_weights.pt"), map_location="cpu")
self.n, self.K, self.fps = ck["n_tokens"], ck["K"], ck["fps"]
self.tok2midi = {int(k): int(v) for k, v in ck["tok2midi"].items()}
self.seed = ck["seed"]; state = ck["state"]
self.model = PolyMind(self.n); self.model.load_state_dict(state); self.model.eval()
@torch.no_grad()
def next_frame(self, history, thresh=0.45, maxn=4):
h = list(history)[-self.K:]
while len(h) < self.K:
h = [[]] + h
R = np.zeros((self.K, self.n), dtype=np.float32)
for i, fr in enumerate(h):
for t in fr:
if 0 <= t < self.n:
R[i, t] = 1.0
feat = torch.tensor(R.reshape(1, -1))
logits, per, shared = self.model.telemetry(feat)
probs = torch.sigmoid(logits)
# anti-silence: if the last 2 frames were empty, force the most likely note(s)
if all(len(f) == 0 for f in h[-2:]):
thresh = min(thresh, float(probs.max()) * 0.6 + 1e-3)
idx = (probs > thresh).nonzero().flatten().tolist()
if len(idx) > maxn:
idx = torch.topk(probs, maxn).indices.tolist()
toks = sorted(int(t) for t in idx)
return toks, [self.tok2midi.get(t, 0) for t in toks], {"spec": per, "shared": shared}
if __name__ == "__main__":
train_and_save()
p = PolyPlayer()
hist = list(p.seed)
print("sample chords (MIDI sets):")
for _ in range(10):
toks, mid, _ = p.next_frame(hist); hist.append(toks); print(" ", mid)