loubb commited on
Commit
e51eaf7
·
verified ·
1 Parent(s): ddfed3c

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,3 +1,88 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - loubb/aria-midi
5
+ language:
6
+ - en
7
+ tags:
8
+ - music
9
+ - MIDI
10
+ - piano
11
+ ---
12
+ # Model
13
+
14
+ `Aria` is a pretrained autoregressive generative model for symbolic music based on the LLaMA 3.2 (1B) architecture. It was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. It has been finetuned to produce realistic continuations of solo-piano compositions as well as to produce general-purpose contrastive MIDI embeddings.
15
+
16
+ This HuggingFace page contains weights and usage instructions for the embedding model. For the pretrained base model, see [aria-medium-base](https://huggingface.co/loubb/aria-medium-base), and for the generative model, see [aria-medium-gen](https://huggingface.co/loubb/aria-medium-gen).
17
+
18
+ 📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/)
19
+ 🚀 Check out the real-time demo in the official [GitHub repository](https://github.com/EleutherAI/aria)
20
+ 📊 Get access to our training dataset [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) to train your own models
21
+
22
+ ## Usage Guidelines
23
+
24
+ Our embedding model was trained to capture composition and performance-level attributes by learning to embed different random slices of transcriptions of solo-piano performances into similar regions of latent space. As the model was trained to produce global embeddings with data augmentation (e.g., pitch, tempo, etc.), it might not be appropriate for every use case. For more information, see our [paper](https://example.com/).
25
+
26
+ ## Quickstart
27
+
28
+ All of our models were trained using MIDI tooling and tokenizer accessible in the [aria-utils](https://github.com/EleutherAI/aria-utils) repository. Install the aria-utils package with pip:
29
+
30
+ ```bash
31
+ pip install git+https://github.com/EleutherAI/aria-utils.git
32
+ ```
33
+
34
+ You can then generate a embedding for a (piano) MIDI file using the transformers library:
35
+
36
+ ```bash
37
+ pip install transformers
38
+ pip install torch
39
+ ```
40
+
41
+ ```python
42
+ from transformers import AutoModelForCausalLM
43
+ from transformers import AutoTokenizer
44
+
45
+ PROMPT_MIDI_LOAD_PATH = "mydir/prompt.midi"
46
+ MAX_SEQ_LEN = 2048
47
+
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ "loubb/aria-medium-embedding",
50
+ trust_remote_code=True,
51
+ )
52
+ tokenizer = AutoTokenizer.from_pretrained(
53
+ "loubb/aria-medium-embedding",
54
+ trust_remote_code=True,
55
+ )
56
+
57
+ prompt = tokenizer.encode_from_file(
58
+ PROMPT_MIDI_LOAD_PATH, return_tensors="pt"
59
+ )
60
+
61
+ # Only sequences up to 2048 are supported.
62
+ # Embedding is extracted from end-of-sequence token
63
+ assert prompt.shape[1] <= MAX_SEQ_LEN
64
+ assert prompt[0, -1] == tokenizer._convert_token_to_id(tokenizer.eos_token)
65
+
66
+ # Alternatively if the sequence is too long:
67
+ prompt = prompt[:, :MAX_SEQ_LEN]
68
+ prompt = prompt[:, -1] = tokenizer._convert_token_to_id(tokenizer.eos_token)
69
+
70
+ # Generate and extract embedding
71
+ outputs = model.forward(prompt).squeeze(0)
72
+ embedding = outputs[-1]
73
+
74
+ ```
75
+
76
+ ## License and Attribution
77
+
78
+ The Aria project has been kindly supported by EleutherAI, Stability AI, as well as by a compute grant from the Ministry of Science and ICT of Korea. Our models and MIDI tooling are released under the Apache-2.0 license. If you use the models or tooling for follow-up work, please cite the paper in which they were introduced:
79
+
80
+ ```bibtex
81
+ @inproceedings{bradshawscaling,
82
+ title={Scaling Self-Supervised Representation Learning for Symbolic Piano Performance},
83
+ author={Bradshaw, Louis and Fan, Honglu and Spangher, Alex and Biderman, Stella and Colton, Simon},
84
+ booktitle={arXiv preprint},
85
+ year={2025},
86
+ url={https://arxiv.org/abs/2504.15071}
87
+ }
88
+ ```
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AriaForSequenceEmbedding"
4
+ ],
5
+ "eos_token_id": 1,
6
+ "pad_token_id": 2,
7
+ "hidden_size": 1536,
8
+ "embedding_size": 512,
9
+ "intermediate_size": 6144,
10
+ "max_seq_len": 2048,
11
+ "model_type": "aria",
12
+ "num_attention_heads": 24,
13
+ "num_hidden_layers": 16,
14
+ "torch_dtype": "bfloat16",
15
+ "transformers_version": "4.45.0",
16
+ "use_cache": false,
17
+ "vocab_size": 17727,
18
+ "auto_map": {
19
+ "AutoConfig": "configuration_aria.AriaConfig",
20
+ "AutoModel": "modeling_aria.AriaModel",
21
+ "AutoModelForCausalLM": "modeling_aria.AriaForSequenceEmbedding"
22
+ }
23
+ }
configuration_aria.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class AriaConfig(PretrainedConfig):
5
+ model_type = "aria"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size: int = 17727,
11
+ hidden_size: int = 1536,
12
+ embedding_size: int | None = None,
13
+ num_hidden_layers: int = 16,
14
+ num_attention_heads: int = 64,
15
+ intermediate_size: int = 6144,
16
+ max_seq_len: int = 8192,
17
+ use_cache: bool = True,
18
+ eos_token_id: int = 1,
19
+ pad_token_id: int = 2,
20
+ tie_word_embeddings: bool = False,
21
+ output_attentions: bool = False,
22
+ output_hidden_states: bool = False,
23
+ return_dict: bool = False,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(
27
+ pad_token_id=pad_token_id,
28
+ eos_token_id=eos_token_id,
29
+ **kwargs,
30
+ )
31
+ self.vocab_size = vocab_size
32
+ self.hidden_size = hidden_size
33
+ self.embedding_size = embedding_size
34
+ self.num_hidden_layers = num_hidden_layers
35
+ self.num_attention_heads = num_attention_heads
36
+ self.intermediate_size = intermediate_size
37
+ self.max_seq_len = max_seq_len
38
+ self.use_cache = use_cache
39
+ self.tie_word_embeddings = tie_word_embeddings
40
+ self.output_attentions = output_attentions
41
+ self.output_hidden_states = output_hidden_states
42
+ self.return_dict = return_dict
43
+
44
+ if self.intermediate_size % self.hidden_size != 0:
45
+ raise ValueError(
46
+ "The intermediate size needs to be divisible by hidden size."
47
+ )
48
+
49
+ if self.hidden_size % self.num_attention_heads != 0:
50
+ raise ValueError(
51
+ "The hidden size needs to be divisible by the number of attention heads."
52
+ )
53
+
54
+ @property
55
+ def ff_mult(self):
56
+ return self.intermediate_size // self.hidden_size
57
+
58
+
59
+ __all__ = ["AriaConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d49acf495f1cf91d26b297f6e902a3464215a4487906f3d6e918fee39ce5477
3
+ size 2528401656
modeling_aria.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is lightly adapted from https://github.com/EleutherAI/aria/blob/main/aria/model.py
2
+
3
+ from typing import Optional, Union, Tuple
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F, CrossEntropyLoss
10
+
11
+ from transformers import Cache, DynamicCache, StaticCache
12
+ from transformers.utils import logging
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ BaseModelOutputWithPoolingAndProjection,
19
+ )
20
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
21
+
22
+ from .configuration_aria import AriaConfig
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class AriaPreTrainedModel(PreTrainedModel):
29
+ config_class = AriaConfig
30
+ base_model_prefix = "aria"
31
+ supports_gradient_checkpointing = True
32
+ _no_split_modules = ["AriaBlock"]
33
+ _skip_keys_device_placement = "past_key_values"
34
+ _supports_flash_attn_2 = False
35
+ _supports_cache_class = True
36
+ _supports_quantized_cache = True
37
+ _supports_static_cache = True
38
+ _supports_sdpa = True
39
+ _supports_flex_attn = False
40
+
41
+ def _init_weights(self, module):
42
+ if isinstance(module, nn.Linear):
43
+ module.weight.data.normal_(
44
+ mean=0.0, std=self.config.initializer_range
45
+ )
46
+ if module.bias is not None:
47
+ module.bias.data.zero_()
48
+ elif isinstance(module, nn.Embedding):
49
+ module.weight.data.normal_(
50
+ mean=0.0, std=self.config.initializer_range
51
+ )
52
+ if module.padding_idx is not None:
53
+ module.weight.data[module.padding_idx].zero_()
54
+ elif isinstance(module, nn.LayerNorm):
55
+ module.bias.data.zero_()
56
+ module.weight.data.fill_(1.0)
57
+
58
+
59
+ class TransformerBlock(nn.Module):
60
+ def __init__(self, model_config: AriaConfig, layer_idx: int):
61
+ super().__init__()
62
+
63
+ self.drop_p = 0.0
64
+ self.n_heads = model_config.num_attention_heads
65
+ self.d_model = model_config.hidden_size
66
+ self.d_head = (
67
+ model_config.hidden_size // model_config.num_attention_heads
68
+ )
69
+ self.max_seq_len = model_config.max_seq_len
70
+ self.layer_idx = layer_idx
71
+
72
+ # Attention
73
+ self.mixed_qkv = nn.Linear(
74
+ in_features=self.d_model,
75
+ out_features=3 * self.d_model,
76
+ bias=False,
77
+ )
78
+ self.att_proj_linear = nn.Linear(
79
+ in_features=self.d_model,
80
+ out_features=self.d_model,
81
+ bias=False,
82
+ )
83
+
84
+ # FF Layer
85
+ self.ff_gate_proj = nn.Linear(
86
+ in_features=self.d_model,
87
+ out_features=self.d_model * model_config.ff_mult,
88
+ bias=False,
89
+ )
90
+ self.ff_up_proj = nn.Linear(
91
+ in_features=self.d_model,
92
+ out_features=self.d_model * model_config.ff_mult,
93
+ bias=False,
94
+ )
95
+ self.ff_down_proj = nn.Linear(
96
+ in_features=self.d_model * model_config.ff_mult,
97
+ out_features=self.d_model,
98
+ bias=False,
99
+ )
100
+
101
+ # Pre layer norms
102
+ self.norm1 = nn.LayerNorm(self.d_model)
103
+ self.norm2 = nn.LayerNorm(self.d_model)
104
+
105
+ def forward(
106
+ self,
107
+ x: torch.Tensor,
108
+ attention_mask: torch.Tensor,
109
+ freqs_cis: torch.Tensor,
110
+ position_ids: Optional[torch.Tensor] = None,
111
+ past_key_values: Optional[
112
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
113
+ ] = None,
114
+ use_cache: Optional[bool] = None,
115
+ output_attentions: Optional[bool] = None,
116
+ output_hidden_states: Optional[bool] = None,
117
+ return_dict: Optional[bool] = None,
118
+ cache_position: Optional[torch.Tensor] = None,
119
+ ):
120
+ attn_output, attn_weights, present = self._att_block(
121
+ self.norm1(x),
122
+ attention_mask,
123
+ freqs_cis,
124
+ past_key_values=past_key_values,
125
+ use_cache=use_cache,
126
+ output_attentions=output_attentions,
127
+ cache_position=cache_position,
128
+ )
129
+
130
+ x = x + attn_output
131
+ x = x + self._ff_block(self.norm2(x))
132
+
133
+ outputs = (x, present)
134
+ if use_cache:
135
+ outputs = (x, present, attn_weights)
136
+ else:
137
+ outputs = (x, attn_weights)
138
+
139
+ return outputs
140
+
141
+ def _att_block(
142
+ self,
143
+ x: torch.Tensor,
144
+ attention_mask: torch.Tensor,
145
+ freqs_cis: torch.Tensor,
146
+ past_key_values: Optional[
147
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
148
+ ] = None,
149
+ use_cache: Optional[bool] = None,
150
+ output_attentions: Optional[bool] = None,
151
+ cache_position: Optional[torch.Tensor] = None,
152
+ ):
153
+ batch_size, seq_len, _ = x.shape
154
+ mixed_qkv = self.mixed_qkv(x)
155
+ xq, xk, xv = mixed_qkv.chunk(3, -1)
156
+
157
+ # Reshape for rotary embeddings
158
+ # Need contiguous for q, k since in-place RoPE cannot be applied on a view
159
+ xq = xq.reshape(
160
+ batch_size, seq_len, self.n_heads, self.d_head
161
+ ).contiguous()
162
+ xk = xk.reshape(
163
+ batch_size, seq_len, self.n_heads, self.d_head
164
+ ).contiguous()
165
+ xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head)
166
+
167
+ # apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head)
168
+ xq = apply_rotary_emb(xq, freqs_cis)
169
+ xk = apply_rotary_emb(xk, freqs_cis)
170
+ xq, xk, xv = map(lambda t: t.transpose(1, 2), (xq, xk, xv))
171
+
172
+ if past_key_values is not None:
173
+ cache_kwargs = {
174
+ # "sin": sin,
175
+ # "cos": cos,
176
+ # "partial_rotation_size": self.rotary_ndims,
177
+ "cache_position": cache_position,
178
+ }
179
+ xk, xv = past_key_values.update(
180
+ xk, xv, self.layer_idx, cache_kwargs
181
+ )
182
+
183
+ att = F.scaled_dot_product_attention(
184
+ query=xq,
185
+ key=xk,
186
+ value=xv,
187
+ attn_mask=attention_mask[..., : xk.shape[2]],
188
+ )
189
+
190
+ # Reshape for out: (b_sz, s_len, n_head, d_head)
191
+ out = att.transpose(1, 2).contiguous()
192
+ out = out.view(batch_size, seq_len, self.n_heads * self.d_head)
193
+
194
+ if not output_attentions:
195
+ att = None
196
+
197
+ return self.att_proj_linear(out), att, past_key_values
198
+
199
+ def _ff_block(self, x: torch.Tensor):
200
+ return self.ff_down_proj(
201
+ F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x)
202
+ )
203
+
204
+
205
+ class AriaModel(AriaPreTrainedModel):
206
+ """Transformer decoder with no language model head.
207
+
208
+ Args:
209
+ model_config (ModelConfig): Model config settings.
210
+ """
211
+
212
+ def __init__(self, model_config: AriaConfig):
213
+ super().__init__(model_config)
214
+ self.model_config = model_config
215
+ self.freqs_cis = None
216
+ self.causal_mask = None
217
+
218
+ self.tok_embeddings = nn.Embedding(
219
+ num_embeddings=model_config.vocab_size,
220
+ embedding_dim=model_config.hidden_size,
221
+ )
222
+
223
+ self.out_layer_norm = nn.LayerNorm(model_config.hidden_size)
224
+ self.encode_layers = nn.ModuleList()
225
+ for i in range(model_config.num_hidden_layers):
226
+ self.encode_layers.append(TransformerBlock(model_config, i))
227
+
228
+ self.gradient_checkpointing = False
229
+ self.post_init()
230
+
231
+ def forward(
232
+ self,
233
+ input_ids: Optional[torch.Tensor] = None,
234
+ attention_mask: Optional[torch.Tensor] = None,
235
+ position_ids: Optional[torch.Tensor] = None,
236
+ inputs_embeds: Optional[torch.Tensor] = None,
237
+ past_key_values: Optional[
238
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
239
+ ] = None,
240
+ use_cache: Optional[bool] = None,
241
+ output_attentions: Optional[bool] = None,
242
+ output_hidden_states: Optional[bool] = None,
243
+ return_dict: Optional[bool] = None,
244
+ cache_position: Optional[torch.Tensor] = None,
245
+ ):
246
+ """Forward pass of Transformer.
247
+
248
+ Args:
249
+ src (torch.tensor): Input to encoder block, of shape (batch_size,
250
+ seq_len, d_model).
251
+ attn_mask (Optional[torch.tensor]): Attention mask of shape
252
+ (batch_size, seq_len). Defaults to None.
253
+ past_kv (Optional[list[KVCache]]): a list of kv caches. The list index
254
+ corresponds to the layer index.
255
+
256
+ Returns:
257
+ torch.tensor: Model outputs with shape (batch_size, seq_len,
258
+ d_model).
259
+ """
260
+ if (
261
+ input_ids is not None
262
+ and input_ids.shape[1] > self.model_config.max_seq_len
263
+ ):
264
+ raise ValueError(
265
+ f"Sequence length ({input_ids.shape[1]}) exceeds max_seq_len "
266
+ f"({self.model_config.max_seq_len})."
267
+ )
268
+ if (
269
+ inputs_embeds is not None
270
+ and inputs_embeds.shape[1] > self.model_config.max_seq_len
271
+ ):
272
+ raise ValueError(
273
+ f"Sequence length ({inputs_embeds.shape[1]}) exceeds max_seq_len "
274
+ f"({self.model_config.max_seq_len})."
275
+ )
276
+
277
+ output_attentions = (
278
+ output_attentions
279
+ if output_attentions is not None
280
+ else self.model_config.output_attentions
281
+ )
282
+ output_hidden_states = (
283
+ output_hidden_states
284
+ if output_hidden_states is not None
285
+ else self.model_config.output_hidden_states
286
+ )
287
+ return_dict = (
288
+ return_dict
289
+ if return_dict is not None
290
+ else self.model_config.use_return_dict
291
+ )
292
+ use_cache = (
293
+ use_cache if use_cache is not None else self.model_config.use_cache
294
+ )
295
+
296
+ if (input_ids is None) ^ (inputs_embeds is not None):
297
+ raise ValueError(
298
+ "You must specify exactly one of input_ids or inputs_embeds"
299
+ )
300
+
301
+ if self.gradient_checkpointing and self.training:
302
+ if use_cache:
303
+ logger.warning_once(
304
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
305
+ )
306
+ use_cache = False
307
+
308
+ if inputs_embeds is None:
309
+ inputs_embeds = self.tok_embeddings(input_ids)
310
+
311
+ return_legacy_cache = False
312
+ if use_cache and not isinstance(past_key_values, Cache):
313
+ return_legacy_cache = True
314
+ if past_key_values is None:
315
+ past_key_values = DynamicCache()
316
+ else:
317
+ past_key_values = DynamicCache.from_legacy_cache(
318
+ past_key_values
319
+ )
320
+ logger.warning_once(
321
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
322
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
323
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
324
+ )
325
+
326
+ seq_length = inputs_embeds.shape[1]
327
+ if cache_position is None:
328
+ past_seen_tokens = (
329
+ past_key_values.get_seq_length()
330
+ if past_key_values is not None
331
+ else 0
332
+ )
333
+ cache_position = torch.arange(
334
+ past_seen_tokens,
335
+ past_seen_tokens + seq_length,
336
+ device=inputs_embeds.device,
337
+ )
338
+
339
+ if position_ids is None:
340
+ position_ids = cache_position.unsqueeze(0)
341
+ hidden_states = inputs_embeds
342
+
343
+ if self.causal_mask is None:
344
+ self.causal_mask = precompute_causal_mask(
345
+ max_seq_len=self.model_config.max_seq_len,
346
+ ).to(input_ids.device)
347
+
348
+ if self.freqs_cis is None:
349
+ self.freqs_cis = precompute_freqs_cis(
350
+ seq_len=self.model_config.max_seq_len,
351
+ n_elem=self.model_config.hidden_size
352
+ // self.model_config.num_attention_heads,
353
+ base=500000,
354
+ dtype=hidden_states.dtype,
355
+ ).to(input_ids.device)
356
+
357
+ freqs_cis = self.freqs_cis[cache_position]
358
+
359
+ if use_cache is True:
360
+ causal_mask = self.causal_mask[None, None, cache_position]
361
+ else:
362
+ causal_mask = self.causal_mask[None, None, :seq_length, :seq_length]
363
+
364
+ if attention_mask is not None:
365
+ pad_len = causal_mask.shape[3] - attention_mask.shape[1]
366
+ padded_attention_mask = F.pad(attention_mask, (0, pad_len), value=1)
367
+ padded_attention_mask = padded_attention_mask[:, None, None, :]
368
+ padded_attention_mask = padded_attention_mask.bool()
369
+
370
+ causal_mask = causal_mask & padded_attention_mask
371
+
372
+ kwargs = {
373
+ "position_ids": position_ids,
374
+ "past_key_values": past_key_values,
375
+ "use_cache": use_cache,
376
+ "output_attentions": output_attentions,
377
+ "output_hidden_states": output_hidden_states,
378
+ "return_dict": return_dict,
379
+ "cache_position": cache_position,
380
+ }
381
+ next_decoder_cache = None
382
+ if self.gradient_checkpointing:
383
+ for layer in self.encode_layers:
384
+
385
+ def create_custom_forward(module):
386
+ def custom_forward(*args):
387
+ return module(*args)[0]
388
+
389
+ return custom_forward
390
+
391
+ hidden_states = torch.utils.checkpoint.checkpoint(
392
+ create_custom_forward(layer),
393
+ hidden_states,
394
+ causal_mask,
395
+ freqs_cis,
396
+ **kwargs,
397
+ preserve_rng_state=True,
398
+ use_reentrant=True,
399
+ )
400
+ else:
401
+ all_attentions = () if output_attentions else None
402
+ all_hidden_states = () if output_hidden_states else None
403
+ for layer in self.encode_layers:
404
+ if output_hidden_states:
405
+ all_hidden_states = all_hidden_states + (hidden_states,)
406
+ outputs = layer(
407
+ hidden_states, causal_mask, freqs_cis=freqs_cis, **kwargs
408
+ )
409
+ hidden_states = outputs[0]
410
+ if use_cache is True:
411
+ next_decoder_cache = outputs[1]
412
+ if output_attentions:
413
+ all_attentions = all_attentions + (
414
+ outputs[2 if use_cache else 1],
415
+ )
416
+ if output_hidden_states:
417
+ all_hidden_states = all_hidden_states + (hidden_states,)
418
+
419
+ hidden_states = self.out_layer_norm(hidden_states)
420
+ next_cache = next_decoder_cache if use_cache else None
421
+
422
+ if return_legacy_cache:
423
+ next_cache = next_cache.to_legacy_cache()
424
+
425
+ if not return_dict:
426
+ return tuple(
427
+ v
428
+ for v in [
429
+ hidden_states,
430
+ next_cache,
431
+ all_hidden_states,
432
+ all_attentions,
433
+ ]
434
+ if v is not None
435
+ )
436
+
437
+ return BaseModelOutputWithPast(
438
+ last_hidden_state=hidden_states,
439
+ past_key_values=next_cache,
440
+ hidden_states=all_hidden_states,
441
+ attentions=all_attentions,
442
+ )
443
+
444
+
445
+ class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
446
+ """Transformer decoder with head for language modelling.
447
+
448
+ Args:
449
+ model_config (ModelConfig): Model config settings.
450
+ """
451
+
452
+ def __init__(self, model_config: AriaConfig):
453
+ super().__init__(model_config)
454
+ self.model_config = model_config
455
+ self.max_seq_len = model_config.max_seq_len
456
+ self.model = AriaModel(model_config)
457
+ self.lm_head = nn.Linear(
458
+ model_config.hidden_size, model_config.vocab_size, bias=False
459
+ )
460
+ self.post_init()
461
+
462
+ def forward(
463
+ self,
464
+ input_ids: Optional[torch.Tensor] = None,
465
+ attention_mask: Optional[torch.Tensor] = None,
466
+ position_ids: Optional[torch.Tensor] = None,
467
+ inputs_embeds: Optional[torch.Tensor] = None,
468
+ past_key_values: Optional[
469
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
470
+ ] = None,
471
+ labels: Optional[torch.Tensor] = None,
472
+ use_cache: Optional[bool] = None,
473
+ output_attentions: Optional[bool] = None,
474
+ output_hidden_states: Optional[bool] = None,
475
+ return_dict: Optional[bool] = None,
476
+ cache_position: Optional[torch.Tensor] = None,
477
+ ):
478
+ """Forward pass of Transformer decoder with LM head."""
479
+ return_dict = (
480
+ return_dict
481
+ if return_dict is not None
482
+ else self.model_config.use_return_dict
483
+ )
484
+ outputs = self.model(
485
+ input_ids,
486
+ attention_mask=attention_mask,
487
+ position_ids=position_ids,
488
+ inputs_embeds=inputs_embeds,
489
+ past_key_values=past_key_values,
490
+ use_cache=use_cache,
491
+ output_attentions=output_attentions,
492
+ output_hidden_states=output_hidden_states,
493
+ return_dict=return_dict,
494
+ cache_position=cache_position,
495
+ )
496
+ hidden = outputs[0]
497
+ lm_logits = self.lm_head(hidden)
498
+
499
+ lm_loss = None
500
+ if labels is not None:
501
+ # move labels to correct device to enable model parallelism
502
+ labels = labels.to(lm_logits.device)
503
+ # we are doing next-token prediction; shift prediction scores and input ids by one
504
+ shift_logits = lm_logits[:, :-1, :].contiguous()
505
+ labels = labels[:, 1:].contiguous()
506
+ loss_fct = CrossEntropyLoss()
507
+ lm_loss = loss_fct(
508
+ shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
509
+ )
510
+
511
+ if not return_dict:
512
+ output = (lm_logits,) + outputs[1:]
513
+ return ((lm_loss,) + output) if lm_loss is not None else output
514
+
515
+ return CausalLMOutputWithPast(
516
+ loss=lm_loss,
517
+ logits=lm_logits,
518
+ past_key_values=outputs.past_key_values,
519
+ hidden_states=outputs.hidden_states,
520
+ attentions=outputs.attentions,
521
+ )
522
+
523
+
524
+ class AriaForSequenceEmbedding(AriaPreTrainedModel):
525
+ """Transformer decoder embedding head for contrastive learning.
526
+
527
+ Args:
528
+ model_config (ModelConfig): Model config settings.
529
+ """
530
+
531
+ def __init__(self, model_config: AriaConfig):
532
+ super().__init__(model_config)
533
+ assert model_config.embedding_size
534
+
535
+ self.model_config = model_config
536
+ self.max_seq_len = model_config.max_seq_len
537
+ self.model = AriaModel(model_config)
538
+ self.emb_head = nn.Linear(
539
+ model_config.hidden_size, model_config.embedding_size, bias=False
540
+ )
541
+ self.post_init()
542
+
543
+ def get_pooled_embedding(
544
+ self, input_ids: torch.Tensor, embedding: torch.Tensor
545
+ ):
546
+ _batch_size = input_ids.shape[0]
547
+ eos_mask = input_ids == self.config.eos_token_id
548
+ if not eos_mask.any(dim=1).all():
549
+ raise ValueError("Each sequence must contain a EOS token")
550
+ eos_pos = eos_mask.int().argmax(dim=1)
551
+
552
+ pooled_embedding = embedding[
553
+ torch.arange(_batch_size, device=input_ids.device), eos_pos
554
+ ]
555
+
556
+ return pooled_embedding
557
+
558
+ def forward(
559
+ self,
560
+ input_ids: torch.Tensor,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ position_ids: Optional[torch.Tensor] = None,
563
+ inputs_embeds: Optional[torch.Tensor] = None,
564
+ past_key_values: Optional[
565
+ Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
566
+ ] = None,
567
+ labels: Optional[torch.Tensor] = None,
568
+ use_cache: Optional[bool] = None,
569
+ output_attentions: Optional[bool] = None,
570
+ output_hidden_states: Optional[bool] = None,
571
+ return_dict: Optional[bool] = None,
572
+ cache_position: Optional[torch.Tensor] = None,
573
+ ):
574
+ """Forward pass of Transformer decoder with embedding head. Pooled
575
+ embedding is extracted from EOS token."""
576
+
577
+ return_dict = (
578
+ return_dict
579
+ if return_dict is not None
580
+ else self.model_config.use_return_dict
581
+ )
582
+
583
+ if (
584
+ position_ids is not None
585
+ or inputs_embeds is not None
586
+ or past_key_values is not None
587
+ or labels is not None
588
+ or cache_position is not None
589
+ or use_cache
590
+ ):
591
+ raise ValueError("Provided args unsupported for embedding head")
592
+
593
+ outputs = self.model(
594
+ input_ids,
595
+ attention_mask=attention_mask,
596
+ output_attentions=output_attentions,
597
+ output_hidden_states=output_hidden_states,
598
+ return_dict=return_dict,
599
+ use_cache=False,
600
+ )
601
+ hidden = outputs[0]
602
+ embedding = self.emb_head(hidden)
603
+ pooled_embedding = self.get_pooled_embedding(
604
+ input_ids=input_ids,
605
+ embedding=embedding,
606
+ )
607
+
608
+ if not return_dict:
609
+ output = (pooled_embedding,) + outputs[1:]
610
+ return output
611
+
612
+ return BaseModelOutputWithPoolingAndProjection(
613
+ last_hidden_state=embedding,
614
+ pooler_output=pooled_embedding,
615
+ hidden_states=outputs.hidden_states,
616
+ attentions=outputs.attentions,
617
+ )
618
+
619
+
620
+ def precompute_causal_mask(max_seq_len: int):
621
+ return torch.tril(
622
+ torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
623
+ ).cuda()
624
+
625
+
626
+ def precompute_freqs_cis(
627
+ seq_len: int,
628
+ n_elem: int,
629
+ base: int = 500000,
630
+ dtype: torch.dtype = torch.bfloat16,
631
+ ):
632
+ freqs = 1.0 / (
633
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
634
+ )
635
+ t = torch.arange(seq_len, device=freqs.device)
636
+ freqs = torch.outer(t, freqs)
637
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
638
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
639
+
640
+ return cache.to(dtype=dtype)
641
+
642
+
643
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
644
+ """
645
+ In-place RoPE. Credits to Katherine Crowson:
646
+ x shape (b_sz, s_len, n_head, d_head).
647
+ cos, sin shape (s_len, d_head // 2).
648
+ """
649
+
650
+ d = x.shape[-1] // 2
651
+ cos = freqs_cis[..., 0][None, :, None]
652
+ sin = freqs_cis[..., 1][None, :, None]
653
+ x1, x2 = x[..., :d], x[..., d : d * 2]
654
+ tmp = x1.clone()
655
+ x1.mul_(cos).addcmul_(x2, sin, value=-1)
656
+ x2.mul_(cos).addcmul_(tmp, sin, value=1)
657
+ return x
658
+
659
+
660
+ __all__ = [
661
+ "AriaPreTrainedModel",
662
+ "AriaModel",
663
+ "TransformerBlock",
664
+ "AriaForCausalLM",
665
+ "AriaForSequenceEmbedding",
666
+ ]
tokenization_aria.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple
2
+
3
+ from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding
4
+ from transformers.utils import logging, TensorType, to_py_obj
5
+
6
+ try:
7
+ from ariautils.midi import MidiDict
8
+ from ariautils.tokenizer import AbsTokenizer
9
+ from ariautils.tokenizer._base import Token
10
+ except ImportError:
11
+ raise ImportError(
12
+ "ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`."
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ pass
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class AriaTokenizer(PreTrainedTokenizer):
22
+ """
23
+ Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule.
24
+
25
+ For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts:
26
+ <GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END>
27
+ This way, we expect a continuation that connects PROMPT and GUIDANCE.
28
+ """
29
+
30
+ vocab_files_names = {}
31
+ model_input_names = ["input_ids", "attention_mask"]
32
+
33
+ def __init__(
34
+ self,
35
+ add_eos_token=True,
36
+ add_dim_token=False,
37
+ clean_up_tokenization_spaces=False,
38
+ use_default_system_prompt=False,
39
+ **kwargs,
40
+ ):
41
+ self._tokenizer = AbsTokenizer()
42
+
43
+ self.add_eos_token = add_eos_token
44
+ self.add_dim_token = add_dim_token
45
+ self.use_default_system_prompt = use_default_system_prompt
46
+
47
+ bos_token = self._tokenizer.bos_tok
48
+ eos_token = self._tokenizer.eos_tok
49
+ pad_token = self._tokenizer.pad_tok
50
+ unk_token = self._tokenizer.unk_tok
51
+
52
+ super().__init__(
53
+ bos_token=bos_token,
54
+ eos_token=eos_token,
55
+ unk_token=unk_token,
56
+ pad_token=pad_token,
57
+ use_default_system_prompt=use_default_system_prompt,
58
+ **kwargs,
59
+ )
60
+
61
+ def __getstate__(self):
62
+ return {}
63
+
64
+ def __setstate__(self, d):
65
+ raise NotImplementedError()
66
+
67
+ @property
68
+ def vocab_size(self):
69
+ """Returns vocab size"""
70
+ return self._tokenizer.vocab_size
71
+
72
+ def get_vocab(self):
73
+ return self._tokenizer.tok_to_id
74
+
75
+ def tokenize(
76
+ self,
77
+ midi_dict: MidiDict,
78
+ add_dim_token: Optional[bool] = None,
79
+ add_eos_token: Optional[bool] = None,
80
+ **kwargs,
81
+ ) -> List[Token]:
82
+ return self._tokenizer.tokenize(
83
+ midi_dict=midi_dict,
84
+ add_dim_tok=(
85
+ add_dim_token
86
+ if add_dim_token is not None
87
+ else self.add_dim_token
88
+ ),
89
+ add_eos_tok=(
90
+ add_eos_token
91
+ if add_eos_token is not None
92
+ else self.add_eos_token
93
+ ),
94
+ )
95
+
96
+ def _tokenize(
97
+ self,
98
+ midi_dict: MidiDict,
99
+ add_dim_token: Optional[bool] = None,
100
+ add_eos_token: Optional[bool] = None,
101
+ **kwargs,
102
+ ) -> List[Token]:
103
+ return self._tokenizer.tokenize(
104
+ midi_dict=midi_dict,
105
+ add_dim_tok=add_dim_token,
106
+ add_eos_tok=add_eos_token,
107
+ )
108
+
109
+ def __call__(
110
+ self,
111
+ midi_dicts: MidiDict | list[MidiDict],
112
+ padding: bool = False,
113
+ max_length: int | None = None,
114
+ pad_to_multiple_of: int | None = None,
115
+ return_tensors: str | TensorType | None = None,
116
+ return_attention_mask: bool | None = None,
117
+ **kwargs,
118
+ ) -> BatchEncoding:
119
+ """It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design."""
120
+ if isinstance(midi_dicts, MidiDict):
121
+ midi_dicts = [midi_dicts]
122
+
123
+ all_tokens: list[list[int]] = []
124
+ all_attn_masks: list[list[int]] = []
125
+ max_len_encoded = 0
126
+ for md in midi_dicts:
127
+ tokens = self._tokenizer.encode(self._tokenizer.tokenize(md))
128
+ if max_length is not None:
129
+ tokens = tokens[:max_length]
130
+ max_len_encoded = max(max_len_encoded, len(tokens))
131
+ all_tokens.append(tokens)
132
+ all_attn_masks.append([True] * len(tokens))
133
+
134
+ if pad_to_multiple_of is not None:
135
+ max_len_encoded = (
136
+ (max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of
137
+ ) * pad_to_multiple_of
138
+ if padding:
139
+ for tokens, attn_mask in zip(all_tokens, all_attn_masks):
140
+ tokens.extend(
141
+ [self._tokenizer.pad_id] * (max_len_encoded - len(tokens))
142
+ )
143
+ attn_mask.extend([False] * (max_len_encoded - len(tokens)))
144
+
145
+ return BatchEncoding(
146
+ {
147
+ "input_ids": all_tokens,
148
+ "attention_masks": all_attn_masks,
149
+ },
150
+ tensor_type=return_tensors,
151
+ )
152
+
153
+ def decode(self, token_ids: List[int], **kwargs) -> MidiDict:
154
+ token_ids = to_py_obj(token_ids)
155
+
156
+ return self._tokenizer.detokenize(self._tokenizer.decode(token_ids))
157
+
158
+ def batch_decode(
159
+ self, token_ids_list: List[List[Token]], **kwargs
160
+ ) -> List[MidiDict]:
161
+ results = []
162
+ for token_ids in token_ids_list:
163
+ results.append(self.decode(token_ids))
164
+ return results
165
+
166
+ def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding:
167
+ midi_dict = MidiDict.from_midi(filename)
168
+ return self(midi_dict, **kwargs)
169
+
170
+ def encode_from_files(
171
+ self, filenames: list[str], **kwargs
172
+ ) -> BatchEncoding:
173
+ midi_dicts = [MidiDict.from_midi(file) for file in filenames]
174
+ return self(midi_dicts, **kwargs)
175
+
176
+ def _convert_token_to_id(self, token: Token):
177
+ """Converts a token (tuple or str) into an id."""
178
+ return self._tokenizer.tok_to_id.get(
179
+ token, self._tokenizer.tok_to_id[self.unk_token]
180
+ )
181
+
182
+ def _convert_id_to_token(self, index: int):
183
+ """Converts an index (integer) in a token (tuple or str)."""
184
+ return self._tokenizer.id_to_tok.get(index, self.unk_token)
185
+
186
+ def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict:
187
+ """Converts a sequence of tokens into a single MidiDict."""
188
+ return self._tokenizer.detokenize(tokens)
189
+
190
+ def save_vocabulary(
191
+ self, save_directory, filename_prefix: Optional[str] = None
192
+ ) -> Tuple[str]:
193
+ raise NotImplementedError()
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_eos_token": true,
3
+ "add_dim_token": false,
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_aria.AriaTokenizer",
7
+ null
8
+ ]
9
+ },
10
+ "tokenizer_class": "AriaTokenizer"
11
+ }