# To use this run this first: # apt-get update -y > /dev/null # apt-get install -y fluid-soundfont-gm > /dev/null import torch import torch.nn as nn import torch.nn.functional as F from miditok import REMI from midi2audio import FluidSynth import os class Head(nn.Module): def __init__(self, head_size): super().__init__() self.key = nn.Linear(512, head_size, bias=False) self.query = nn.Linear(512, head_size, bias=False) self.value = nn.Linear(512, head_size, bias=False) self.register_buffer('tril', torch.tril(torch.ones(512, 512))) self.dropout = nn.Dropout(0.1) def forward(self, x): B, T, C = x.shape k, q, v = self.key(x), self.query(x), self.value(x) wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) wei = F.softmax(wei, dim=-1) return wei @ v class MultiHeadAttention(nn.Module): def __init__(self, num_heads, head_size): super().__init__() self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) self.proj = nn.Linear(512, 512) self.dropout = nn.Dropout(0.1) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) return self.dropout(self.proj(out)) class FeedForward(nn.Module): def __init__(self, n_embd): super().__init__() self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(0.1)) def forward(self, x): return self.net(x) class Block(nn.Module): def __init__(self, n_embd, n_head): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_head, head_size) self.ffwd = FeedForward(n_embd) self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd) def forward(self, x): x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class TinyMozart(nn.Module): def __init__(self, vocab_size): super().__init__() self.token_embedding_table = nn.Embedding(vocab_size, 512) self.position_embedding_table = nn.Embedding(512, 512) self.blocks = nn.Sequential(*[Block(512, 8) for _ in range(8)]) self.ln_f = nn.LayerNorm(512) self.lm_head = nn.Linear(512, vocab_size) def forward(self, idx, targets=None): B, T = idx.shape x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device)) x = self.blocks(x) logits = self.lm_head(self.ln_f(x)) return logits, None device = 'cuda' if torch.cuda.is_available() else 'cpu' tokenizer = REMI() vocab_size = 300 model = TinyMozart(vocab_size).to(device) checkpoint_path = 'model.pt' if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"Model loaded from checkpoint (Iter {checkpoint['iter']})") model.eval() @torch.no_grad() def generate_tinymozart_pro(max_len=1000, temp=1.0, top_p=0.9, top_k=0): x = torch.zeros((1, 1), dtype=torch.long, device=device) for _ in range(max_len): x_cond = x[:, -512:] logits, _ = model(x_cond) logits = logits[:, -1, :] / temp sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits.scatter_(1, indices_to_remove.unsqueeze(0), -float('Inf')) if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) x = torch.cat((x, next_token), dim=1) return x[0].cpu().numpy().tolist() print("TinyMozart creates music...") tokens = generate_tinymozart_pro( max_len=600, temp=1.2, top_p=0.95, top_k=50 ) if tokens[0] == 0: tokens = tokens[1:] print(tokens) print("Converting tokens to MIDI...") generated_score = tokenizer.decode([tokens]) generated_score.dump_midi("output_mozart.mid") SF2_PATH = "/usr/share/sounds/sf2/FluidR3_GM.sf2" print("Creating Audio with System-Soundfont...") if os.path.exists(SF2_PATH): fs = FluidSynth(SF2_PATH, sample_rate=44100) fs.midi_to_audio("output_mozart.mid", "output_mozart.wav") from IPython.display import Audio display(Audio("output_mozart.wav")) else: print("Soundfont not found! Look at the top of this script (use.py) and follow the instructions in the comments on the top!")