TinyMozart_85M / use.py
LH-Tech-AI's picture
Create use.py
55d79c8 verified
# To use this run this first:
# apt-get update -y > /dev/null
# apt-get install -y fluid-soundfont-gm > /dev/null
import torch
import torch.nn as nn
import torch.nn.functional as F
from miditok import REMI
from midi2audio import FluidSynth
import os
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(512, head_size, bias=False)
self.query = nn.Linear(512, head_size, bias=False)
self.value = nn.Linear(512, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(512, 512)))
self.dropout = nn.Dropout(0.1)
def forward(self, x):
B, T, C = x.shape
k, q, v = self.key(x), self.query(x), self.value(x)
wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
return wei @ v
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(512, 512)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
return self.dropout(self.proj(out))
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(0.1))
def forward(self, x): return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class TinyMozart(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, 512)
self.position_embedding_table = nn.Embedding(512, 512)
self.blocks = nn.Sequential(*[Block(512, 8) for _ in range(8)])
self.ln_f = nn.LayerNorm(512)
self.lm_head = nn.Linear(512, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device))
x = self.blocks(x)
logits = self.lm_head(self.ln_f(x))
return logits, None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = REMI()
vocab_size = 300
model = TinyMozart(vocab_size).to(device)
checkpoint_path = 'model.pt'
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Model loaded from checkpoint (Iter {checkpoint['iter']})")
model.eval()
@torch.no_grad()
def generate_tinymozart_pro(max_len=1000, temp=1.0, top_p=0.9, top_k=0):
x = torch.zeros((1, 1), dtype=torch.long, device=device)
for _ in range(max_len):
x_cond = x[:, -512:]
logits, _ = model(x_cond)
logits = logits[:, -1, :] / temp
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits.scatter_(1, indices_to_remove.unsqueeze(0), -float('Inf'))
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
x = torch.cat((x, next_token), dim=1)
return x[0].cpu().numpy().tolist()
print("TinyMozart creates music...")
tokens = generate_tinymozart_pro(
max_len=600,
temp=1.2,
top_p=0.95,
top_k=50
)
if tokens[0] == 0:
tokens = tokens[1:]
print(tokens)
print("Converting tokens to MIDI...")
generated_score = tokenizer.decode([tokens])
generated_score.dump_midi("output_mozart.mid")
SF2_PATH = "/usr/share/sounds/sf2/FluidR3_GM.sf2"
print("Creating Audio with System-Soundfont...")
if os.path.exists(SF2_PATH):
fs = FluidSynth(SF2_PATH, sample_rate=44100)
fs.midi_to_audio("output_mozart.mid", "output_mozart.wav")
from IPython.display import Audio
display(Audio("output_mozart.wav"))
else:
print("Soundfont not found! Look at the top of this script (use.py) and follow the instructions in the comments on the top!")