from miditok import REMI from transformers import AutoModelForCausalLM, GenerationConfig, AutoConfig from huggingface_hub import hf_hub_download import torch from pathlib import Path from symusic import Score class Processor: def __init__(self, model_location_repo, model_tokenizer_file) -> None: self.config = AutoConfig.from_pretrained(model_location_repo) self.model = AutoModelForCausalLM.from_pretrained(model_location_repo, config=self.config) tokenizer_file_location = hf_hub_download(repo_id=model_location_repo, filename=model_tokenizer_file) self.tokenizer = REMI(params=Path(tokenizer_file_location)) self.generation_config = GenerationConfig( max_new_tokens=2000, num_beams=1, do_sample=True, temperature=0.9, top_k=15, top_p=0.95, epsilon_cutoff=3e-4, eta_cutoff=1e-3, pad_token_id=self.tokenizer['PAD_None'], bos_token_id=self.tokenizer['BOS_None'], eos_token_id=self.tokenizer['EOS_None'], ) def transpose_midi(self, midi_bytes: bytes | None, max_new_tokens: int = 2000, temperature: float = 0.9, top_p: float = 0.95, do_sample: bool = True) -> bytes | None: """"" Process the MIDI file using a transformer model to generate new MIDI content based on the input. Args: midi_bytes: Raw MIDI file bytes from the frontend Returns: Generated MIDI file bytes """"" if midi_bytes is None: return None try: score = Score.from_midi(midi_bytes) tokenized_input = self.tokenizer(score) max_len = self.model.config.max_position_embeddings print(f"Max position embeddings: {self.model.config.max_position_embeddings}") max_len = 1024 #TODO for now as we are using a smaller model # Truncate input if it exceeds the model's maximum context length input_ids = tokenized_input[0].ids if len(input_ids) >= max_len: print(f"Warning: Input sequence ({len(input_ids)}) longer than max_position_embeddings ({max_len}). Truncating.") input_ids = input_ids[-max_len:] tensor_sequence = torch.tensor([input_ids], dtype=torch.long) print(f"input tensor shape: {tensor_sequence.shape}") input_token_length = tensor_sequence.shape[1] # Generate the new token sequence gen_config = GenerationConfig( max_new_tokens=int(max_new_tokens), num_beams=self.generation_config.num_beams, do_sample=do_sample, temperature=temperature, top_k=self.generation_config.top_k, top_p=top_p, epsilon_cutoff=self.generation_config.epsilon_cutoff, eta_cutoff=self.generation_config.eta_cutoff, pad_token_id=self.generation_config.pad_token_id, bos_token_id=self.generation_config.bos_token_id, eos_token_id=self.generation_config.eos_token_id, ) res = self.model.generate( inputs=tensor_sequence, generation_config=gen_config) print("Generated Output Shape", res.shape) print(f"New tokens length: {res.shape[1] - input_token_length}") # Decode the generated tokens (excluding the input part) decoded = self.tokenizer.decode([res[0][input_token_length:]]) return decoded.dumps_midi() except Exception as e: print(f"Error processing MIDI: {e}") return midi_bytes # Return original on error