File size: 4,992 Bytes
55d79c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | # To use this run this first:
# apt-get update -y > /dev/null
# apt-get install -y fluid-soundfont-gm > /dev/null
import torch
import torch.nn as nn
import torch.nn.functional as F
from miditok import REMI
from midi2audio import FluidSynth
import os
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(512, head_size, bias=False)
self.query = nn.Linear(512, head_size, bias=False)
self.value = nn.Linear(512, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(512, 512)))
self.dropout = nn.Dropout(0.1)
def forward(self, x):
B, T, C = x.shape
k, q, v = self.key(x), self.query(x), self.value(x)
wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
return wei @ v
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(512, 512)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
return self.dropout(self.proj(out))
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(0.1))
def forward(self, x): return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class TinyMozart(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, 512)
self.position_embedding_table = nn.Embedding(512, 512)
self.blocks = nn.Sequential(*[Block(512, 8) for _ in range(8)])
self.ln_f = nn.LayerNorm(512)
self.lm_head = nn.Linear(512, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device))
x = self.blocks(x)
logits = self.lm_head(self.ln_f(x))
return logits, None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = REMI()
vocab_size = 300
model = TinyMozart(vocab_size).to(device)
checkpoint_path = 'model.pt'
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Model loaded from checkpoint (Iter {checkpoint['iter']})")
model.eval()
@torch.no_grad()
def generate_tinymozart_pro(max_len=1000, temp=1.0, top_p=0.9, top_k=0):
x = torch.zeros((1, 1), dtype=torch.long, device=device)
for _ in range(max_len):
x_cond = x[:, -512:]
logits, _ = model(x_cond)
logits = logits[:, -1, :] / temp
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits.scatter_(1, indices_to_remove.unsqueeze(0), -float('Inf'))
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
x = torch.cat((x, next_token), dim=1)
return x[0].cpu().numpy().tolist()
print("TinyMozart creates music...")
tokens = generate_tinymozart_pro(
max_len=600,
temp=1.2,
top_p=0.95,
top_k=50
)
if tokens[0] == 0:
tokens = tokens[1:]
print(tokens)
print("Converting tokens to MIDI...")
generated_score = tokenizer.decode([tokens])
generated_score.dump_midi("output_mozart.mid")
SF2_PATH = "/usr/share/sounds/sf2/FluidR3_GM.sf2"
print("Creating Audio with System-Soundfont...")
if os.path.exists(SF2_PATH):
fs = FluidSynth(SF2_PATH, sample_rate=44100)
fs.midi_to_audio("output_mozart.mid", "output_mozart.wav")
from IPython.display import Audio
display(Audio("output_mozart.wav"))
else:
print("Soundfont not found! Look at the top of this script (use.py) and follow the instructions in the comments on the top!") |