File size: 3,364 Bytes
ab7c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Złączenie przez ensemble (logit-mix) DWÓCH modeli o WSPÓLNYM słowniku.
W każdym kroku: logits = alpha*A + (1-alpha)*B -> sampluj. Melodia wychodzi stylistycznie pomiędzy.
To jest baseline „płaskiego ważenia" (KMS5). Stitch reprezentacji = osobny eksperyment (E1).
Użycie:
  python src/compose/fuse.py --a data/models/waltz_ckpt.pt --b data/models/reel_sv_ckpt.pt \
      --alpha 0.5 --meter 3/4 --keys D,G,Emin --inst piano --out data/recordings/fuzja
"""
import argparse, sys, os, math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from torch.nn import functional as F
from core.gpt import GPT
from core.abc_to_midi import to_midi
from music21 import instrument as M

INST = {"piano": M.Piano, "violin": M.Violin, "none": None}

def load(path):
    ck = torch.load(path, map_location="cpu", weights_only=False)
    m = GPT(ck["config"]); m.load_state_dict(ck["model"]); m.eval()
    return m, ck

def first_tune(raw):
    out = []
    for ln in raw.split("\n"):
        if ln.startswith("X:") and out:
            break
        out.append(ln)
    return "\n".join(out).strip()

@torch.no_grad()
def gen_mix(A, B, idx, n, alpha, temp, topk, block):
    for _ in range(n):
        ic = idx[:, -block:]
        la, _ = A(ic); lb, _ = B(ic)
        logits = (alpha * la[:, -1, :] + (1 - alpha) * lb[:, -1, :]) / temp
        if topk:
            v, _ = torch.topk(logits, topk)
            logits[logits < v[:, [-1]]] = float("-inf")
        probs = F.softmax(logits, dim=-1)
        idx = torch.cat([idx, torch.multinomial(probs, 1)], dim=1)
    return idx

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--a", required=True); ap.add_argument("--b", required=True)
    ap.add_argument("--alpha", type=float, default=0.5)
    ap.add_argument("--meter", default="3/4"); ap.add_argument("--keys", default="D,G,Emin")
    ap.add_argument("--inst", default="piano", choices=list(INST))
    ap.add_argument("--out", required=True)
    ap.add_argument("--new", type=int, default=420); ap.add_argument("--temp", type=float, default=0.85)
    ap.add_argument("--topk", type=int, default=18)
    a = ap.parse_args()
    sys.stdout.reconfigure(encoding="utf-8")

    A, ckA = load(a.a); B, ckB = load(a.b)
    assert ckA["stoi"] == ckB["stoi"], "Modele mają RÓŻNY słownik — najpierw wspólny słownik!"
    stoi, itos = ckA["stoi"], ckA["itos"]
    block = ckA["config"].block_size
    os.makedirs(a.out, exist_ok=True)
    inst_cls = INST[a.inst]
    print(f"FUZJA α={a.alpha} (A={a.a} : B={a.b}) | wspólny słownik {len(stoi)} | -> {a.out}")
    torch.manual_seed(20260621)
    ok = 0
    for i, key in enumerate(a.keys.split(","), 1):
        seed = f"X:1\nM:{a.meter}\nK:{key}\n"
        idx = torch.tensor([[stoi[c] for c in seed]])
        gen = gen_mix(A, B, idx, a.new, a.alpha, a.temp, a.topk, block)[0].tolist()
        tune = first_tune("".join(itos[t] for t in gen))
        base = f"{a.out}/fuzja_{i}_{key}"
        open(base + ".abc", "w", encoding="utf-8").write(tune + "\n")
        good = to_midi(tune, base + ".mid", inst=inst_cls() if inst_cls else None)
        ok += good
        print(f"  #{i} ({key}) [{'MIDI OK' if good else 'błąd'}] -> {base}.mid")
    print(f"\ngotowe: {ok}/{len(a.keys.split(','))} fuzji w {a.out}/")

if __name__ == "__main__":
    main()