Upload folder using huggingface_hub
Browse files- .gitattributes +1 -35
- README.md +88 -3
- __init__.py +0 -0
- config.json +23 -0
- configuration_aria.py +59 -0
- model.safetensors +3 -0
- modeling_aria.py +666 -0
- tokenization_aria.py +193 -0
- tokenizer_config.json +11 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,3 +1,88 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
datasets:
|
| 4 |
+
- loubb/aria-midi
|
| 5 |
+
language:
|
| 6 |
+
- en
|
| 7 |
+
tags:
|
| 8 |
+
- music
|
| 9 |
+
- MIDI
|
| 10 |
+
- piano
|
| 11 |
+
---
|
| 12 |
+
# Model
|
| 13 |
+
|
| 14 |
+
`Aria` is a pretrained autoregressive generative model for symbolic music based on the LLaMA 3.2 (1B) architecture. It was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. It has been finetuned to produce realistic continuations of solo-piano compositions as well as to produce general-purpose contrastive MIDI embeddings.
|
| 15 |
+
|
| 16 |
+
This HuggingFace page contains weights and usage instructions for the embedding model. For the pretrained base model, see [aria-medium-base](https://huggingface.co/loubb/aria-medium-base), and for the generative model, see [aria-medium-gen](https://huggingface.co/loubb/aria-medium-gen).
|
| 17 |
+
|
| 18 |
+
📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/)
|
| 19 |
+
🚀 Check out the real-time demo in the official [GitHub repository](https://github.com/EleutherAI/aria)
|
| 20 |
+
📊 Get access to our training dataset [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) to train your own models
|
| 21 |
+
|
| 22 |
+
## Usage Guidelines
|
| 23 |
+
|
| 24 |
+
Our embedding model was trained to capture composition and performance-level attributes by learning to embed different random slices of transcriptions of solo-piano performances into similar regions of latent space. As the model was trained to produce global embeddings with data augmentation (e.g., pitch, tempo, etc.), it might not be appropriate for every use case. For more information, see our [paper](https://example.com/).
|
| 25 |
+
|
| 26 |
+
## Quickstart
|
| 27 |
+
|
| 28 |
+
All of our models were trained using MIDI tooling and tokenizer accessible in the [aria-utils](https://github.com/EleutherAI/aria-utils) repository. Install the aria-utils package with pip:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
pip install git+https://github.com/EleutherAI/aria-utils.git
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
You can then generate a embedding for a (piano) MIDI file using the transformers library:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip install transformers
|
| 38 |
+
pip install torch
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from transformers import AutoModelForCausalLM
|
| 43 |
+
from transformers import AutoTokenizer
|
| 44 |
+
|
| 45 |
+
PROMPT_MIDI_LOAD_PATH = "mydir/prompt.midi"
|
| 46 |
+
MAX_SEQ_LEN = 2048
|
| 47 |
+
|
| 48 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 49 |
+
"loubb/aria-medium-embedding",
|
| 50 |
+
trust_remote_code=True,
|
| 51 |
+
)
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 53 |
+
"loubb/aria-medium-embedding",
|
| 54 |
+
trust_remote_code=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
prompt = tokenizer.encode_from_file(
|
| 58 |
+
PROMPT_MIDI_LOAD_PATH, return_tensors="pt"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Only sequences up to 2048 are supported.
|
| 62 |
+
# Embedding is extracted from end-of-sequence token
|
| 63 |
+
assert prompt.shape[1] <= MAX_SEQ_LEN
|
| 64 |
+
assert prompt[0, -1] == tokenizer._convert_token_to_id(tokenizer.eos_token)
|
| 65 |
+
|
| 66 |
+
# Alternatively if the sequence is too long:
|
| 67 |
+
prompt = prompt[:, :MAX_SEQ_LEN]
|
| 68 |
+
prompt = prompt[:, -1] = tokenizer._convert_token_to_id(tokenizer.eos_token)
|
| 69 |
+
|
| 70 |
+
# Generate and extract embedding
|
| 71 |
+
outputs = model.forward(prompt).squeeze(0)
|
| 72 |
+
embedding = outputs[-1]
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## License and Attribution
|
| 77 |
+
|
| 78 |
+
The Aria project has been kindly supported by EleutherAI, Stability AI, as well as by a compute grant from the Ministry of Science and ICT of Korea. Our models and MIDI tooling are released under the Apache-2.0 license. If you use the models or tooling for follow-up work, please cite the paper in which they were introduced:
|
| 79 |
+
|
| 80 |
+
```bibtex
|
| 81 |
+
@inproceedings{bradshawscaling,
|
| 82 |
+
title={Scaling Self-Supervised Representation Learning for Symbolic Piano Performance},
|
| 83 |
+
author={Bradshaw, Louis and Fan, Honglu and Spangher, Alex and Biderman, Stella and Colton, Simon},
|
| 84 |
+
booktitle={arXiv preprint},
|
| 85 |
+
year={2025},
|
| 86 |
+
url={https://arxiv.org/abs/2504.15071}
|
| 87 |
+
}
|
| 88 |
+
```
|
__init__.py
ADDED
|
File without changes
|
config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"AriaForSequenceEmbedding"
|
| 4 |
+
],
|
| 5 |
+
"eos_token_id": 1,
|
| 6 |
+
"pad_token_id": 2,
|
| 7 |
+
"hidden_size": 1536,
|
| 8 |
+
"embedding_size": 512,
|
| 9 |
+
"intermediate_size": 6144,
|
| 10 |
+
"max_seq_len": 2048,
|
| 11 |
+
"model_type": "aria",
|
| 12 |
+
"num_attention_heads": 24,
|
| 13 |
+
"num_hidden_layers": 16,
|
| 14 |
+
"torch_dtype": "bfloat16",
|
| 15 |
+
"transformers_version": "4.45.0",
|
| 16 |
+
"use_cache": false,
|
| 17 |
+
"vocab_size": 17727,
|
| 18 |
+
"auto_map": {
|
| 19 |
+
"AutoConfig": "configuration_aria.AriaConfig",
|
| 20 |
+
"AutoModel": "modeling_aria.AriaModel",
|
| 21 |
+
"AutoModelForCausalLM": "modeling_aria.AriaForSequenceEmbedding"
|
| 22 |
+
}
|
| 23 |
+
}
|
configuration_aria.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class AriaConfig(PretrainedConfig):
|
| 5 |
+
model_type = "aria"
|
| 6 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
vocab_size: int = 17727,
|
| 11 |
+
hidden_size: int = 1536,
|
| 12 |
+
embedding_size: int | None = None,
|
| 13 |
+
num_hidden_layers: int = 16,
|
| 14 |
+
num_attention_heads: int = 64,
|
| 15 |
+
intermediate_size: int = 6144,
|
| 16 |
+
max_seq_len: int = 8192,
|
| 17 |
+
use_cache: bool = True,
|
| 18 |
+
eos_token_id: int = 1,
|
| 19 |
+
pad_token_id: int = 2,
|
| 20 |
+
tie_word_embeddings: bool = False,
|
| 21 |
+
output_attentions: bool = False,
|
| 22 |
+
output_hidden_states: bool = False,
|
| 23 |
+
return_dict: bool = False,
|
| 24 |
+
**kwargs,
|
| 25 |
+
):
|
| 26 |
+
super().__init__(
|
| 27 |
+
pad_token_id=pad_token_id,
|
| 28 |
+
eos_token_id=eos_token_id,
|
| 29 |
+
**kwargs,
|
| 30 |
+
)
|
| 31 |
+
self.vocab_size = vocab_size
|
| 32 |
+
self.hidden_size = hidden_size
|
| 33 |
+
self.embedding_size = embedding_size
|
| 34 |
+
self.num_hidden_layers = num_hidden_layers
|
| 35 |
+
self.num_attention_heads = num_attention_heads
|
| 36 |
+
self.intermediate_size = intermediate_size
|
| 37 |
+
self.max_seq_len = max_seq_len
|
| 38 |
+
self.use_cache = use_cache
|
| 39 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 40 |
+
self.output_attentions = output_attentions
|
| 41 |
+
self.output_hidden_states = output_hidden_states
|
| 42 |
+
self.return_dict = return_dict
|
| 43 |
+
|
| 44 |
+
if self.intermediate_size % self.hidden_size != 0:
|
| 45 |
+
raise ValueError(
|
| 46 |
+
"The intermediate size needs to be divisible by hidden size."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
| 50 |
+
raise ValueError(
|
| 51 |
+
"The hidden size needs to be divisible by the number of attention heads."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def ff_mult(self):
|
| 56 |
+
return self.intermediate_size // self.hidden_size
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
__all__ = ["AriaConfig"]
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d49acf495f1cf91d26b297f6e902a3464215a4487906f3d6e918fee39ce5477
|
| 3 |
+
size 2528401656
|
modeling_aria.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is lightly adapted from https://github.com/EleutherAI/aria/blob/main/aria/model.py
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Union, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.checkpoint
|
| 7 |
+
|
| 8 |
+
from torch import nn as nn
|
| 9 |
+
from torch.nn import functional as F, CrossEntropyLoss
|
| 10 |
+
|
| 11 |
+
from transformers import Cache, DynamicCache, StaticCache
|
| 12 |
+
from transformers.utils import logging
|
| 13 |
+
from transformers.generation import GenerationMixin
|
| 14 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 15 |
+
from transformers.modeling_outputs import (
|
| 16 |
+
BaseModelOutputWithPast,
|
| 17 |
+
CausalLMOutputWithPast,
|
| 18 |
+
BaseModelOutputWithPoolingAndProjection,
|
| 19 |
+
)
|
| 20 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 21 |
+
|
| 22 |
+
from .configuration_aria import AriaConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AriaPreTrainedModel(PreTrainedModel):
|
| 29 |
+
config_class = AriaConfig
|
| 30 |
+
base_model_prefix = "aria"
|
| 31 |
+
supports_gradient_checkpointing = True
|
| 32 |
+
_no_split_modules = ["AriaBlock"]
|
| 33 |
+
_skip_keys_device_placement = "past_key_values"
|
| 34 |
+
_supports_flash_attn_2 = False
|
| 35 |
+
_supports_cache_class = True
|
| 36 |
+
_supports_quantized_cache = True
|
| 37 |
+
_supports_static_cache = True
|
| 38 |
+
_supports_sdpa = True
|
| 39 |
+
_supports_flex_attn = False
|
| 40 |
+
|
| 41 |
+
def _init_weights(self, module):
|
| 42 |
+
if isinstance(module, nn.Linear):
|
| 43 |
+
module.weight.data.normal_(
|
| 44 |
+
mean=0.0, std=self.config.initializer_range
|
| 45 |
+
)
|
| 46 |
+
if module.bias is not None:
|
| 47 |
+
module.bias.data.zero_()
|
| 48 |
+
elif isinstance(module, nn.Embedding):
|
| 49 |
+
module.weight.data.normal_(
|
| 50 |
+
mean=0.0, std=self.config.initializer_range
|
| 51 |
+
)
|
| 52 |
+
if module.padding_idx is not None:
|
| 53 |
+
module.weight.data[module.padding_idx].zero_()
|
| 54 |
+
elif isinstance(module, nn.LayerNorm):
|
| 55 |
+
module.bias.data.zero_()
|
| 56 |
+
module.weight.data.fill_(1.0)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TransformerBlock(nn.Module):
|
| 60 |
+
def __init__(self, model_config: AriaConfig, layer_idx: int):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.drop_p = 0.0
|
| 64 |
+
self.n_heads = model_config.num_attention_heads
|
| 65 |
+
self.d_model = model_config.hidden_size
|
| 66 |
+
self.d_head = (
|
| 67 |
+
model_config.hidden_size // model_config.num_attention_heads
|
| 68 |
+
)
|
| 69 |
+
self.max_seq_len = model_config.max_seq_len
|
| 70 |
+
self.layer_idx = layer_idx
|
| 71 |
+
|
| 72 |
+
# Attention
|
| 73 |
+
self.mixed_qkv = nn.Linear(
|
| 74 |
+
in_features=self.d_model,
|
| 75 |
+
out_features=3 * self.d_model,
|
| 76 |
+
bias=False,
|
| 77 |
+
)
|
| 78 |
+
self.att_proj_linear = nn.Linear(
|
| 79 |
+
in_features=self.d_model,
|
| 80 |
+
out_features=self.d_model,
|
| 81 |
+
bias=False,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# FF Layer
|
| 85 |
+
self.ff_gate_proj = nn.Linear(
|
| 86 |
+
in_features=self.d_model,
|
| 87 |
+
out_features=self.d_model * model_config.ff_mult,
|
| 88 |
+
bias=False,
|
| 89 |
+
)
|
| 90 |
+
self.ff_up_proj = nn.Linear(
|
| 91 |
+
in_features=self.d_model,
|
| 92 |
+
out_features=self.d_model * model_config.ff_mult,
|
| 93 |
+
bias=False,
|
| 94 |
+
)
|
| 95 |
+
self.ff_down_proj = nn.Linear(
|
| 96 |
+
in_features=self.d_model * model_config.ff_mult,
|
| 97 |
+
out_features=self.d_model,
|
| 98 |
+
bias=False,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Pre layer norms
|
| 102 |
+
self.norm1 = nn.LayerNorm(self.d_model)
|
| 103 |
+
self.norm2 = nn.LayerNorm(self.d_model)
|
| 104 |
+
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
x: torch.Tensor,
|
| 108 |
+
attention_mask: torch.Tensor,
|
| 109 |
+
freqs_cis: torch.Tensor,
|
| 110 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 111 |
+
past_key_values: Optional[
|
| 112 |
+
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
|
| 113 |
+
] = None,
|
| 114 |
+
use_cache: Optional[bool] = None,
|
| 115 |
+
output_attentions: Optional[bool] = None,
|
| 116 |
+
output_hidden_states: Optional[bool] = None,
|
| 117 |
+
return_dict: Optional[bool] = None,
|
| 118 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 119 |
+
):
|
| 120 |
+
attn_output, attn_weights, present = self._att_block(
|
| 121 |
+
self.norm1(x),
|
| 122 |
+
attention_mask,
|
| 123 |
+
freqs_cis,
|
| 124 |
+
past_key_values=past_key_values,
|
| 125 |
+
use_cache=use_cache,
|
| 126 |
+
output_attentions=output_attentions,
|
| 127 |
+
cache_position=cache_position,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
x = x + attn_output
|
| 131 |
+
x = x + self._ff_block(self.norm2(x))
|
| 132 |
+
|
| 133 |
+
outputs = (x, present)
|
| 134 |
+
if use_cache:
|
| 135 |
+
outputs = (x, present, attn_weights)
|
| 136 |
+
else:
|
| 137 |
+
outputs = (x, attn_weights)
|
| 138 |
+
|
| 139 |
+
return outputs
|
| 140 |
+
|
| 141 |
+
def _att_block(
|
| 142 |
+
self,
|
| 143 |
+
x: torch.Tensor,
|
| 144 |
+
attention_mask: torch.Tensor,
|
| 145 |
+
freqs_cis: torch.Tensor,
|
| 146 |
+
past_key_values: Optional[
|
| 147 |
+
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
|
| 148 |
+
] = None,
|
| 149 |
+
use_cache: Optional[bool] = None,
|
| 150 |
+
output_attentions: Optional[bool] = None,
|
| 151 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 152 |
+
):
|
| 153 |
+
batch_size, seq_len, _ = x.shape
|
| 154 |
+
mixed_qkv = self.mixed_qkv(x)
|
| 155 |
+
xq, xk, xv = mixed_qkv.chunk(3, -1)
|
| 156 |
+
|
| 157 |
+
# Reshape for rotary embeddings
|
| 158 |
+
# Need contiguous for q, k since in-place RoPE cannot be applied on a view
|
| 159 |
+
xq = xq.reshape(
|
| 160 |
+
batch_size, seq_len, self.n_heads, self.d_head
|
| 161 |
+
).contiguous()
|
| 162 |
+
xk = xk.reshape(
|
| 163 |
+
batch_size, seq_len, self.n_heads, self.d_head
|
| 164 |
+
).contiguous()
|
| 165 |
+
xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head)
|
| 166 |
+
|
| 167 |
+
# apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head)
|
| 168 |
+
xq = apply_rotary_emb(xq, freqs_cis)
|
| 169 |
+
xk = apply_rotary_emb(xk, freqs_cis)
|
| 170 |
+
xq, xk, xv = map(lambda t: t.transpose(1, 2), (xq, xk, xv))
|
| 171 |
+
|
| 172 |
+
if past_key_values is not None:
|
| 173 |
+
cache_kwargs = {
|
| 174 |
+
# "sin": sin,
|
| 175 |
+
# "cos": cos,
|
| 176 |
+
# "partial_rotation_size": self.rotary_ndims,
|
| 177 |
+
"cache_position": cache_position,
|
| 178 |
+
}
|
| 179 |
+
xk, xv = past_key_values.update(
|
| 180 |
+
xk, xv, self.layer_idx, cache_kwargs
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
att = F.scaled_dot_product_attention(
|
| 184 |
+
query=xq,
|
| 185 |
+
key=xk,
|
| 186 |
+
value=xv,
|
| 187 |
+
attn_mask=attention_mask[..., : xk.shape[2]],
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Reshape for out: (b_sz, s_len, n_head, d_head)
|
| 191 |
+
out = att.transpose(1, 2).contiguous()
|
| 192 |
+
out = out.view(batch_size, seq_len, self.n_heads * self.d_head)
|
| 193 |
+
|
| 194 |
+
if not output_attentions:
|
| 195 |
+
att = None
|
| 196 |
+
|
| 197 |
+
return self.att_proj_linear(out), att, past_key_values
|
| 198 |
+
|
| 199 |
+
def _ff_block(self, x: torch.Tensor):
|
| 200 |
+
return self.ff_down_proj(
|
| 201 |
+
F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class AriaModel(AriaPreTrainedModel):
|
| 206 |
+
"""Transformer decoder with no language model head.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
model_config (ModelConfig): Model config settings.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, model_config: AriaConfig):
|
| 213 |
+
super().__init__(model_config)
|
| 214 |
+
self.model_config = model_config
|
| 215 |
+
self.freqs_cis = None
|
| 216 |
+
self.causal_mask = None
|
| 217 |
+
|
| 218 |
+
self.tok_embeddings = nn.Embedding(
|
| 219 |
+
num_embeddings=model_config.vocab_size,
|
| 220 |
+
embedding_dim=model_config.hidden_size,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.out_layer_norm = nn.LayerNorm(model_config.hidden_size)
|
| 224 |
+
self.encode_layers = nn.ModuleList()
|
| 225 |
+
for i in range(model_config.num_hidden_layers):
|
| 226 |
+
self.encode_layers.append(TransformerBlock(model_config, i))
|
| 227 |
+
|
| 228 |
+
self.gradient_checkpointing = False
|
| 229 |
+
self.post_init()
|
| 230 |
+
|
| 231 |
+
def forward(
|
| 232 |
+
self,
|
| 233 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 234 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 235 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 236 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 237 |
+
past_key_values: Optional[
|
| 238 |
+
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
|
| 239 |
+
] = None,
|
| 240 |
+
use_cache: Optional[bool] = None,
|
| 241 |
+
output_attentions: Optional[bool] = None,
|
| 242 |
+
output_hidden_states: Optional[bool] = None,
|
| 243 |
+
return_dict: Optional[bool] = None,
|
| 244 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 245 |
+
):
|
| 246 |
+
"""Forward pass of Transformer.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
src (torch.tensor): Input to encoder block, of shape (batch_size,
|
| 250 |
+
seq_len, d_model).
|
| 251 |
+
attn_mask (Optional[torch.tensor]): Attention mask of shape
|
| 252 |
+
(batch_size, seq_len). Defaults to None.
|
| 253 |
+
past_kv (Optional[list[KVCache]]): a list of kv caches. The list index
|
| 254 |
+
corresponds to the layer index.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
torch.tensor: Model outputs with shape (batch_size, seq_len,
|
| 258 |
+
d_model).
|
| 259 |
+
"""
|
| 260 |
+
if (
|
| 261 |
+
input_ids is not None
|
| 262 |
+
and input_ids.shape[1] > self.model_config.max_seq_len
|
| 263 |
+
):
|
| 264 |
+
raise ValueError(
|
| 265 |
+
f"Sequence length ({input_ids.shape[1]}) exceeds max_seq_len "
|
| 266 |
+
f"({self.model_config.max_seq_len})."
|
| 267 |
+
)
|
| 268 |
+
if (
|
| 269 |
+
inputs_embeds is not None
|
| 270 |
+
and inputs_embeds.shape[1] > self.model_config.max_seq_len
|
| 271 |
+
):
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"Sequence length ({inputs_embeds.shape[1]}) exceeds max_seq_len "
|
| 274 |
+
f"({self.model_config.max_seq_len})."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
output_attentions = (
|
| 278 |
+
output_attentions
|
| 279 |
+
if output_attentions is not None
|
| 280 |
+
else self.model_config.output_attentions
|
| 281 |
+
)
|
| 282 |
+
output_hidden_states = (
|
| 283 |
+
output_hidden_states
|
| 284 |
+
if output_hidden_states is not None
|
| 285 |
+
else self.model_config.output_hidden_states
|
| 286 |
+
)
|
| 287 |
+
return_dict = (
|
| 288 |
+
return_dict
|
| 289 |
+
if return_dict is not None
|
| 290 |
+
else self.model_config.use_return_dict
|
| 291 |
+
)
|
| 292 |
+
use_cache = (
|
| 293 |
+
use_cache if use_cache is not None else self.model_config.use_cache
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 297 |
+
raise ValueError(
|
| 298 |
+
"You must specify exactly one of input_ids or inputs_embeds"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if self.gradient_checkpointing and self.training:
|
| 302 |
+
if use_cache:
|
| 303 |
+
logger.warning_once(
|
| 304 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 305 |
+
)
|
| 306 |
+
use_cache = False
|
| 307 |
+
|
| 308 |
+
if inputs_embeds is None:
|
| 309 |
+
inputs_embeds = self.tok_embeddings(input_ids)
|
| 310 |
+
|
| 311 |
+
return_legacy_cache = False
|
| 312 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 313 |
+
return_legacy_cache = True
|
| 314 |
+
if past_key_values is None:
|
| 315 |
+
past_key_values = DynamicCache()
|
| 316 |
+
else:
|
| 317 |
+
past_key_values = DynamicCache.from_legacy_cache(
|
| 318 |
+
past_key_values
|
| 319 |
+
)
|
| 320 |
+
logger.warning_once(
|
| 321 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 322 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 323 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
seq_length = inputs_embeds.shape[1]
|
| 327 |
+
if cache_position is None:
|
| 328 |
+
past_seen_tokens = (
|
| 329 |
+
past_key_values.get_seq_length()
|
| 330 |
+
if past_key_values is not None
|
| 331 |
+
else 0
|
| 332 |
+
)
|
| 333 |
+
cache_position = torch.arange(
|
| 334 |
+
past_seen_tokens,
|
| 335 |
+
past_seen_tokens + seq_length,
|
| 336 |
+
device=inputs_embeds.device,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if position_ids is None:
|
| 340 |
+
position_ids = cache_position.unsqueeze(0)
|
| 341 |
+
hidden_states = inputs_embeds
|
| 342 |
+
|
| 343 |
+
if self.causal_mask is None:
|
| 344 |
+
self.causal_mask = precompute_causal_mask(
|
| 345 |
+
max_seq_len=self.model_config.max_seq_len,
|
| 346 |
+
).to(input_ids.device)
|
| 347 |
+
|
| 348 |
+
if self.freqs_cis is None:
|
| 349 |
+
self.freqs_cis = precompute_freqs_cis(
|
| 350 |
+
seq_len=self.model_config.max_seq_len,
|
| 351 |
+
n_elem=self.model_config.hidden_size
|
| 352 |
+
// self.model_config.num_attention_heads,
|
| 353 |
+
base=500000,
|
| 354 |
+
dtype=hidden_states.dtype,
|
| 355 |
+
).to(input_ids.device)
|
| 356 |
+
|
| 357 |
+
freqs_cis = self.freqs_cis[cache_position]
|
| 358 |
+
|
| 359 |
+
if use_cache is True:
|
| 360 |
+
causal_mask = self.causal_mask[None, None, cache_position]
|
| 361 |
+
else:
|
| 362 |
+
causal_mask = self.causal_mask[None, None, :seq_length, :seq_length]
|
| 363 |
+
|
| 364 |
+
if attention_mask is not None:
|
| 365 |
+
pad_len = causal_mask.shape[3] - attention_mask.shape[1]
|
| 366 |
+
padded_attention_mask = F.pad(attention_mask, (0, pad_len), value=1)
|
| 367 |
+
padded_attention_mask = padded_attention_mask[:, None, None, :]
|
| 368 |
+
padded_attention_mask = padded_attention_mask.bool()
|
| 369 |
+
|
| 370 |
+
causal_mask = causal_mask & padded_attention_mask
|
| 371 |
+
|
| 372 |
+
kwargs = {
|
| 373 |
+
"position_ids": position_ids,
|
| 374 |
+
"past_key_values": past_key_values,
|
| 375 |
+
"use_cache": use_cache,
|
| 376 |
+
"output_attentions": output_attentions,
|
| 377 |
+
"output_hidden_states": output_hidden_states,
|
| 378 |
+
"return_dict": return_dict,
|
| 379 |
+
"cache_position": cache_position,
|
| 380 |
+
}
|
| 381 |
+
next_decoder_cache = None
|
| 382 |
+
if self.gradient_checkpointing:
|
| 383 |
+
for layer in self.encode_layers:
|
| 384 |
+
|
| 385 |
+
def create_custom_forward(module):
|
| 386 |
+
def custom_forward(*args):
|
| 387 |
+
return module(*args)[0]
|
| 388 |
+
|
| 389 |
+
return custom_forward
|
| 390 |
+
|
| 391 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 392 |
+
create_custom_forward(layer),
|
| 393 |
+
hidden_states,
|
| 394 |
+
causal_mask,
|
| 395 |
+
freqs_cis,
|
| 396 |
+
**kwargs,
|
| 397 |
+
preserve_rng_state=True,
|
| 398 |
+
use_reentrant=True,
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
all_attentions = () if output_attentions else None
|
| 402 |
+
all_hidden_states = () if output_hidden_states else None
|
| 403 |
+
for layer in self.encode_layers:
|
| 404 |
+
if output_hidden_states:
|
| 405 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 406 |
+
outputs = layer(
|
| 407 |
+
hidden_states, causal_mask, freqs_cis=freqs_cis, **kwargs
|
| 408 |
+
)
|
| 409 |
+
hidden_states = outputs[0]
|
| 410 |
+
if use_cache is True:
|
| 411 |
+
next_decoder_cache = outputs[1]
|
| 412 |
+
if output_attentions:
|
| 413 |
+
all_attentions = all_attentions + (
|
| 414 |
+
outputs[2 if use_cache else 1],
|
| 415 |
+
)
|
| 416 |
+
if output_hidden_states:
|
| 417 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 418 |
+
|
| 419 |
+
hidden_states = self.out_layer_norm(hidden_states)
|
| 420 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 421 |
+
|
| 422 |
+
if return_legacy_cache:
|
| 423 |
+
next_cache = next_cache.to_legacy_cache()
|
| 424 |
+
|
| 425 |
+
if not return_dict:
|
| 426 |
+
return tuple(
|
| 427 |
+
v
|
| 428 |
+
for v in [
|
| 429 |
+
hidden_states,
|
| 430 |
+
next_cache,
|
| 431 |
+
all_hidden_states,
|
| 432 |
+
all_attentions,
|
| 433 |
+
]
|
| 434 |
+
if v is not None
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
return BaseModelOutputWithPast(
|
| 438 |
+
last_hidden_state=hidden_states,
|
| 439 |
+
past_key_values=next_cache,
|
| 440 |
+
hidden_states=all_hidden_states,
|
| 441 |
+
attentions=all_attentions,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
|
| 446 |
+
"""Transformer decoder with head for language modelling.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
model_config (ModelConfig): Model config settings.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
def __init__(self, model_config: AriaConfig):
|
| 453 |
+
super().__init__(model_config)
|
| 454 |
+
self.model_config = model_config
|
| 455 |
+
self.max_seq_len = model_config.max_seq_len
|
| 456 |
+
self.model = AriaModel(model_config)
|
| 457 |
+
self.lm_head = nn.Linear(
|
| 458 |
+
model_config.hidden_size, model_config.vocab_size, bias=False
|
| 459 |
+
)
|
| 460 |
+
self.post_init()
|
| 461 |
+
|
| 462 |
+
def forward(
|
| 463 |
+
self,
|
| 464 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 465 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 466 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 467 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 468 |
+
past_key_values: Optional[
|
| 469 |
+
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
|
| 470 |
+
] = None,
|
| 471 |
+
labels: Optional[torch.Tensor] = None,
|
| 472 |
+
use_cache: Optional[bool] = None,
|
| 473 |
+
output_attentions: Optional[bool] = None,
|
| 474 |
+
output_hidden_states: Optional[bool] = None,
|
| 475 |
+
return_dict: Optional[bool] = None,
|
| 476 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 477 |
+
):
|
| 478 |
+
"""Forward pass of Transformer decoder with LM head."""
|
| 479 |
+
return_dict = (
|
| 480 |
+
return_dict
|
| 481 |
+
if return_dict is not None
|
| 482 |
+
else self.model_config.use_return_dict
|
| 483 |
+
)
|
| 484 |
+
outputs = self.model(
|
| 485 |
+
input_ids,
|
| 486 |
+
attention_mask=attention_mask,
|
| 487 |
+
position_ids=position_ids,
|
| 488 |
+
inputs_embeds=inputs_embeds,
|
| 489 |
+
past_key_values=past_key_values,
|
| 490 |
+
use_cache=use_cache,
|
| 491 |
+
output_attentions=output_attentions,
|
| 492 |
+
output_hidden_states=output_hidden_states,
|
| 493 |
+
return_dict=return_dict,
|
| 494 |
+
cache_position=cache_position,
|
| 495 |
+
)
|
| 496 |
+
hidden = outputs[0]
|
| 497 |
+
lm_logits = self.lm_head(hidden)
|
| 498 |
+
|
| 499 |
+
lm_loss = None
|
| 500 |
+
if labels is not None:
|
| 501 |
+
# move labels to correct device to enable model parallelism
|
| 502 |
+
labels = labels.to(lm_logits.device)
|
| 503 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 504 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
| 505 |
+
labels = labels[:, 1:].contiguous()
|
| 506 |
+
loss_fct = CrossEntropyLoss()
|
| 507 |
+
lm_loss = loss_fct(
|
| 508 |
+
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
if not return_dict:
|
| 512 |
+
output = (lm_logits,) + outputs[1:]
|
| 513 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 514 |
+
|
| 515 |
+
return CausalLMOutputWithPast(
|
| 516 |
+
loss=lm_loss,
|
| 517 |
+
logits=lm_logits,
|
| 518 |
+
past_key_values=outputs.past_key_values,
|
| 519 |
+
hidden_states=outputs.hidden_states,
|
| 520 |
+
attentions=outputs.attentions,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
| 525 |
+
"""Transformer decoder embedding head for contrastive learning.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
model_config (ModelConfig): Model config settings.
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
def __init__(self, model_config: AriaConfig):
|
| 532 |
+
super().__init__(model_config)
|
| 533 |
+
assert model_config.embedding_size
|
| 534 |
+
|
| 535 |
+
self.model_config = model_config
|
| 536 |
+
self.max_seq_len = model_config.max_seq_len
|
| 537 |
+
self.model = AriaModel(model_config)
|
| 538 |
+
self.emb_head = nn.Linear(
|
| 539 |
+
model_config.hidden_size, model_config.embedding_size, bias=False
|
| 540 |
+
)
|
| 541 |
+
self.post_init()
|
| 542 |
+
|
| 543 |
+
def get_pooled_embedding(
|
| 544 |
+
self, input_ids: torch.Tensor, embedding: torch.Tensor
|
| 545 |
+
):
|
| 546 |
+
_batch_size = input_ids.shape[0]
|
| 547 |
+
eos_mask = input_ids == self.config.eos_token_id
|
| 548 |
+
if not eos_mask.any(dim=1).all():
|
| 549 |
+
raise ValueError("Each sequence must contain a EOS token")
|
| 550 |
+
eos_pos = eos_mask.int().argmax(dim=1)
|
| 551 |
+
|
| 552 |
+
pooled_embedding = embedding[
|
| 553 |
+
torch.arange(_batch_size, device=input_ids.device), eos_pos
|
| 554 |
+
]
|
| 555 |
+
|
| 556 |
+
return pooled_embedding
|
| 557 |
+
|
| 558 |
+
def forward(
|
| 559 |
+
self,
|
| 560 |
+
input_ids: torch.Tensor,
|
| 561 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 562 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 563 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 564 |
+
past_key_values: Optional[
|
| 565 |
+
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
|
| 566 |
+
] = None,
|
| 567 |
+
labels: Optional[torch.Tensor] = None,
|
| 568 |
+
use_cache: Optional[bool] = None,
|
| 569 |
+
output_attentions: Optional[bool] = None,
|
| 570 |
+
output_hidden_states: Optional[bool] = None,
|
| 571 |
+
return_dict: Optional[bool] = None,
|
| 572 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 573 |
+
):
|
| 574 |
+
"""Forward pass of Transformer decoder with embedding head. Pooled
|
| 575 |
+
embedding is extracted from EOS token."""
|
| 576 |
+
|
| 577 |
+
return_dict = (
|
| 578 |
+
return_dict
|
| 579 |
+
if return_dict is not None
|
| 580 |
+
else self.model_config.use_return_dict
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
if (
|
| 584 |
+
position_ids is not None
|
| 585 |
+
or inputs_embeds is not None
|
| 586 |
+
or past_key_values is not None
|
| 587 |
+
or labels is not None
|
| 588 |
+
or cache_position is not None
|
| 589 |
+
or use_cache
|
| 590 |
+
):
|
| 591 |
+
raise ValueError("Provided args unsupported for embedding head")
|
| 592 |
+
|
| 593 |
+
outputs = self.model(
|
| 594 |
+
input_ids,
|
| 595 |
+
attention_mask=attention_mask,
|
| 596 |
+
output_attentions=output_attentions,
|
| 597 |
+
output_hidden_states=output_hidden_states,
|
| 598 |
+
return_dict=return_dict,
|
| 599 |
+
use_cache=False,
|
| 600 |
+
)
|
| 601 |
+
hidden = outputs[0]
|
| 602 |
+
embedding = self.emb_head(hidden)
|
| 603 |
+
pooled_embedding = self.get_pooled_embedding(
|
| 604 |
+
input_ids=input_ids,
|
| 605 |
+
embedding=embedding,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
if not return_dict:
|
| 609 |
+
output = (pooled_embedding,) + outputs[1:]
|
| 610 |
+
return output
|
| 611 |
+
|
| 612 |
+
return BaseModelOutputWithPoolingAndProjection(
|
| 613 |
+
last_hidden_state=embedding,
|
| 614 |
+
pooler_output=pooled_embedding,
|
| 615 |
+
hidden_states=outputs.hidden_states,
|
| 616 |
+
attentions=outputs.attentions,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def precompute_causal_mask(max_seq_len: int):
|
| 621 |
+
return torch.tril(
|
| 622 |
+
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
|
| 623 |
+
).cuda()
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def precompute_freqs_cis(
|
| 627 |
+
seq_len: int,
|
| 628 |
+
n_elem: int,
|
| 629 |
+
base: int = 500000,
|
| 630 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 631 |
+
):
|
| 632 |
+
freqs = 1.0 / (
|
| 633 |
+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
| 634 |
+
)
|
| 635 |
+
t = torch.arange(seq_len, device=freqs.device)
|
| 636 |
+
freqs = torch.outer(t, freqs)
|
| 637 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 638 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
| 639 |
+
|
| 640 |
+
return cache.to(dtype=dtype)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 644 |
+
"""
|
| 645 |
+
In-place RoPE. Credits to Katherine Crowson:
|
| 646 |
+
x shape (b_sz, s_len, n_head, d_head).
|
| 647 |
+
cos, sin shape (s_len, d_head // 2).
|
| 648 |
+
"""
|
| 649 |
+
|
| 650 |
+
d = x.shape[-1] // 2
|
| 651 |
+
cos = freqs_cis[..., 0][None, :, None]
|
| 652 |
+
sin = freqs_cis[..., 1][None, :, None]
|
| 653 |
+
x1, x2 = x[..., :d], x[..., d : d * 2]
|
| 654 |
+
tmp = x1.clone()
|
| 655 |
+
x1.mul_(cos).addcmul_(x2, sin, value=-1)
|
| 656 |
+
x2.mul_(cos).addcmul_(tmp, sin, value=1)
|
| 657 |
+
return x
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
__all__ = [
|
| 661 |
+
"AriaPreTrainedModel",
|
| 662 |
+
"AriaModel",
|
| 663 |
+
"TransformerBlock",
|
| 664 |
+
"AriaForCausalLM",
|
| 665 |
+
"AriaForSequenceEmbedding",
|
| 666 |
+
]
|
tokenization_aria.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding
|
| 4 |
+
from transformers.utils import logging, TensorType, to_py_obj
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from ariautils.midi import MidiDict
|
| 8 |
+
from ariautils.tokenizer import AbsTokenizer
|
| 9 |
+
from ariautils.tokenizer._base import Token
|
| 10 |
+
except ImportError:
|
| 11 |
+
raise ImportError(
|
| 12 |
+
"ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`."
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
logger = logging.get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AriaTokenizer(PreTrainedTokenizer):
|
| 22 |
+
"""
|
| 23 |
+
Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule.
|
| 24 |
+
|
| 25 |
+
For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts:
|
| 26 |
+
<GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END>
|
| 27 |
+
This way, we expect a continuation that connects PROMPT and GUIDANCE.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
vocab_files_names = {}
|
| 31 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
add_eos_token=True,
|
| 36 |
+
add_dim_token=False,
|
| 37 |
+
clean_up_tokenization_spaces=False,
|
| 38 |
+
use_default_system_prompt=False,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
self._tokenizer = AbsTokenizer()
|
| 42 |
+
|
| 43 |
+
self.add_eos_token = add_eos_token
|
| 44 |
+
self.add_dim_token = add_dim_token
|
| 45 |
+
self.use_default_system_prompt = use_default_system_prompt
|
| 46 |
+
|
| 47 |
+
bos_token = self._tokenizer.bos_tok
|
| 48 |
+
eos_token = self._tokenizer.eos_tok
|
| 49 |
+
pad_token = self._tokenizer.pad_tok
|
| 50 |
+
unk_token = self._tokenizer.unk_tok
|
| 51 |
+
|
| 52 |
+
super().__init__(
|
| 53 |
+
bos_token=bos_token,
|
| 54 |
+
eos_token=eos_token,
|
| 55 |
+
unk_token=unk_token,
|
| 56 |
+
pad_token=pad_token,
|
| 57 |
+
use_default_system_prompt=use_default_system_prompt,
|
| 58 |
+
**kwargs,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def __getstate__(self):
|
| 62 |
+
return {}
|
| 63 |
+
|
| 64 |
+
def __setstate__(self, d):
|
| 65 |
+
raise NotImplementedError()
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def vocab_size(self):
|
| 69 |
+
"""Returns vocab size"""
|
| 70 |
+
return self._tokenizer.vocab_size
|
| 71 |
+
|
| 72 |
+
def get_vocab(self):
|
| 73 |
+
return self._tokenizer.tok_to_id
|
| 74 |
+
|
| 75 |
+
def tokenize(
|
| 76 |
+
self,
|
| 77 |
+
midi_dict: MidiDict,
|
| 78 |
+
add_dim_token: Optional[bool] = None,
|
| 79 |
+
add_eos_token: Optional[bool] = None,
|
| 80 |
+
**kwargs,
|
| 81 |
+
) -> List[Token]:
|
| 82 |
+
return self._tokenizer.tokenize(
|
| 83 |
+
midi_dict=midi_dict,
|
| 84 |
+
add_dim_tok=(
|
| 85 |
+
add_dim_token
|
| 86 |
+
if add_dim_token is not None
|
| 87 |
+
else self.add_dim_token
|
| 88 |
+
),
|
| 89 |
+
add_eos_tok=(
|
| 90 |
+
add_eos_token
|
| 91 |
+
if add_eos_token is not None
|
| 92 |
+
else self.add_eos_token
|
| 93 |
+
),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def _tokenize(
|
| 97 |
+
self,
|
| 98 |
+
midi_dict: MidiDict,
|
| 99 |
+
add_dim_token: Optional[bool] = None,
|
| 100 |
+
add_eos_token: Optional[bool] = None,
|
| 101 |
+
**kwargs,
|
| 102 |
+
) -> List[Token]:
|
| 103 |
+
return self._tokenizer.tokenize(
|
| 104 |
+
midi_dict=midi_dict,
|
| 105 |
+
add_dim_tok=add_dim_token,
|
| 106 |
+
add_eos_tok=add_eos_token,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def __call__(
|
| 110 |
+
self,
|
| 111 |
+
midi_dicts: MidiDict | list[MidiDict],
|
| 112 |
+
padding: bool = False,
|
| 113 |
+
max_length: int | None = None,
|
| 114 |
+
pad_to_multiple_of: int | None = None,
|
| 115 |
+
return_tensors: str | TensorType | None = None,
|
| 116 |
+
return_attention_mask: bool | None = None,
|
| 117 |
+
**kwargs,
|
| 118 |
+
) -> BatchEncoding:
|
| 119 |
+
"""It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design."""
|
| 120 |
+
if isinstance(midi_dicts, MidiDict):
|
| 121 |
+
midi_dicts = [midi_dicts]
|
| 122 |
+
|
| 123 |
+
all_tokens: list[list[int]] = []
|
| 124 |
+
all_attn_masks: list[list[int]] = []
|
| 125 |
+
max_len_encoded = 0
|
| 126 |
+
for md in midi_dicts:
|
| 127 |
+
tokens = self._tokenizer.encode(self._tokenizer.tokenize(md))
|
| 128 |
+
if max_length is not None:
|
| 129 |
+
tokens = tokens[:max_length]
|
| 130 |
+
max_len_encoded = max(max_len_encoded, len(tokens))
|
| 131 |
+
all_tokens.append(tokens)
|
| 132 |
+
all_attn_masks.append([True] * len(tokens))
|
| 133 |
+
|
| 134 |
+
if pad_to_multiple_of is not None:
|
| 135 |
+
max_len_encoded = (
|
| 136 |
+
(max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of
|
| 137 |
+
) * pad_to_multiple_of
|
| 138 |
+
if padding:
|
| 139 |
+
for tokens, attn_mask in zip(all_tokens, all_attn_masks):
|
| 140 |
+
tokens.extend(
|
| 141 |
+
[self._tokenizer.pad_id] * (max_len_encoded - len(tokens))
|
| 142 |
+
)
|
| 143 |
+
attn_mask.extend([False] * (max_len_encoded - len(tokens)))
|
| 144 |
+
|
| 145 |
+
return BatchEncoding(
|
| 146 |
+
{
|
| 147 |
+
"input_ids": all_tokens,
|
| 148 |
+
"attention_masks": all_attn_masks,
|
| 149 |
+
},
|
| 150 |
+
tensor_type=return_tensors,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def decode(self, token_ids: List[int], **kwargs) -> MidiDict:
|
| 154 |
+
token_ids = to_py_obj(token_ids)
|
| 155 |
+
|
| 156 |
+
return self._tokenizer.detokenize(self._tokenizer.decode(token_ids))
|
| 157 |
+
|
| 158 |
+
def batch_decode(
|
| 159 |
+
self, token_ids_list: List[List[Token]], **kwargs
|
| 160 |
+
) -> List[MidiDict]:
|
| 161 |
+
results = []
|
| 162 |
+
for token_ids in token_ids_list:
|
| 163 |
+
results.append(self.decode(token_ids))
|
| 164 |
+
return results
|
| 165 |
+
|
| 166 |
+
def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding:
|
| 167 |
+
midi_dict = MidiDict.from_midi(filename)
|
| 168 |
+
return self(midi_dict, **kwargs)
|
| 169 |
+
|
| 170 |
+
def encode_from_files(
|
| 171 |
+
self, filenames: list[str], **kwargs
|
| 172 |
+
) -> BatchEncoding:
|
| 173 |
+
midi_dicts = [MidiDict.from_midi(file) for file in filenames]
|
| 174 |
+
return self(midi_dicts, **kwargs)
|
| 175 |
+
|
| 176 |
+
def _convert_token_to_id(self, token: Token):
|
| 177 |
+
"""Converts a token (tuple or str) into an id."""
|
| 178 |
+
return self._tokenizer.tok_to_id.get(
|
| 179 |
+
token, self._tokenizer.tok_to_id[self.unk_token]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _convert_id_to_token(self, index: int):
|
| 183 |
+
"""Converts an index (integer) in a token (tuple or str)."""
|
| 184 |
+
return self._tokenizer.id_to_tok.get(index, self.unk_token)
|
| 185 |
+
|
| 186 |
+
def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict:
|
| 187 |
+
"""Converts a sequence of tokens into a single MidiDict."""
|
| 188 |
+
return self._tokenizer.detokenize(tokens)
|
| 189 |
+
|
| 190 |
+
def save_vocabulary(
|
| 191 |
+
self, save_directory, filename_prefix: Optional[str] = None
|
| 192 |
+
) -> Tuple[str]:
|
| 193 |
+
raise NotImplementedError()
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_eos_token": true,
|
| 3 |
+
"add_dim_token": false,
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoTokenizer": [
|
| 6 |
+
"tokenization_aria.AriaTokenizer",
|
| 7 |
+
null
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
"tokenizer_class": "AriaTokenizer"
|
| 11 |
+
}
|