Update README.md
Browse files
README.md
CHANGED
|
@@ -32,22 +32,25 @@ Generate symbolic music from a text prompt:
|
|
| 32 |
```python
|
| 33 |
from transformers import T5Tokenizer
|
| 34 |
from model.transformer_model import Transformer
|
| 35 |
-
from
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 39 |
-
artifact_folder = 'artifacts'
|
| 40 |
|
| 41 |
-
tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
|
| 42 |
# Load the tokenizer dictionary
|
| 43 |
-
with open(
|
| 44 |
r_tokenizer = pickle.load(f)
|
| 45 |
|
| 46 |
# Get the vocab size
|
| 47 |
vocab_size = len(r_tokenizer)
|
| 48 |
print("Vocab size: ", vocab_size)
|
| 49 |
-
model = Transformer(vocab_size, 768, 8,
|
| 50 |
-
model.load_state_dict(torch.load(
|
| 51 |
model.eval()
|
| 52 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 53 |
|
|
@@ -61,7 +64,6 @@ output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.
|
|
| 61 |
output_list = output[0].tolist()
|
| 62 |
generated_midi = r_tokenizer.decode(output_list)
|
| 63 |
generated_midi.dump_midi("output.mid")
|
| 64 |
-
post_processing("output.mid", "output.mid")
|
| 65 |
```
|
| 66 |
|
| 67 |
## Installation
|
|
@@ -78,26 +80,27 @@ The MidiCaps dataset is a large-scale dataset of 168k MIDI files paired with ric
|
|
| 78 |
|
| 79 |
Each question is rated on a Likert scale from 1 (very bad) to 7 (very good). The table shows the average ratings per question for each group of participants.
|
| 80 |
|
| 81 |
-
|
|
| 82 |
-
|
| 83 |
-
|
|
| 84 |
-
|
|
| 85 |
-
|
|
| 86 |
-
|
|
| 87 |
-
|
|
| 88 |
-
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
## Objective Evaluations
|
| 92 |
|
| 93 |
| Metric | text2midi | MidiCaps | MuseCoco |
|
| 94 |
|---------------------|-----------|----------|----------|
|
| 95 |
-
| CR β | 2.
|
| 96 |
-
| CLAP β | 0.
|
| 97 |
-
| TB (%) β |
|
| 98 |
-
| TBT (%) β |
|
| 99 |
-
| CK (%) β |
|
| 100 |
-
| CKD (%) β |
|
| 101 |
|
| 102 |
**Note**:
|
| 103 |
CR = Compression ratio
|
|
|
|
| 32 |
```python
|
| 33 |
from transformers import T5Tokenizer
|
| 34 |
from model.transformer_model import Transformer
|
| 35 |
+
from huggingface_hub import hf_hub_download
|
| 36 |
+
|
| 37 |
+
repo_id = "amaai-lab/text2midi"
|
| 38 |
+
# Download the model.bin file
|
| 39 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
|
| 40 |
+
# Download the vocab_remi.pkl file
|
| 41 |
+
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
|
| 42 |
|
| 43 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 44 |
|
|
|
|
| 45 |
# Load the tokenizer dictionary
|
| 46 |
+
with open(tokenizer_path, "rb") as f:
|
| 47 |
r_tokenizer = pickle.load(f)
|
| 48 |
|
| 49 |
# Get the vocab size
|
| 50 |
vocab_size = len(r_tokenizer)
|
| 51 |
print("Vocab size: ", vocab_size)
|
| 52 |
+
model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
|
| 53 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 54 |
model.eval()
|
| 55 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 56 |
|
|
|
|
| 64 |
output_list = output[0].tolist()
|
| 65 |
generated_midi = r_tokenizer.decode(output_list)
|
| 66 |
generated_midi.dump_midi("output.mid")
|
|
|
|
| 67 |
```
|
| 68 |
|
| 69 |
## Installation
|
|
|
|
| 80 |
|
| 81 |
Each question is rated on a Likert scale from 1 (very bad) to 7 (very good). The table shows the average ratings per question for each group of participants.
|
| 82 |
|
| 83 |
+
| Question | MidiCaps | text2midi | MuseCoco |
|
| 84 |
+
|---------------------|----------|-----------|----------|
|
| 85 |
+
| Musical Quality | 5.79 | 4.62 | 4.40 |
|
| 86 |
+
| Overall Matching | 5.42 | 4.67 | 4.07 |
|
| 87 |
+
| Genre Matching | 5.54 | 4.98 | 4.40 |
|
| 88 |
+
| Mood Matching | 5.70 | 5.00 | 4.32 |
|
| 89 |
+
| Key Matching | 4.61 | 3.64 | 3.36 |
|
| 90 |
+
| Chord Matching | 3.20 | 2.50 | 2.00 |
|
| 91 |
+
| Tempo Matching | 5.89 | 5.42 | 4.94 |
|
| 92 |
|
| 93 |
|
| 94 |
## Objective Evaluations
|
| 95 |
|
| 96 |
| Metric | text2midi | MidiCaps | MuseCoco |
|
| 97 |
|---------------------|-----------|----------|----------|
|
| 98 |
+
| CR β | 2.14 | 3.43 | 2.12 |
|
| 99 |
+
| CLAP β | 0.22 | 0.26 | 0.21 |
|
| 100 |
+
| TB (%) β | 27.85 | - | 21.71 |
|
| 101 |
+
| TBT (%) β | 57.78 | - | 54.63 |
|
| 102 |
+
| CK (%) β | 7.69 | - | 13.70 |
|
| 103 |
+
| CKD (%) β | 14.80 | - | 14.59 |
|
| 104 |
|
| 105 |
**Note**:
|
| 106 |
CR = Compression ratio
|