File size: 3,843 Bytes
dbf06fc
ef96802
 
 
 
dbf06fc
ef96802
 
 
9aa1091
ef96802
 
 
48830a5
ef96802
48830a5
ef96802
 
 
 
 
 
 
 
 
 
 
 
 
 
329de81
ef96802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f2a67e
ef96802
 
 
329de81
 
 
 
 
 
 
 
 
 
 
 
 
ef96802
 
329de81
ef96802
 
1f2a67e
ef96802
 
 
 
1f2a67e
ef96802
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from miditok import REMI
from transformers import AutoModelForCausalLM, GenerationConfig, AutoConfig
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
from symusic import Score


class Processor: 
    def __init__(self, model_location_repo, model_tokenizer_file) -> None:
        self.config = AutoConfig.from_pretrained(model_location_repo)
        self.model = AutoModelForCausalLM.from_pretrained(model_location_repo, config=self.config)
  
        tokenizer_file_location = hf_hub_download(repo_id=model_location_repo, filename=model_tokenizer_file)
  
        self.tokenizer = REMI(params=Path(tokenizer_file_location))
        self.generation_config = GenerationConfig(
            max_new_tokens=2000,
            num_beams=1,
            do_sample=True,
            temperature=0.9,
            top_k=15,
            top_p=0.95,
            epsilon_cutoff=3e-4,
            eta_cutoff=1e-3,
            pad_token_id=self.tokenizer['PAD_None'],
            bos_token_id=self.tokenizer['BOS_None'],
            eos_token_id=self.tokenizer['EOS_None'],
        )

    def transpose_midi(self, midi_bytes: bytes | None, max_new_tokens: int = 2000, temperature: float = 0.9, top_p: float = 0.95, do_sample: bool = True) -> bytes | None:
        """""
        Process the MIDI file using a transformer model to generate new MIDI content based on the input.
        
        Args:
            midi_bytes: Raw MIDI file bytes from the frontend
            
        Returns:
            Generated MIDI file bytes
        """""
    
        if midi_bytes is None:
            return None
        
        try:
            
            score = Score.from_midi(midi_bytes)

            tokenized_input = self.tokenizer(score)
            max_len = self.model.config.max_position_embeddings
            print(f"Max position embeddings: {self.model.config.max_position_embeddings}")
            max_len = 1024 #TODO for now as we are using a smaller model

            # Truncate input if it exceeds the model's maximum context length   
            input_ids = tokenized_input[0].ids
            if len(input_ids) >= max_len:
                print(f"Warning: Input sequence ({len(input_ids)}) longer than max_position_embeddings ({max_len}). Truncating.")
                input_ids = input_ids[-max_len:]

            tensor_sequence = torch.tensor([input_ids], dtype=torch.long)
            print(f"input tensor shape: {tensor_sequence.shape}")
            input_token_length = tensor_sequence.shape[1]

            # Generate the new token sequence
            gen_config = GenerationConfig(
                max_new_tokens=int(max_new_tokens),
                num_beams=self.generation_config.num_beams,
                do_sample=do_sample,
                temperature=temperature,
                top_k=self.generation_config.top_k,
                top_p=top_p,
                epsilon_cutoff=self.generation_config.epsilon_cutoff,
                eta_cutoff=self.generation_config.eta_cutoff,
                pad_token_id=self.generation_config.pad_token_id,
                bos_token_id=self.generation_config.bos_token_id,
                eos_token_id=self.generation_config.eos_token_id,
            )
            res = self.model.generate(
                inputs=tensor_sequence,
                generation_config=gen_config)

            print("Generated Output Shape", res.shape)
            print(f"New tokens length: {res.shape[1] - input_token_length}")

            # Decode the generated tokens (excluding the input part)
            decoded = self.tokenizer.decode([res[0][input_token_length:]])
            
            return decoded.dumps_midi()
            
        except Exception as e:
            print(f"Error processing MIDI: {e}")
            return midi_bytes  # Return original on error