File size: 4,992 Bytes
55d79c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# To use this run this first:
# apt-get update -y > /dev/null
# apt-get install -y fluid-soundfont-gm > /dev/null

import torch
import torch.nn as nn
import torch.nn.functional as F
from miditok import REMI
from midi2audio import FluidSynth
import os

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(512, head_size, bias=False)
        self.query = nn.Linear(512, head_size, bias=False)
        self.value = nn.Linear(512, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(512, 512)))
        self.dropout = nn.Dropout(0.1)
    def forward(self, x):
        B, T, C = x.shape
        k, q, v = self.key(x), self.query(x), self.value(x)
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(512, 512)
        self.dropout = nn.Dropout(0.1)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(0.1))
    def forward(self, x): return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class TinyMozart(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, 512)
        self.position_embedding_table = nn.Embedding(512, 512)
        self.blocks = nn.Sequential(*[Block(512, 8) for _ in range(8)])
        self.ln_f = nn.LayerNorm(512)
        self.lm_head = nn.Linear(512, vocab_size)
    def forward(self, idx, targets=None):
        B, T = idx.shape
        x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device))
        x = self.blocks(x)
        logits = self.lm_head(self.ln_f(x))
        return logits, None

device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = REMI()

vocab_size = 300 

model = TinyMozart(vocab_size).to(device)
checkpoint_path = 'model.pt'

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded from checkpoint (Iter {checkpoint['iter']})")
model.eval()

@torch.no_grad()
def generate_tinymozart_pro(max_len=1000, temp=1.0, top_p=0.9, top_k=0):
    x = torch.zeros((1, 1), dtype=torch.long, device=device)

    for _ in range(max_len):
        x_cond = x[:, -512:] 
        logits, _ = model(x_cond)
        logits = logits[:, -1, :] / temp
        
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits.scatter_(1, indices_to_remove.unsqueeze(0), -float('Inf'))
        
        if top_k > 0:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        x = torch.cat((x, next_token), dim=1)
    
    return x[0].cpu().numpy().tolist()

print("TinyMozart creates music...")
tokens = generate_tinymozart_pro(
    max_len=600, 
    temp=1.2,
    top_p=0.95,
    top_k=50
)
if tokens[0] == 0:
    tokens = tokens[1:]

print(tokens)

print("Converting tokens to MIDI...")
generated_score = tokenizer.decode([tokens]) 
generated_score.dump_midi("output_mozart.mid")

SF2_PATH = "/usr/share/sounds/sf2/FluidR3_GM.sf2"

print("Creating Audio with System-Soundfont...")
if os.path.exists(SF2_PATH):
    fs = FluidSynth(SF2_PATH, sample_rate=44100)
    fs.midi_to_audio("output_mozart.mid", "output_mozart.wav")
    
    from IPython.display import Audio
    display(Audio("output_mozart.wav"))
else:
    print("Soundfont not found! Look at the top of this script (use.py) and follow the instructions in the comments on the top!")