Upload OPTForMusicGeneration
Browse files
model.py
CHANGED
|
@@ -4,16 +4,14 @@ import torch
|
|
| 4 |
from miditok import TokSequence
|
| 5 |
|
| 6 |
|
| 7 |
-
# class OPTForMusicGenerationConfig(OPTConfig):
|
| 8 |
-
|
| 9 |
-
|
| 10 |
class OPTForMusicGeneration(OPTForCausalLM):
|
| 11 |
|
| 12 |
-
def generate_music(self, **kwargs):
|
| 13 |
input = torch.tensor([[self.config.bos_token_id]], device=self.device)
|
| 14 |
midi = self.generate(input, **kwargs)
|
| 15 |
generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
OPTForMusicGeneration.register_for_auto_class("AutoModel")
|
|
|
|
| 4 |
from miditok import TokSequence
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
class OPTForMusicGeneration(OPTForCausalLM):
|
| 8 |
|
| 9 |
+
def generate_music(self, tokenizer, **kwargs):
|
| 10 |
input = torch.tensor([[self.config.bos_token_id]], device=self.device)
|
| 11 |
midi = self.generate(input, **kwargs)
|
| 12 |
generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
|
| 13 |
+
generated_score = tokenizer(generated_ts)
|
| 14 |
+
return generated_score
|
| 15 |
|
| 16 |
|
| 17 |
OPTForMusicGeneration.register_for_auto_class("AutoModel")
|