Spaces:
Sleeping
Sleeping
File size: 3,843 Bytes
dbf06fc ef96802 dbf06fc ef96802 9aa1091 ef96802 48830a5 ef96802 48830a5 ef96802 329de81 ef96802 1f2a67e ef96802 329de81 ef96802 329de81 ef96802 1f2a67e ef96802 1f2a67e ef96802 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | from miditok import REMI
from transformers import AutoModelForCausalLM, GenerationConfig, AutoConfig
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
from symusic import Score
class Processor:
def __init__(self, model_location_repo, model_tokenizer_file) -> None:
self.config = AutoConfig.from_pretrained(model_location_repo)
self.model = AutoModelForCausalLM.from_pretrained(model_location_repo, config=self.config)
tokenizer_file_location = hf_hub_download(repo_id=model_location_repo, filename=model_tokenizer_file)
self.tokenizer = REMI(params=Path(tokenizer_file_location))
self.generation_config = GenerationConfig(
max_new_tokens=2000,
num_beams=1,
do_sample=True,
temperature=0.9,
top_k=15,
top_p=0.95,
epsilon_cutoff=3e-4,
eta_cutoff=1e-3,
pad_token_id=self.tokenizer['PAD_None'],
bos_token_id=self.tokenizer['BOS_None'],
eos_token_id=self.tokenizer['EOS_None'],
)
def transpose_midi(self, midi_bytes: bytes | None, max_new_tokens: int = 2000, temperature: float = 0.9, top_p: float = 0.95, do_sample: bool = True) -> bytes | None:
"""""
Process the MIDI file using a transformer model to generate new MIDI content based on the input.
Args:
midi_bytes: Raw MIDI file bytes from the frontend
Returns:
Generated MIDI file bytes
"""""
if midi_bytes is None:
return None
try:
score = Score.from_midi(midi_bytes)
tokenized_input = self.tokenizer(score)
max_len = self.model.config.max_position_embeddings
print(f"Max position embeddings: {self.model.config.max_position_embeddings}")
max_len = 1024 #TODO for now as we are using a smaller model
# Truncate input if it exceeds the model's maximum context length
input_ids = tokenized_input[0].ids
if len(input_ids) >= max_len:
print(f"Warning: Input sequence ({len(input_ids)}) longer than max_position_embeddings ({max_len}). Truncating.")
input_ids = input_ids[-max_len:]
tensor_sequence = torch.tensor([input_ids], dtype=torch.long)
print(f"input tensor shape: {tensor_sequence.shape}")
input_token_length = tensor_sequence.shape[1]
# Generate the new token sequence
gen_config = GenerationConfig(
max_new_tokens=int(max_new_tokens),
num_beams=self.generation_config.num_beams,
do_sample=do_sample,
temperature=temperature,
top_k=self.generation_config.top_k,
top_p=top_p,
epsilon_cutoff=self.generation_config.epsilon_cutoff,
eta_cutoff=self.generation_config.eta_cutoff,
pad_token_id=self.generation_config.pad_token_id,
bos_token_id=self.generation_config.bos_token_id,
eos_token_id=self.generation_config.eos_token_id,
)
res = self.model.generate(
inputs=tensor_sequence,
generation_config=gen_config)
print("Generated Output Shape", res.shape)
print(f"New tokens length: {res.shape[1] - input_token_length}")
# Decode the generated tokens (excluding the input part)
decoded = self.tokenizer.decode([res[0][input_token_length:]])
return decoded.dumps_midi()
except Exception as e:
print(f"Error processing MIDI: {e}")
return midi_bytes # Return original on error
|