| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from miditok import REMI |
| from midi2audio import FluidSynth |
| import os |
|
|
| class Head(nn.Module): |
| def __init__(self, head_size): |
| super().__init__() |
| self.key = nn.Linear(512, head_size, bias=False) |
| self.query = nn.Linear(512, head_size, bias=False) |
| self.value = nn.Linear(512, head_size, bias=False) |
| self.register_buffer('tril', torch.tril(torch.ones(512, 512))) |
| self.dropout = nn.Dropout(0.1) |
| def forward(self, x): |
| B, T, C = x.shape |
| k, q, v = self.key(x), self.query(x), self.value(x) |
| wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 |
| wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) |
| wei = F.softmax(wei, dim=-1) |
| return wei @ v |
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, num_heads, head_size): |
| super().__init__() |
| self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) |
| self.proj = nn.Linear(512, 512) |
| self.dropout = nn.Dropout(0.1) |
| def forward(self, x): |
| out = torch.cat([h(x) for h in self.heads], dim=-1) |
| return self.dropout(self.proj(out)) |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, n_embd): |
| super().__init__() |
| self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(0.1)) |
| def forward(self, x): return self.net(x) |
|
|
| class Block(nn.Module): |
| def __init__(self, n_embd, n_head): |
| super().__init__() |
| head_size = n_embd // n_head |
| self.sa = MultiHeadAttention(n_head, head_size) |
| self.ffwd = FeedForward(n_embd) |
| self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd) |
| def forward(self, x): |
| x = x + self.sa(self.ln1(x)) |
| x = x + self.ffwd(self.ln2(x)) |
| return x |
|
|
| class TinyMozart(nn.Module): |
| def __init__(self, vocab_size): |
| super().__init__() |
| self.token_embedding_table = nn.Embedding(vocab_size, 512) |
| self.position_embedding_table = nn.Embedding(512, 512) |
| self.blocks = nn.Sequential(*[Block(512, 8) for _ in range(8)]) |
| self.ln_f = nn.LayerNorm(512) |
| self.lm_head = nn.Linear(512, vocab_size) |
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device)) |
| x = self.blocks(x) |
| logits = self.lm_head(self.ln_f(x)) |
| return logits, None |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| tokenizer = REMI() |
|
|
| vocab_size = 300 |
|
|
| model = TinyMozart(vocab_size).to(device) |
| checkpoint_path = 'model.pt' |
|
|
| if os.path.exists(checkpoint_path): |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| print(f"Model loaded from checkpoint (Iter {checkpoint['iter']})") |
| model.eval() |
|
|
| @torch.no_grad() |
| def generate_tinymozart_pro(max_len=1000, temp=1.0, top_p=0.9, top_k=0): |
| x = torch.zeros((1, 1), dtype=torch.long, device=device) |
|
|
| for _ in range(max_len): |
| x_cond = x[:, -512:] |
| logits, _ = model(x_cond) |
| logits = logits[:, -1, :] / temp |
| |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| logits.scatter_(1, indices_to_remove.unsqueeze(0), -float('Inf')) |
| |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| x = torch.cat((x, next_token), dim=1) |
| |
| return x[0].cpu().numpy().tolist() |
|
|
| print("TinyMozart creates music...") |
| tokens = generate_tinymozart_pro( |
| max_len=600, |
| temp=1.2, |
| top_p=0.95, |
| top_k=50 |
| ) |
| if tokens[0] == 0: |
| tokens = tokens[1:] |
|
|
| print(tokens) |
|
|
| print("Converting tokens to MIDI...") |
| generated_score = tokenizer.decode([tokens]) |
| generated_score.dump_midi("output_mozart.mid") |
|
|
| SF2_PATH = "/usr/share/sounds/sf2/FluidR3_GM.sf2" |
|
|
| print("Creating Audio with System-Soundfont...") |
| if os.path.exists(SF2_PATH): |
| fs = FluidSynth(SF2_PATH, sample_rate=44100) |
| fs.midi_to_audio("output_mozart.mid", "output_mozart.wav") |
| |
| from IPython.display import Audio |
| display(Audio("output_mozart.wav")) |
| else: |
| print("Soundfont not found! Look at the top of this script (use.py) and follow the instructions in the comments on the top!") |