TundraandTabor commited on
Commit
12b2634
·
verified ·
1 Parent(s): ac50976

Upload 38 files

Browse files
Files changed (39) hide show
  1. .gitattributes +3 -0
  2. 1_batch_xml2abc.py +54 -0
  3. 2_data_preprocess.py +181 -0
  4. 3_batch_abc2xml.py +56 -0
  5. LICENSE.txt +21 -0
  6. README (1).md +31 -0
  7. README (2).md +293 -0
  8. README.md +44 -0
  9. abc2xml (1).py +0 -0
  10. abc2xml (2).py +0 -0
  11. abc2xml.py +0 -0
  12. config (1).py +67 -0
  13. config (2).py +38 -0
  14. config (3).py +15 -0
  15. config (4).py +18 -0
  16. config (5).py +39 -0
  17. config.py +35 -0
  18. data.py +136 -0
  19. demo.ipynb +821 -0
  20. demo.py +236 -0
  21. extract_clamp2.py +194 -0
  22. illustration.png +3 -0
  23. illustration_online.png +3 -0
  24. inference (1).py +271 -0
  25. inference.py +318 -0
  26. notagen.png +3 -0
  27. prompts.txt +112 -0
  28. requirements (6).txt +7 -0
  29. statistics.py +68 -0
  30. train-gen (1).py +325 -0
  31. train-gen.py +374 -0
  32. train.py +186 -0
  33. utils (1).py +483 -0
  34. utils (2).py +423 -0
  35. utils (3).py +423 -0
  36. utils (4).py +423 -0
  37. utils (5).py +421 -0
  38. utils.py +406 -0
  39. xml2abc.py +1609 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ illustration_online.png filter=lfs diff=lfs merge=lfs -text
37
+ illustration.png filter=lfs diff=lfs merge=lfs -text
38
+ notagen.png filter=lfs diff=lfs merge=lfs -text
1_batch_xml2abc.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ORI_FOLDER = "" # Replace with the path to your folder containing XML (.xml, .mxl, .musicxml) files
2
+ DES_FOLDER = "" # The script will convert the musicxml files and output standard abc notation files to this folder
3
+
4
+ import os
5
+ import math
6
+ import random
7
+ import subprocess
8
+ from tqdm import tqdm
9
+ from multiprocessing import Pool
10
+
11
+
12
+ def convert_xml2abc(file_list):
13
+ cmd = 'python xml2abc.py -d 8 -c 6 -x '
14
+ for file in tqdm(file_list):
15
+ filename = os.path.basename(file)
16
+ os.makedirs(DES_FOLDER, exist_ok=True)
17
+
18
+ try:
19
+ p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
20
+ result = p.communicate()
21
+ output = result[0].decode('utf-8')
22
+
23
+ if output == '':
24
+ with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
25
+ f.write(file + '\n')
26
+ continue
27
+ else:
28
+ with open(os.path.join(DES_FOLDER, filename.rsplit('.', 1)[0] + '.abc'), 'w', encoding='utf-8') as f:
29
+ f.write(output)
30
+ except Exception as e:
31
+ with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
32
+ f.write(file + ' ' + str(e) + '\n')
33
+
34
+
35
+ if __name__ == '__main__':
36
+ file_list = []
37
+ os.makedirs("logs", exist_ok=True)
38
+
39
+ # Traverse the specified folder for XML/MXL files
40
+ for root, dirs, files in os.walk(os.path.abspath(ORI_FOLDER)):
41
+ for file in files:
42
+ if file.endswith((".mxl", ".xml", ".musicxml")):
43
+ filename = os.path.join(root, file).replace("\\", "/")
44
+ file_list.append(filename)
45
+
46
+ # Shuffle and prepare for multiprocessing
47
+ random.shuffle(file_list)
48
+ num_files = len(file_list)
49
+ num_processes = os.cpu_count()
50
+ file_lists = [file_list[i::num_processes] for i in range(num_processes)]
51
+
52
+ # Create a pool for processing
53
+ with Pool(processes=num_processes) as pool:
54
+ pool.map(convert_xml2abc, file_lists)
2_data_preprocess.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ORI_FOLDER = '' # Replace with the path to your folder containing standard ABC notation files
2
+ INTERLEAVED_FOLDER = '' # Output interleaved ABC notation files to this folder
3
+ AUGMENTED_FOLDER = '' # Output key-augmented and rest-omitted ABC notation files to this folder
4
+ EVAL_SPLIT = 0.1 # The ratio of eval data
5
+
6
+ import os
7
+ import re
8
+ import json
9
+ import shutil
10
+ import random
11
+ from tqdm import tqdm
12
+ from abctoolkit.utils import (
13
+ remove_information_field,
14
+ remove_bar_no_annotations,
15
+ Quote_re,
16
+ Barlines,
17
+ extract_metadata_and_parts,
18
+ extract_global_and_local_metadata,
19
+ extract_barline_and_bartext_dict)
20
+ from abctoolkit.convert import unidecode_abc_lines
21
+ from abctoolkit.rotate import rotate_abc
22
+ from abctoolkit.check import check_alignment_unrotated
23
+ from abctoolkit.transpose import Key2index, transpose_an_abc_text
24
+
25
+ os.makedirs(INTERLEAVED_FOLDER, exist_ok=True)
26
+ os.makedirs(AUGMENTED_FOLDER, exist_ok=True)
27
+ for key in Key2index.keys():
28
+ key_folder = os.path.join(AUGMENTED_FOLDER, key)
29
+ os.makedirs(key_folder, exist_ok=True)
30
+
31
+
32
+ def abc_preprocess_pipeline(abc_path):
33
+
34
+ with open(abc_path, 'r', encoding='utf-8') as f:
35
+ abc_lines = f.readlines()
36
+
37
+ # delete blank lines
38
+ abc_lines = [line for line in abc_lines if line.strip() != '']
39
+
40
+ # unidecode
41
+ abc_lines = unidecode_abc_lines(abc_lines)
42
+
43
+ # clean information field
44
+ abc_lines = remove_information_field(abc_lines=abc_lines, info_fields=['X:', 'T:', 'C:', 'W:', 'w:', 'Z:', '%%MIDI'])
45
+
46
+ # delete bar number annotations
47
+ abc_lines = remove_bar_no_annotations(abc_lines)
48
+
49
+ # delete \"
50
+ for i, line in enumerate(abc_lines):
51
+ if re.search(r'^[A-Za-z]:', line) or line.startswith('%'):
52
+ continue
53
+ else:
54
+ if r'\"' in line:
55
+ abc_lines[i] = abc_lines[i].replace(r'\"', '')
56
+
57
+ # delete text annotations with quotes
58
+ for i, line in enumerate(abc_lines):
59
+ quote_contents = re.findall(Quote_re, line)
60
+ for quote_content in quote_contents:
61
+ for barline in Barlines:
62
+ if barline in quote_content:
63
+ line = line.replace(quote_content, '')
64
+ abc_lines[i] = line
65
+
66
+ # check bar alignment
67
+ try:
68
+ _, bar_no_equal_flag, _ = check_alignment_unrotated(abc_lines)
69
+ if not bar_no_equal_flag:
70
+ print(abc_path, 'Unequal bar number')
71
+ raise Exception
72
+ except:
73
+ raise Exception
74
+
75
+ # deal with text annotations: remove too long text annotations; remove consecutive non-alphabet/number characters
76
+ for i, line in enumerate(abc_lines):
77
+ quote_matches = re.findall(r'"[^"]*"', line)
78
+ for match in quote_matches:
79
+ if match == '""':
80
+ line = line.replace(match, '')
81
+ if match[1] in ['^', '_']:
82
+ sub_string = match
83
+ pattern = r'([^a-zA-Z0-9])\1+'
84
+ sub_string = re.sub(pattern, r'\1', sub_string)
85
+ if len(sub_string) <= 40:
86
+ line = line.replace(match, sub_string)
87
+ else:
88
+ line = line.replace(match, '')
89
+ abc_lines[i] = line
90
+
91
+ abc_name = os.path.splitext(os.path.split(abc_path)[-1])[0]
92
+
93
+ # transpose
94
+ metadata_lines, part_text_dict = extract_metadata_and_parts(abc_lines)
95
+ global_metadata_dict, local_metadata_dict = extract_global_and_local_metadata(metadata_lines)
96
+ if global_metadata_dict['K'][0] == 'none':
97
+ global_metadata_dict['K'][0] = 'C'
98
+ ori_key = global_metadata_dict['K'][0]
99
+
100
+ interleaved_abc = rotate_abc(abc_lines)
101
+ interleaved_path = os.path.join(INTERLEAVED_FOLDER, abc_name + '.abc')
102
+ with open(interleaved_path, 'w') as w:
103
+ w.writelines(interleaved_abc)
104
+
105
+ for key in Key2index.keys():
106
+ transposed_abc_text = transpose_an_abc_text(abc_lines, key)
107
+ transposed_abc_lines = transposed_abc_text.split('\n')
108
+ transposed_abc_lines = list(filter(None, transposed_abc_lines))
109
+ transposed_abc_lines = [line + '\n' for line in transposed_abc_lines]
110
+
111
+ # rest reduction
112
+ metadata_lines, prefix_dict, left_barline_dict, bar_text_dict, right_barline_dict = \
113
+ extract_barline_and_bartext_dict(transposed_abc_lines)
114
+ reduced_abc_lines = metadata_lines
115
+ for i in range(len(bar_text_dict['V:1'])):
116
+ line = ''
117
+ for symbol in prefix_dict.keys():
118
+ valid_flag = False
119
+ for char in bar_text_dict[symbol][i]:
120
+ if char.isalpha() and not char in ['Z', 'z', 'X', 'x']:
121
+ valid_flag = True
122
+ break
123
+ if valid_flag:
124
+ if i == 0:
125
+ part_patch = '[' + symbol + ']' + prefix_dict[symbol] + left_barline_dict[symbol][0] + bar_text_dict[symbol][0] + right_barline_dict[symbol][0]
126
+ else:
127
+ part_patch = '[' + symbol + ']' + bar_text_dict[symbol][i] + right_barline_dict[symbol][i]
128
+ line += part_patch
129
+ line += '\n'
130
+ reduced_abc_lines.append(line)
131
+
132
+ reduced_abc_name = abc_name + '_' + key
133
+ reduced_abc_path = os.path.join(AUGMENTED_FOLDER, key, reduced_abc_name + '.abc')
134
+
135
+ with open(reduced_abc_path, 'w', encoding='utf-8') as w:
136
+ w.writelines(reduced_abc_lines)
137
+
138
+ return abc_name, ori_key
139
+
140
+
141
+
142
+
143
+
144
+ if __name__ == '__main__':
145
+
146
+ data = []
147
+ file_list = os.listdir(ORI_FOLDER)
148
+ for file in tqdm(file_list):
149
+ ori_abc_path = os.path.join(ORI_FOLDER, file)
150
+ try:
151
+ abc_name, ori_key = abc_preprocess_pipeline(ori_abc_path)
152
+ except:
153
+ print(ori_abc_path, 'failed to pre-process.')
154
+ continue
155
+
156
+ data.append({
157
+ 'path': os.path.join(AUGMENTED_FOLDER, abc_name),
158
+ 'key': ori_key
159
+ })
160
+
161
+ random.shuffle(data)
162
+ eval_data = data[ : int(EVAL_SPLIT * len(data))]
163
+ train_data = data[int(EVAL_SPLIT * len(data)) : ]
164
+
165
+ data_index_path = AUGMENTED_FOLDER + '.jsonl'
166
+ eval_index_path = AUGMENTED_FOLDER + '_eval.jsonl'
167
+ train_index_path = AUGMENTED_FOLDER + '_train.jsonl'
168
+
169
+
170
+ with open(data_index_path, 'w', encoding='utf-8') as w:
171
+ for d in data:
172
+ w.write(json.dumps(d) + '\n')
173
+ with open(eval_index_path, 'w', encoding='utf-8') as w:
174
+ for d in eval_data:
175
+ w.write(json.dumps(d) + '\n')
176
+ with open(train_index_path, 'w', encoding='utf-8') as w:
177
+ for d in train_data:
178
+ w.write(json.dumps(d) + '\n')
179
+
180
+
181
+
3_batch_abc2xml.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ORI_FOLDER = "" # Replace with the path to your folder containing standard/interleaved abc files
2
+ DES_FOLDER = "" # The script will convert the abc files and output musicxml files to this folder
3
+
4
+ import os
5
+ import math
6
+ import random
7
+ import subprocess
8
+ from tqdm import tqdm
9
+ from multiprocessing import Pool
10
+
11
+ def convert_abc2xml(file_list):
12
+ cmd = 'python abc2xml.py '
13
+ for file in tqdm(file_list):
14
+ filename = file.split('/')[-1] # Extract file name
15
+ os.makedirs(DES_FOLDER, exist_ok=True)
16
+
17
+ try:
18
+ p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
19
+ result = p.communicate()
20
+ output = result[0].decode('utf-8')
21
+
22
+ if output == '':
23
+ with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
24
+ f.write(file + '\n')
25
+ continue
26
+ else:
27
+ output_path = f"{DES_FOLDER}/" + ".".join(filename.split(".")[:-1]) + ".xml"
28
+ with open(output_path, 'w', encoding='utf-8') as f:
29
+ f.write(output)
30
+ except Exception as e:
31
+ with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
32
+ f.write(file + ' ' + str(e) + '\n')
33
+ pass
34
+
35
+ if __name__ == '__main__':
36
+ file_list = []
37
+ os.makedirs("logs", exist_ok=True)
38
+
39
+ # Traverse the specified folder for ABC files
40
+ for root, dirs, files in os.walk(ORI_FOLDER):
41
+ for file in files:
42
+ if not file.endswith(".abc"):
43
+ continue
44
+ filename = os.path.join(root, file).replace("\\", "/")
45
+ file_list.append(filename)
46
+
47
+ # Prepare for multiprocessing
48
+ file_lists = []
49
+ random.shuffle(file_list)
50
+ for i in range(os.cpu_count()):
51
+ start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
52
+ end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
53
+ file_lists.append(file_list[start_idx:end_idx])
54
+
55
+ pool = Pool(processes=os.cpu_count())
56
+ pool.map(convert_abc2xml, file_lists)
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Yashan Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README (1).md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Local Gradio Demo
2
+
3
+ 1. Set up the environment:
4
+
5
+ ```
6
+ conda create --name notagen python=3.10
7
+ conda activate notagen
8
+ conda install pytorch==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
9
+ pip install accelerate
10
+ pip install optimum
11
+ pip install -r requirements.txt
12
+ ```
13
+
14
+ 2. Download [NotaGen-X](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth) and put it under ```gradio/```.
15
+
16
+ 3. run ```demo.py```:
17
+
18
+ ```
19
+ cd gradio/
20
+ python demo.py
21
+ ```
22
+
23
+ 4. Then you can view the demo page at 0.0.0.0:7861.
24
+
25
+ <p align="center">
26
+ <img src="illustration.png" alt="NotaGen Gradio Demo">
27
+ </p>
28
+
29
+ You can choose period, composer, and instrumentation as a prompt combination for NotaGen's conditional generation. After generation completes, you can save the ABC notation and MusicXML files locally.
30
+
31
+ It is with some regret that the current combination of prompts is limited to 112, which is constrained by the number of pieces of music under each prompt in the fine-tuning dataset. We hope to expand the combinations and forms of prompts in the future.
README (2).md ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎵 NotaGen: Advancing Musicality in Symbolic Music Generation with Large Language Model Training Paradigms
2
+
3
+ <p align="center">
4
+ <!-- ArXiv -->
5
+ <a href="https://arxiv.org/abs/2502.18008">
6
+ <img src="https://img.shields.io/badge/NotaGen_Paper-ArXiv-%23B31B1B?logo=arxiv&logoColor=white" alt="Paper">
7
+ </a>
8
+ &nbsp;&nbsp;
9
+ <!-- HuggingFace -->
10
+ <a href="https://huggingface.co/ElectricAlexis/NotaGen">
11
+ <img src="https://img.shields.io/badge/NotaGen_Weights-HuggingFace-%23FFD21F?logo=huggingface&logoColor=white" alt="Weights">
12
+ </a>
13
+ &nbsp;&nbsp;
14
+ <!-- HuggingFace Space -->
15
+ <a href="https://huggingface.co/spaces/ElectricAlexis/NotaGen">
16
+ <img src="https://img.shields.io/badge/NotaGen_Space-Huggingface-✨️?logo=huggingface&logoColor=white" alt="Space">
17
+ </a>
18
+ &nbsp;&nbsp;
19
+ <!-- Web Demo -->
20
+ <a href="https://electricalexis.github.io/notagen-demo/">
21
+ <img src="https://img.shields.io/badge/NotaGen_Demo-Web-%23007ACC?logo=google-chrome&logoColor=white" alt="Demo">
22
+ </a>
23
+ </p>
24
+
25
+ <p align="center">
26
+ <img src="notagen.png" alt="NotaGen" width="50%">
27
+ </p>
28
+
29
+
30
+ ## 📖 Overview
31
+ **NotaGen** is a symbolic music generation model that explores the potential of producing **high-quality classical sheet music**. Inspired by the success of Large Language Models (LLMs), NotaGen adopts a three-stage training paradigm:
32
+ - 🧠 **Pre-training** on 1.6M musical pieces
33
+ - 🎯 **Fine-tuning** on ~9K classical compositions with `period-composer-instrumentation` prompts
34
+ - 🚀 **Reinforcement Learning** using our novel **CLaMP-DPO** method (no human annotations or pre-defined rewards required.)
35
+
36
+ Check our [demo page](https://electricalexis.github.io/notagen-demo/) and enjoy music composed by NotaGen!
37
+
38
+ ## ⚙️ Environment Setup
39
+
40
+ ```bash
41
+ conda create --name notagen python=3.10
42
+ conda activate notagen
43
+ conda install pytorch==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
44
+ pip install accelerate
45
+ pip install optimum
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ ## 🏋️ NotaGen Model Weights
50
+
51
+ ### Pre-training
52
+ We provide pre-trained weights of different scales:
53
+ | Models | Parameters | Patch-level Decoder Layers | Character-level Decoder Layers | Hidden Size | Patch Length (Context Length) |
54
+ | ---- | ---- | ---- | ---- | ---- | ---- |
55
+ | [NotaGen-small](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain_p_size_16_p_length_2048_p_layers_12_c_layers_3_h_size_768_lr_0.0002_batch_8.pth) | 110M | 12 | 3 | 768 | 2048 |
56
+ | [NotaGen-medium](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain_p_size_16_p_length_2048_p_layers_16_c_layers_3_h_size_1024_lr_0.0001_batch_4.pth) | 244M | 16 | 3 | 1024 | 2048 |
57
+ | [NotaGen-large](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_0.0001_batch_4.pth) | 516M | 20 | 6 | 1280 | 1024 |
58
+
59
+ **Notice**: The pre-trained weights cannot be used for conditional generation based on 'period-composer-instrumentation'.
60
+
61
+ ### Fine-tuning
62
+
63
+ We fine-tuned NotaGen-large on a corpus of approximately 9k classical pieces. You can download the weights [here](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain-finetune_p_size_16_p_length_1024_p_layers_c_layers_6_20_h_size_1280_lr_1e-05_batch_1.pth).
64
+
65
+ ### Reinforcement-Learning
66
+
67
+ After pre-training and fine-tuning, we optimized NotaGen-large with 3 iterations of CLaMP-DPO. You can download the weights [here](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain-finetune-RL3_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-06_batch_1.pth).
68
+
69
+ ### 🌟 NotaGen-X
70
+
71
+ Inspired by Deepseek-R1, we further optimized the training procedures of NotaGen and released a better version --- [NotaGen-X](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth). Compared to the version in the paper, NotaGen-X incorporates the following improvements:
72
+
73
+ - We introduced a post-training stage between pre-training and fine-tuning, refining the model with a classical-style subset of the pre-training dataset.
74
+ - We removed the key augmentation in the Fine-tune stage, making the instrument range of the generated compositions more reasonable.
75
+ - After RL, we utilized the resulting checkpoint to gather a new set of post-training data. Starting from the pre-trained checkpoint, we conducted another round of post-training, fine-tuning, and reinforcement learning.
76
+
77
+ If you want to add a new composer style to NotaGen-X, please refer to issue [#18](https://github.com/ElectricAlexis/NotaGen/issues/18) for more instructions :D
78
+
79
+ ## 🎹 Demo
80
+
81
+ ### Online Gradio Demo
82
+
83
+ We developed an [online gradio demo](https://huggingface.co/spaces/ElectricAlexis/NotaGen) on Huggingface Space for NotaGen-X. You can input **"Period-Composer-Instrumentation"** as the prompt to have NotaGen generate music, preview the audio / pdf scores, and download them :D
84
+
85
+ <p align="center">
86
+ <img src="gradio/illustration_online.png" alt="NotaGen Gradio Demo">
87
+ </p>
88
+
89
+ ### Local Gradio Demo
90
+
91
+ We developed a local Gradio demo for NotaGen-X. You can input **"Period-Composer-Instrumentation"** as the prompt to have NotaGen generate music!
92
+
93
+ <p align="center">
94
+ <img src="gradio/illustration.png" alt="NotaGen Gradio Demo">
95
+ </p>
96
+
97
+ Deploying NotaGen-X inference locally may require 8GB of GPU memory. For implementation details, please view [gradio/README.md](https://github.com/ElectricAlexis/NotaGen/blob/main/gradio/README.md). We are also working on developing an online demo.
98
+
99
+ ### Online Colab Notebook
100
+
101
+ Thanks for [@deeplearn-art](https://github.com/deeplearn-art/NotaGen)'s contribution of a [Google Colab notebook for NotaGen](https://colab.research.google.com/drive/1yJA1wG0fiwNeehdQxAUw56i4bTXzoVVv?usp=sharing)! You can run it and access to a Gradio public link to play with this demo. 🤩
102
+
103
+ ### ComfyUI
104
+
105
+ Thanks for [@billwuhao](https://github.com/billwuhao/ComfyUI_NotaGen)'s contribution of [a ComfyUI node for NotaGen](https://github.com/billwuhao/ComfyUI_NotaGen)! It can automatically convert generated .abc to .xml, .mp3, and .png formats. You can listen to the generated music and see the sheet music too! Please visit the [repository page](https://github.com/billwuhao/ComfyUI_NotaGen) for more information. 🤩
106
+
107
+ <p align="center">
108
+ <img src="https://github.com/billwuhao/ComfyUI_NotaGen/blob/master/images/2025-03-10_06-24-03.png" alt="NotaGen ComfyUI">
109
+ </p>
110
+
111
+
112
+ ## 🛠️ Data Pre-processing & Post-processing
113
+
114
+ For converting **ABC notation** files from / to **MusicXML** files, please view [data/README.md](https://github.com/ElectricAlexis/NotaGen/blob/main/data/README.md) for instructions.
115
+
116
+ To illustrate the specific data format, we provide a small dataset of **Schubert's lieder** compositions from the [OpenScore Lieder](https://github.com/OpenScore/Lieder), which includes:
117
+ - 🗂️ Interleaved ABC folders
118
+ - 🗂️ Augmented ABC folders
119
+ - 📄 Data index files for training and evaluation
120
+
121
+ You can download it [here](https://drive.google.com/drive/folders/1iVLkcywzXGcHFodce9nDQyEmK4UDmBtY?usp=sharing) and put it under ```data/```.
122
+
123
+ In the instructions of **Fine-tuning** and **Reinforcement Learning** below, we will use this dataset as an example of our implementation. **It won't include the "period-composer-instrumentation" conditioning**, just for showing how to adapt the pretrained NotaGen to a specific music style.
124
+
125
+
126
+ ## 🧠 Pre-train
127
+ If you want to use your own data to pre-train a blank **NotaGen** model, please:
128
+ 1. Preprocess the data and generate the data index files following the instructions in [data/README.md](https://github.com/ElectricAlexis/NotaGen/blob/main/data/README.md)
129
+ 2. Modify the parameters in ```pretrain/config.py```
130
+
131
+ Use this command for pre-training:
132
+ ```bash
133
+ cd pretrain/
134
+ accelerate launch --multi_gpu --mixed_precision fp16 train-gen.py
135
+ ```
136
+
137
+ ## 🎯 Fine-tune
138
+
139
+ Here we give an example on fine-tuning **NotaGen-large** with the **Schubert's lieder** data mentioned above.
140
+
141
+ **Notice:** The use of **NotaGen-large** requires at least **24GB of GPU memory** for training and inference. Alternatively, you may use **NotaGen-small** or **NotaGen-medium** and change the configuration of models in ```finetune/config.py```.
142
+
143
+ ### Configuration
144
+ - In ```finetune/config.py```:
145
+ - Modify the ```DATA_TRAIN_INDEX_PATH``` and ```DATA_EVAL_INDEX_PATH```:
146
+ ```python
147
+ # Configuration for the data
148
+ DATA_TRAIN_INDEX_PATH = "../data/schubert_augmented_train.jsonl"
149
+ DATA_EVAL_INDEX_PATH = "../data/schubert_augmented_eval.jsonl"
150
+ ```
151
+ - Download pre-trained NotaGen weights, and modify the ```PRETRAINED_PATH```:
152
+ ```python
153
+ PRETRAINED_PATH = "../pretrain/weights_notagen_pretrain_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_0.0001_batch_4.pth" # Use NotaGen-large
154
+ ```
155
+ - ```EXP_TAG``` is for differentiating the models. It will be integrated into the ckpt's name. Here we set it to ```schubert```.
156
+ - You can also modify other parameters like the learning rate.
157
+
158
+ ### Execution
159
+ Use this command for fine-tuning:
160
+ ```bash
161
+ cd finetune/
162
+ CUDA_VISIBLE_DEVICES=0 python train-gen.py
163
+ ```
164
+
165
+ ## 🚀 Reinforcement Learning (CLaMP-DPO)
166
+
167
+ Here we give an example on how to use **CLaMP-DPO** to enhance the model fine-tuned with **Schubert's lieder** data.
168
+
169
+ ### ⚙️ [CLaMP 2](https://github.com/sanderwood/clamp2) Setup
170
+
171
+ Download model weights and put them under the ```clamp2/```folder:
172
+ - [CLaMP 2 Model Weights](https://huggingface.co/sander-wood/clamp2/blob/main/weights_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.pth)
173
+ - [M3 Model Weights](https://huggingface.co/sander-wood/clamp2/blob/main/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth)
174
+
175
+ ### 🔍 Extract Ground Truth Features
176
+ Modify ```input_dir``` and ```output_dir``` in ```clamp2/extract_clamp2.py```:
177
+ ```python
178
+ input_dir = '../data/schubert_interleaved' # interleaved abc folder
179
+ output_dir = 'feature/schubert_interleaved' # feature folder
180
+ ```
181
+ Extract the features:
182
+ ```
183
+ cd clamp2/
184
+ python extract_clamp2.py
185
+ ```
186
+
187
+ ### 🔄 CLaMP-DPO
188
+
189
+ Here we give an example of an iteration of **CLaMP-DPO** from the initial model fine-tuned on **Schubert's lieder** data.
190
+
191
+ #### 1. Inference
192
+ - Modify the ```INFERENCE_WEIGHTS_PATH``` to path of the fine-tuned weights and ```NUM_SAMPLES``` to generate in ```inference/config.py```:
193
+ ```python
194
+ INFERENCE_WEIGHTS_PATH = '../finetune/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1.pth'
195
+ NUM_SAMPLES = 1000
196
+ ```
197
+ - Inference:
198
+ ```
199
+ cd inference/
200
+ python inference.py
201
+ ```
202
+ This will generate an ```output/```folder with two subfolders: ```original``` and ```interleaved```. The ```original/``` subdirectory stores the raw inference outputs from the model, while the ```interleaved/``` subdirectory contains data post-processed with rest measure completion, compatible with CLaMP 2. Each of these subdirectories will contain a model-specific folder, named as a combination of the model's name and its sampling parameters.
203
+
204
+ #### 2. Extract Generated Data Features
205
+
206
+ Modify ```input_dir``` and ```output_dir``` in ```clamp2/extract_clamp2.py```:
207
+ ```python
208
+ input_dir = '../output/interleaved/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2' # interleaved abc folder
209
+ output_dir = 'feature/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2' # feature folder
210
+ ```
211
+ Extract the features:
212
+ ```
213
+ cd clamp2/
214
+ python extract_clamp2.py
215
+ ```
216
+
217
+ #### 3. Statistics on Averge CLaMP 2 Score (Optional)
218
+ If you're interested in the **Average CLaMP 2 Score** of the current model, modify the parameters in ```clamp2/statistics.py```:
219
+ ```python
220
+ gt_feature_folder = 'feature/schubert_interleaved'
221
+ output_feature_folder = 'feature/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
222
+ ```
223
+ Then run this script:
224
+ ```
225
+ cd clamp2/
226
+ python statistics.py
227
+ ```
228
+
229
+ #### 4. Construct Preference Data
230
+ Modify the parameters in ```RL/data.py```:
231
+ ```python
232
+ gt_feature_folder = '../clamp2/feature/schubert_interleaved'
233
+ output_feature_folder = '../clamp2/feature/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
234
+ output_original_abc_folder = '../output/original/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
235
+ output_interleaved_abc_folder = '../output/interleaved/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
236
+ data_index_path = 'schubert_RL1.json' # Data for the first iteration of RL
237
+ data_select_portion = 0.1
238
+ ```
239
+ In this script, the **CLaMP 2 Score** of each generated piece will be calculated and sorted. The portion of data in the chosen and rejected sets is determined by ```data_select_portion```. Additionally, there are also three rules to exclude problematic sheets from the chosen set:
240
+ - Sheets with duration alignment problems are excluded;
241
+ - Sheets that may plagiarize from ground truth data (ld_sim>0.95) are excluded;
242
+ - Sheets where staves for the same instrument are not grouped together are excluded.
243
+
244
+ The prefence data file will be names as ```data_index_path```, which records the file paths in chosen and rejected sets.
245
+
246
+ Run this script:
247
+ ```
248
+ cd RL/
249
+ python data.py
250
+ ```
251
+
252
+ #### 5. DPO Training
253
+
254
+ Modify the parameters in ```RL/config.py```:
255
+ ```python
256
+ DATA_INDEX_PATH = 'schubert_RL1.json' # Preference data path
257
+ PRETRAINED_PATH = '../finetune/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1.pth' # The model to go through DPO optimization
258
+ EXP_TAG = 'schubert-RL1' # Model tag for differentiation
259
+ ```
260
+ You can also modify other parameters like ```OPTIMATION_STEPS``` and DPO hyper-parameters.
261
+
262
+ Run this script:
263
+ ```
264
+ cd RL/
265
+ CUDA_VISIBLE_DEVICES=0 python train.py
266
+ ```
267
+ After training, a model named ```weights_notagen_schubert-RL1_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-06.pth``` will be saved under ```RL/```. For the second round of CLaMP-DPO, please go back to the first inference stage, and let the new model to generate pieces.
268
+
269
+ For this small experiment on **Schubert's lieder** data, we post our **Average CLaMP 2 Score** here for the fine-tuned model and models after each iteration of CLaMP-DPO, as a reference:
270
+
271
+ | CLaMP-DPO Iteration (K) | Average CLaMP 2 Score |
272
+ | ---- | ---- |
273
+ | 0 (fine-tuned) | 0.324 |
274
+ | 1 | 0.579 |
275
+ | 2 | 0.778 |
276
+
277
+ If you are interested in this method, have a try on your own style-specific dataset :D
278
+
279
+ ## 📚 Citation
280
+
281
+ If you find **NotaGen** or **CLaMP-DPO** useful in your work, please cite our paper.
282
+
283
+ ```bibtex
284
+ @misc{wang2025notagenadvancingmusicalitysymbolic,
285
+ title={NotaGen: Advancing Musicality in Symbolic Music Generation with Large Language Model Training Paradigms},
286
+ author={Yashan Wang and Shangda Wu and Jianhuai Hu and Xingjian Du and Yueqi Peng and Yongxin Huang and Shuai Fan and Xiaobing Li and Feng Yu and Maosong Sun},
287
+ year={2025},
288
+ eprint={2502.18008},
289
+ archivePrefix={arXiv},
290
+ primaryClass={cs.SD},
291
+ url={https://arxiv.org/abs/2502.18008},
292
+ }
293
+ ```
README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Data Pre-processing
2
+
3
+ ### Convert from MusicXML
4
+
5
+ - Navigate to the data folder ```cd data/```
6
+ - Modify the ```ORI_FOLDER``` and ```DES_FOLDER``` in ```1_batch_xml2abc.py```, then run this script:
7
+ ```
8
+ python 1_batch_xml2abc.py
9
+ ```
10
+ This will conver the MusicXML files into standard ABC notation files.
11
+ - Modify the ```ORI_FOLDER```, ```INTERLEAVED_FOLDER```, ```AUGMENTED_FOLDER```, and ```EVAL_SPLIT``` in ```2_data_preprocess.py```:
12
+
13
+ ```python
14
+ ORI_FOLDER = '' # Folder containing standard ABC notation files
15
+ INTERLEAVED_FOLDER = '' # Output interleaved ABC notation files that are compatible with CLaMP 2 to this folder
16
+ AUGMENTED_FOLDER = '' # On the basis of interleaved ABC, output key-augmented and rest-omitted files that are compatible with NotaGen to this folder
17
+ EVAL_SPLIT = 0.1 # Evaluation data ratio
18
+ ```
19
+ then run this script:
20
+ ```
21
+ python 2_data_preprocess.py
22
+ ```
23
+ - The script will convert the standard ABC to interleaved ABC, which is compatible with CLaMP 2. The files will be under ```INTERLEAVED_FOLDER```.
24
+
25
+ - This script will make 15 key signature folders under the ```AUGMENTED_FOLDER```, and output interleaved ABC notation files with rest bars omitted. This is the data representation that NotaGen adopts.
26
+
27
+ - This script will also generate data index files for training NotaGen. It will randomly split train and eval sets according to the proportion ```EVAL_SPLIT``` defines. The index files will be named as ```{AUGMENTED_FOLDER}_train.jsonl``` and ```{AUGMENTED_FOLDER}_eval.jsonl```.
28
+
29
+ ## Data Post-processing
30
+
31
+ ### Preview Sheets in ABC Notation
32
+
33
+ We recommend [EasyABC](https://sourceforge.net/projects/easyabc/), a nice software for ABC Notation previewing, composing and editing.
34
+
35
+ It's needed to add a line "X:1" before each piece to present the score image in EasyABC :D
36
+
37
+ ### Convert to MusicXML
38
+
39
+ - Go to the data folder ```cd data/```
40
+ - Modify the ```ORI_FOLDER``` and ```DES_FOLDER``` in ```3_batch_abc2xml.py```, then run this script:
41
+ ```
42
+ python 3_batch_abc2xml.py
43
+ ```
44
+ This will conver the standard/interleaved ABC notation files into MusicXML files.
abc2xml (1).py ADDED
The diff for this file is too large to render. See raw diff
 
abc2xml (2).py ADDED
The diff for this file is too large to render. See raw diff
 
abc2xml.py ADDED
The diff for this file is too large to render. See raw diff
 
config (1).py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation
2
+ WANDB_KEY = "<your_wandb_key>" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging
3
+
4
+ # -------------------- Configuration for M3 Training --------------------
5
+ TRAIN_FOLDERS = [
6
+ "<path_to_training_data>" # Directory containing training data
7
+ ]
8
+
9
+ EVAL_FOLDERS = [
10
+ "" # (Optional) Directory containing evaluation data
11
+ ]
12
+
13
+ PATCH_SIZE = 64 # Size of each patch
14
+ PATCH_LENGTH = 512 # Length of the patches
15
+ PATCH_NUM_LAYERS = 12 # Number of layers in the encoder
16
+ TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder
17
+ M3_HIDDEN_SIZE = 768 # Size of the hidden layer
18
+
19
+ M3_NUM_EPOCH = 100 # Maximum number of epochs for training
20
+ M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer
21
+ M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training
22
+ M3_MASK_RATIO = 0.45 # Ratio of masked elements during training
23
+ M3_DETERMINISTIC = True # Ensures deterministic results with random seeds
24
+ M3_WANDB_LOG = True # Enable logging to Weights and Biases
25
+ M3_LOAD_CKPT = True # Load model weights from a checkpoint if available
26
+
27
+ M3_WEIGHTS_PATH = (
28
+ "weights_m3_p_size_" + str(PATCH_SIZE) +
29
+ "_p_length_" + str(PATCH_LENGTH) +
30
+ "_t_layers_" + str(TOKEN_NUM_LAYERS) +
31
+ "_p_layers_" + str(PATCH_NUM_LAYERS) +
32
+ "_h_size_" + str(M3_HIDDEN_SIZE) +
33
+ "_lr_" + str(M3_LEARNING_RATE) +
34
+ "_batch_" + str(M3_BATCH_SIZE) +
35
+ "_mask_" + str(M3_MASK_RATIO) + ".pth"
36
+ ) # Path to store the model weights
37
+ M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
38
+
39
+ # -------------------- Configuration for CLaMP2 Training ----------------
40
+ TRAIN_JSONL = "<path_to_training_jsonl>" # Path to the JSONL file with training data
41
+ EVAL_JSONL = "" # (Optional) Path to the JSONL file with evaluation data
42
+
43
+ CLAMP2_HIDDEN_SIZE = 768 # Size of the hidden layer
44
+ TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model
45
+
46
+ CLAMP2_NUM_EPOCH = 100 # Maximum number of epochs for training
47
+ CLAMP2_LEARNING_RATE = 5e-5 # Learning rate for the optimizer
48
+ CLAMP2_BATCH_SIZE = 128 # Batch size per GPU (single card) during training
49
+ LOGIT_SCALE = 1 # Scaling factor for contrastive loss
50
+ MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input
51
+ TEXT_DROPOUT = True # Whether to apply dropout during text processing
52
+ CLAMP2_DETERMINISTIC = True # Ensures deterministic results with random seeds
53
+ CLAMP2_LOAD_M3 = True # Load weights from the M3 model
54
+ CLAMP2_WANDB_LOG = True # Enable logging to Weights and Biases
55
+ CLAMP2_LOAD_CKPT = True # Load weights from a checkpoint if available
56
+
57
+ CLAMP2_WEIGHTS_PATH = (
58
+ "weights_clamp2_h_size_" + str(CLAMP2_HIDDEN_SIZE) +
59
+ "_lr_" + str(CLAMP2_LEARNING_RATE) +
60
+ "_batch_" + str(CLAMP2_BATCH_SIZE) +
61
+ "_scale_" + str(LOGIT_SCALE) +
62
+ "_t_length_" + str(MAX_TEXT_LENGTH) +
63
+ "_t_model_" + TEXT_MODEL_NAME.replace("/", "_") +
64
+ "_t_dropout_" + str(TEXT_DROPOUT) +
65
+ "_m3_" + str(CLAMP2_LOAD_M3) + ".pth"
66
+ ) # Path to store CLaMP2 model weights
67
+ CLAMP2_LOGS_PATH = CLAMP2_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
config (2).py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Configuration for the data
4
+ DATA_TRAIN_INDEX_PATH = ""
5
+ DATA_EVAL_INDEX_PATH = ""
6
+
7
+ # Configuration for the model
8
+ PATCH_STREAM = True # Stream training / inference
9
+ PATCH_SIZE = 16 # Patch Size
10
+ PATCH_LENGTH = 1024 # Patch Length
11
+ CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
12
+ PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
13
+ HIDDEN_SIZE = 1280 # Hidden Size
14
+
15
+ # Configuration for the training
16
+ BATCH_SIZE = 1
17
+ LEARNING_RATE = 1e-5
18
+ NUM_EPOCHS = 64 # Number of epochs to train for (if early stopping doesn't intervene)
19
+ ACCUMULATION_STEPS = 1 # Accumulation steps to simulate large batch size
20
+ PATCH_SAMPLING_BATCH_SIZE = 0 # Batch size for patch during training, 0 for full conaudio
21
+ LOAD_FROM_CHECKPOINT = False # Whether to load weights from a checkpoint
22
+ WANDB_LOGGING = False # Whether to log to wandb
23
+ WANDB_KEY = '<your_wandb_key>'
24
+
25
+ PRETRAINED_PATH = "" # Path of pretrained weights
26
+ EXP_TAG = '' # Experiment tag for name differentiation
27
+ NAME = EXP_TAG + \
28
+ "_p_size_" + str(PATCH_SIZE) + \
29
+ "_p_length_" + str(PATCH_LENGTH) + \
30
+ "_p_layers_" + str(PATCH_NUM_LAYERS) + \
31
+ "_c_layers_" + str(CHAR_NUM_LAYERS) + \
32
+ "_h_size_" + str(HIDDEN_SIZE) + \
33
+ "_lr_" + str(LEARNING_RATE) + \
34
+ "_batch_" + str(BATCH_SIZE)
35
+
36
+ WEIGHTS_PATH = "weights_notagen_" + NAME + ".pth" # Path to save weights
37
+ LOGS_PATH = "logs_notagen_" + NAME + ".txt" # Path to save logs
38
+ WANDB_NAME = NAME
config (3).py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Configurations for inference
4
+ INFERENCE_WEIGHTS_PATH = 'weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth' # Path to weights for inference# Folder to save output files
5
+ TOP_K = 9 # Top k for sampling
6
+ TOP_P = 0.9 # Top p for sampling
7
+ TEMPERATURE = 1.2 # Temperature for sampling
8
+
9
+ # Configurations for model
10
+ PATCH_STREAM = True # Stream training / inference
11
+ PATCH_SIZE = 16 # Patch Size
12
+ PATCH_LENGTH = 1024 # Patch Length
13
+ CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
14
+ PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
15
+ HIDDEN_SIZE = 1280 # Hidden Size
config (4).py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Configurations for inference
4
+ INFERENCE_WEIGHTS_PATH = '' # Path to weights for inference# Folder to save output files
5
+ NUM_SAMPLES = 1000 # Number of samples to generate (only for generate mode)
6
+ TOP_K = 9 # Top k for sampling
7
+ TOP_P = 0.9 # Top p for sampling
8
+ TEMPERATURE = 1.2 # Temperature for sampling
9
+ ORIGINAL_OUTPUT_FOLDER = os.path.join('../output/original', os.path.splitext(os.path.split(INFERENCE_WEIGHTS_PATH)[-1])[0] + '_k_' + str(TOP_K) + '_p_' + str(TOP_P) + '_temp_' + str(TEMPERATURE))
10
+ INTERLEAVED_OUTPUT_FOLDER = os.path.join('../output/interleaved', os.path.splitext(os.path.split(INFERENCE_WEIGHTS_PATH)[-1])[0] + '_k_' + str(TOP_K) + '_p_' + str(TOP_P) + '_temp_' + str(TEMPERATURE))
11
+
12
+ # Configurations for model
13
+ PATCH_STREAM = True # Stream training / inference
14
+ PATCH_SIZE = 16 # Patch Size
15
+ PATCH_LENGTH = 1024 # Patch Length
16
+ CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
17
+ PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
18
+ HIDDEN_SIZE = 1280 # Hidden Size
config (5).py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Configuration for the data
4
+ DATA_TRAIN_INDEX_PATH = ""
5
+ DATA_EVAL_INDEX_PATH = ""
6
+
7
+ # Configuration for the model
8
+ PATCH_STREAM = True
9
+ PATCH_SIZE = 16 # Patch Size
10
+ PATCH_LENGTH = 2048 # Patch Length
11
+ CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
12
+ PATCH_NUM_LAYERS = 12 # Number of layers in the encoder
13
+ HIDDEN_SIZE = 768 # Hidden Size
14
+
15
+ # Configuration for the training
16
+ BATCH_SIZE = 4
17
+ LEARNING_RATE = 1e-4
18
+ NUM_EPOCHS = 128 # Number of epochs to train for (if early stopping doesn't intervene)
19
+ ACCUMULATION_STEPS = 1 # Accumulation steps to simulate large batch size
20
+ PATCH_SAMPLING_BATCH_SIZE = 0 # Batch size for patch during training, 0 for full conaudio
21
+ LOAD_FROM_CHECKPOINT = False # Whether to load weights from a checkpoint
22
+ WANDB_LOGGING = False # Whether to log to wandb
23
+ WANDB_KEY = '<your_wandb_key>'
24
+
25
+ EXP_TAG = 'pretrain' # Experiment tag for differentiation
26
+ NAME = EXP_TAG + \
27
+ "_p_size_" + str(PATCH_SIZE) + \
28
+ "_p_length_" + str(PATCH_LENGTH) + \
29
+ "_p_layers_" + str(PATCH_NUM_LAYERS) + \
30
+ "_c_layers_" + str(CHAR_NUM_LAYERS) + \
31
+ "_h_size_" + str(HIDDEN_SIZE) + \
32
+ "_lr_" + str(LEARNING_RATE) + \
33
+ "_batch_" + str(BATCH_SIZE)
34
+
35
+ WEIGHTS_PATH = "weights_notagen_" + NAME + ".pth" # Path to save weights
36
+ LOGS_PATH = "logs_notagen_" + NAME + ".txt" # Path to save logs
37
+ WANDB_NAME = NAME
38
+
39
+
config.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Configuration for the data
4
+ DATA_INDEX_PATH = ''
5
+
6
+ # Configuration for the model
7
+ PATCH_STREAM = True
8
+ PATCH_SIZE = 16 # Patch Size
9
+ PATCH_LENGTH = 1024 # Patch Length
10
+ CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
11
+ PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
12
+ HIDDEN_SIZE = 1280 # Hidden Size
13
+
14
+ # Configuration for the training
15
+ BETA = 0.1 # beta in DPO's objective function
16
+ LAMBDA = 10 # lambda in DPOP's objective function
17
+ LEARNING_RATE = 1e-6
18
+ OPTIMIZATION_STEPS = 10000 # Optimization steps for DPO
19
+ WANDB_LOGGING = False # Whether to log to wandb
20
+ WANDB_KEY = '<your_wandb_key>'
21
+
22
+ PRETRAINED_PATH = ''
23
+ EXP_TAG = ''
24
+ NAME = EXP_TAG + \
25
+ "_beta_" + str(BETA) + \
26
+ "_lambda_" + str(LAMBDA) + \
27
+ "_p_size_" + str(PATCH_SIZE) + \
28
+ "_p_length_" + str(PATCH_LENGTH) + \
29
+ "_p_layers_" + str(PATCH_NUM_LAYERS) + \
30
+ "_c_layers_" + str(CHAR_NUM_LAYERS) + \
31
+ "_h_size_" + str(HIDDEN_SIZE) + \
32
+ "_lr_" + str(LEARNING_RATE)
33
+
34
+ WEIGHTS_PATH = "weights_notagen_" + NAME + ".pth" # Path to save weights
35
+ WANDB_NAME = NAME
data.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gt_feature_folder = '../clamp2/feature/schubert_interleaved'
2
+ output_feature_folder = '../clamp2/feature/weights_notagen_schubert-RL2_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_h_size_1280_lr_1e-06_k_9_p_0.9_temp_1.2'
3
+ output_original_abc_folder = '../output/original/weights_notagen_schubert-RL2_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_h_size_1280_lr_1e-06_k_9_p_0.9_temp_1.2'
4
+ output_interleaved_abc_folder = '../output/interleaved/weights_notagen_schubert-RL2_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_h_size_1280_lr_1e-06_k_9_p_0.9_temp_1.2'
5
+ data_index_path = 'schubert_RL3.json'
6
+ data_select_portion = 0.1
7
+
8
+ import os
9
+ import re
10
+ import json
11
+ import random
12
+ import numpy as np
13
+ from config import *
14
+ from abctoolkit.check import check_alignment_rotated, check_alignment_unrotated
15
+ from abctoolkit.rotate import unrotate_abc
16
+
17
+
18
+ def load_npy_files(folder_path_list):
19
+ """
20
+ Load all .npy files from a specified folder and return a list of numpy arrays.
21
+ """
22
+ npy_list = []
23
+ for file_path in folder_path_list:
24
+ if file_path.endswith('.npy'):
25
+ # file_path = os.path.join(folder_path, file_name)
26
+ np_array = np.load(file_path)[0]
27
+ npy_list.append(np_array)
28
+ return npy_list
29
+
30
+ def average_npy(npy_list):
31
+ """
32
+ Compute the average of a list of numpy arrays.
33
+ """
34
+ return np.mean(npy_list, axis=0)
35
+
36
+ def cosine_similarity(vec1, vec2):
37
+ """
38
+ Compute cosine similarity between two numpy arrays.
39
+ """
40
+ dot_product = np.dot(vec1, vec2)
41
+
42
+ norm_vec1 = np.linalg.norm(vec1)
43
+ norm_vec2 = np.linalg.norm(vec2)
44
+
45
+ cosine_sim = dot_product / (norm_vec1 * norm_vec2)
46
+
47
+ return cosine_sim
48
+
49
+
50
+ def generate_preference_dict():
51
+
52
+ gt_feature_paths = []
53
+ for gt_feature_file in os.listdir(gt_feature_folder):
54
+ gt_feature_paths.append(os.path.join(gt_feature_folder, gt_feature_file))
55
+ gt_features = load_npy_files(gt_feature_paths)
56
+ gt_avg_feature = average_npy(gt_features)
57
+
58
+ output_feature_sim_dict = {}
59
+ for file in os.listdir(output_feature_folder):
60
+ output_feature_path = os.path.join(output_feature_folder, file)
61
+ output_feature = np.load(output_feature_path)[0]
62
+ sim = cosine_similarity(gt_avg_feature, output_feature)
63
+ output_feature_sim_dict[file[:-4]] = sim
64
+
65
+ threshold = int(len(output_feature_sim_dict) * data_select_portion)
66
+ sorted_output_files = sorted(output_feature_sim_dict.keys(), key=lambda item: output_feature_sim_dict[item], reverse=True)
67
+
68
+ chosen_index = 0
69
+ i = 0
70
+ chosen_abc_paths = []
71
+ while chosen_index < threshold and i < len(sorted_output_files):
72
+
73
+ chosen_flag = True
74
+
75
+ file = sorted_output_files[i]
76
+ output_interleaved_abc_path = os.path.join(output_interleaved_abc_folder, file + '.abc')
77
+
78
+ with open(output_interleaved_abc_path, 'r') as f:
79
+ abc_lines = f.readlines()
80
+
81
+ # check aligment
82
+ try:
83
+ abc_lines_unrotated = unrotate_abc(abc_lines)
84
+ barline_equal_flag, bar_no_equal_flag, bar_dur_equal_flag = check_alignment_unrotated(abc_lines_unrotated)
85
+ if not (barline_equal_flag and bar_no_equal_flag and bar_dur_equal_flag):
86
+ raise Exception
87
+ except:
88
+ chosen_flag = False
89
+
90
+ # check header: sheets where staves for the same instrument are not grouped together are excluded from the chosen set.
91
+ appeared_inst = set()
92
+ last_inst = ''
93
+ for line in abc_lines:
94
+ if line.startswith('V:') and 'nm=' in line:
95
+ match = re.search(r'nm="([^"]+)"', line)
96
+ if match:
97
+ inst = match.group(1)
98
+ if inst != last_inst and inst in appeared_inst:
99
+ chosen_flag = False
100
+ break
101
+ else:
102
+ last_inst = inst
103
+ appeared_inst.add(inst)
104
+
105
+ # check plagiarism: sheets with sim > 0.95 are excluded
106
+ output_feature_path = os.path.join(output_feature_folder, file + '.npy')
107
+ output_feature = np.load(output_feature_path)[0]
108
+ for gt_feature_file in os.listdir(gt_feature_folder):
109
+ gt_feature_path = os.path.join(gt_feature_folder, gt_feature_file)
110
+ gt_feature = np.load(gt_feature_path)[0]
111
+ sim = cosine_similarity(output_feature, gt_feature)
112
+ if sim > 0.95:
113
+ chosen_flag = False
114
+ break
115
+
116
+ if chosen_flag:
117
+ original_abc_path = os.path.join(output_original_abc_folder, file + '.abc')
118
+ chosen_abc_paths.append(original_abc_path)
119
+ chosen_index += 1
120
+ else:
121
+ print(file, 'skipped')
122
+
123
+ i += 1
124
+
125
+ rejected_abc_paths = [os.path.join(output_original_abc_folder, file + '.abc') for file in sorted_output_files[-threshold:]]
126
+ preference_dict = {'chosen': chosen_abc_paths, 'rejected': rejected_abc_paths}
127
+
128
+ with open(data_index_path, 'w') as w:
129
+ json.dump(preference_dict, w, indent=4)
130
+
131
+
132
+ if __name__ == '__main__':
133
+
134
+ generate_preference_dict()
135
+
136
+
demo.ipynb ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6e5cf1e7-c275-4929-9c44-ec48e26a2d4d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import re\n",
12
+ "import time\n",
13
+ "import torch\n",
14
+ "import torch\n",
15
+ "import random\n",
16
+ "import bisect\n",
17
+ "import json\n",
18
+ "from pathlib import Path\n",
19
+ "from tokenizers import Tokenizer\n",
20
+ "from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, LlamaModel, LlamaForCausalLM, PreTrainedModel \n",
21
+ "from samplings import top_p_sampling, top_k_sampling, temperature_sampling\n",
22
+ "from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern\n",
23
+ "from abctoolkit.transpose import Note_list, Pitch_sign_list\n",
24
+ "from abctoolkit.duration import calculate_bartext_duration"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "00fd2ebb-6e53-4038-af85-f9c5f02fde0e",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "# Configurations for inference\n",
35
+ "INFERENCE_WEIGHTS_PATH = '../weights/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth' # Path to weights for inference# Folder to save output files\n",
36
+ "TOP_K = 9 # Top k for sampling\n",
37
+ "TOP_P = 0.9 # Top p for sampling\n",
38
+ "TEMPERATURE = 1.2 # Temperature for sampling\n",
39
+ "\n",
40
+ "# Configurations for model\n",
41
+ "PATCH_STREAM = True # Stream training / inference\n",
42
+ "PATCH_SIZE = 16 # Patch Size\n",
43
+ "PATCH_LENGTH = 1024 # Patch Length\n",
44
+ "CHAR_NUM_LAYERS = 6 # Number of layers in the decoder\n",
45
+ "PATCH_NUM_LAYERS = 20 # Number of layers in the encoder\n",
46
+ "HIDDEN_SIZE = 1280 # Hidden Size\n",
47
+ "\n",
48
+ "device = torch.device(\"cuda\")"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "fb70eb19-8b9c-4864-b711-7a0395b42c49",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "class Patchilizer:\n",
59
+ " def __init__(self, stream=PATCH_STREAM):\n",
60
+ " self.stream = stream\n",
61
+ " self.delimiters = [\"|:\", \"::\", \":|\", \"[|\", \"||\", \"|]\", \"|\"]\n",
62
+ " self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'\n",
63
+ " self.bos_token_id = 1\n",
64
+ " self.eos_token_id = 2\n",
65
+ " self.special_token_id = 0\n",
66
+ "\n",
67
+ " def split_bars(self, body_lines):\n",
68
+ " \"\"\"\n",
69
+ " Split a body of music into individual bars.\n",
70
+ " \"\"\"\n",
71
+ " new_bars = []\n",
72
+ " try:\n",
73
+ " for line in body_lines:\n",
74
+ " line_bars = re.split(self.regexPattern, line)\n",
75
+ " line_bars = list(filter(None, line_bars))\n",
76
+ " new_line_bars = []\n",
77
+ "\n",
78
+ " if len(line_bars) == 1:\n",
79
+ " new_line_bars = line_bars\n",
80
+ " else:\n",
81
+ " if line_bars[0] in self.delimiters:\n",
82
+ " new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]\n",
83
+ " else:\n",
84
+ " new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]\n",
85
+ " if 'V' not in new_line_bars[-1]:\n",
86
+ " new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\\n 的组合\n",
87
+ " new_line_bars = new_line_bars[:-1]\n",
88
+ " new_bars += new_line_bars\n",
89
+ " except:\n",
90
+ " pass\n",
91
+ "\n",
92
+ " return new_bars\n",
93
+ "\n",
94
+ " def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):\n",
95
+ " if not generate_last and len(abc_text) % patch_size != 0:\n",
96
+ " abc_text += chr(self.eos_token_id)\n",
97
+ " patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]\n",
98
+ " return patches\n",
99
+ "\n",
100
+ " def patch2chars(self, patch):\n",
101
+ " \"\"\"\n",
102
+ " Convert a patch into a bar.\n",
103
+ " \"\"\"\n",
104
+ " bytes = ''\n",
105
+ " for idx in patch:\n",
106
+ " if idx == self.eos_token_id:\n",
107
+ " break\n",
108
+ " if idx < self.eos_token_id:\n",
109
+ " pass\n",
110
+ " bytes += chr(idx)\n",
111
+ " return bytes\n",
112
+ " \n",
113
+ "\n",
114
+ " def patchilize_metadata(self, metadata_lines):\n",
115
+ "\n",
116
+ " metadata_patches = []\n",
117
+ " for line in metadata_lines:\n",
118
+ " metadata_patches += self.split_patches(line)\n",
119
+ "\n",
120
+ " return metadata_patches\n",
121
+ " \n",
122
+ " def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):\n",
123
+ "\n",
124
+ " tunebody_patches = []\n",
125
+ " bars = self.split_bars(tunebody_lines)\n",
126
+ " if encode_mode == 'train':\n",
127
+ " for bar in bars:\n",
128
+ " tunebody_patches += self.split_patches(bar)\n",
129
+ " elif encode_mode == 'generate':\n",
130
+ " for bar in bars[:-1]:\n",
131
+ " tunebody_patches += self.split_patches(bar)\n",
132
+ " tunebody_patches += self.split_patches(bars[-1], generate_last=True)\n",
133
+ " \n",
134
+ " return tunebody_patches\n",
135
+ "\n",
136
+ " def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):\n",
137
+ "\n",
138
+ " lines = abc_text.split('\\n')\n",
139
+ " lines = list(filter(None, lines))\n",
140
+ " lines = [line + '\\n' for line in lines]\n",
141
+ "\n",
142
+ " tunebody_index = -1\n",
143
+ " for i, line in enumerate(lines):\n",
144
+ " if '[V:' in line:\n",
145
+ " tunebody_index = i\n",
146
+ " break\n",
147
+ "\n",
148
+ " metadata_lines = lines[ : tunebody_index]\n",
149
+ " tunebody_lines = lines[tunebody_index : ]\n",
150
+ "\n",
151
+ " if self.stream:\n",
152
+ " tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in\n",
153
+ " enumerate(tunebody_lines)] \n",
154
+ "\n",
155
+ " metadata_patches = self.patchilize_metadata(metadata_lines)\n",
156
+ " tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')\n",
157
+ "\n",
158
+ " if add_special_patches:\n",
159
+ " bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)\n",
160
+ " eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)\n",
161
+ "\n",
162
+ " metadata_patches = [bos_patch] + metadata_patches\n",
163
+ " tunebody_patches = tunebody_patches + [eos_patch]\n",
164
+ "\n",
165
+ " if self.stream:\n",
166
+ " if len(metadata_patches) + len(tunebody_patches) > patch_length:\n",
167
+ " available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\\n' in patch]\n",
168
+ " line_index_for_cut_index = list(range(len(available_cut_indexes))) \n",
169
+ " end_index = len(metadata_patches) + len(tunebody_patches) - patch_length\n",
170
+ " biggest_index = bisect.bisect_left(available_cut_indexes, end_index) \n",
171
+ " available_cut_indexes = available_cut_indexes[:biggest_index + 1]\n",
172
+ "\n",
173
+ " if len(available_cut_indexes) == 1:\n",
174
+ " choices = ['head']\n",
175
+ " elif len(available_cut_indexes) == 2:\n",
176
+ " choices = ['head', 'tail']\n",
177
+ " else:\n",
178
+ " choices = ['head', 'tail', 'middle']\n",
179
+ " choice = random.choice(choices)\n",
180
+ " if choice == 'head':\n",
181
+ " patches = metadata_patches + tunebody_patches[0:]\n",
182
+ " else:\n",
183
+ " if choice == 'tail':\n",
184
+ " cut_index = len(available_cut_indexes) - 1\n",
185
+ " else:\n",
186
+ " cut_index = random.choice(range(1, len(available_cut_indexes) - 1))\n",
187
+ "\n",
188
+ " line_index = line_index_for_cut_index[cut_index] \n",
189
+ " stream_tunebody_lines = tunebody_lines[line_index : ]\n",
190
+ " \n",
191
+ " stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')\n",
192
+ " if add_special_patches:\n",
193
+ " stream_tunebody_patches = stream_tunebody_patches + [eos_patch]\n",
194
+ " patches = metadata_patches + stream_tunebody_patches\n",
195
+ " else:\n",
196
+ " patches = metadata_patches + tunebody_patches\n",
197
+ " else:\n",
198
+ " patches = metadata_patches + tunebody_patches\n",
199
+ "\n",
200
+ " if cut: \n",
201
+ " patches = patches[ : patch_length]\n",
202
+ " else: \n",
203
+ " pass\n",
204
+ "\n",
205
+ " # encode to ids\n",
206
+ " id_patches = []\n",
207
+ " for patch in patches:\n",
208
+ " id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))\n",
209
+ " id_patches.append(id_patch)\n",
210
+ "\n",
211
+ " return id_patches\n",
212
+ "\n",
213
+ " def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):\n",
214
+ "\n",
215
+ " lines = abc_code.split('\\n')\n",
216
+ " lines = list(filter(None, lines))\n",
217
+ " \n",
218
+ " tunebody_index = None\n",
219
+ " for i, line in enumerate(lines):\n",
220
+ " if line.startswith('[V:') or line.startswith('[r:'):\n",
221
+ " tunebody_index = i\n",
222
+ " break\n",
223
+ " \n",
224
+ " metadata_lines = lines[ : tunebody_index]\n",
225
+ " tunebody_lines = lines[tunebody_index : ] \n",
226
+ " \n",
227
+ " metadata_lines = [line + '\\n' for line in metadata_lines]\n",
228
+ " if self.stream:\n",
229
+ " if not abc_code.endswith('\\n'):\n",
230
+ " tunebody_lines = [tunebody_lines[i] + '\\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]\n",
231
+ " else:\n",
232
+ " tunebody_lines = [tunebody_lines[i] + '\\n' for i in range(len(tunebody_lines))]\n",
233
+ " else:\n",
234
+ " tunebody_lines = [line + '\\n' for line in tunebody_lines]\n",
235
+ " \n",
236
+ " metadata_patches = self.patchilize_metadata(metadata_lines)\n",
237
+ " tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')\n",
238
+ " \n",
239
+ " if add_special_patches:\n",
240
+ " bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)\n",
241
+ "\n",
242
+ " metadata_patches = [bos_patch] + metadata_patches\n",
243
+ " \n",
244
+ " patches = metadata_patches + tunebody_patches\n",
245
+ " patches = patches[ : patch_length]\n",
246
+ "\n",
247
+ " # encode to ids\n",
248
+ " id_patches = []\n",
249
+ " for patch in patches:\n",
250
+ " if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):\n",
251
+ " id_patch = [ord(c) for c in patch]\n",
252
+ " else:\n",
253
+ " id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))\n",
254
+ " id_patches.append(id_patch)\n",
255
+ " \n",
256
+ " return id_patches\n",
257
+ "\n",
258
+ " def decode(self, patches):\n",
259
+ " \"\"\"\n",
260
+ " Decode patches into music.\n",
261
+ " \"\"\"\n",
262
+ " return ''.join(self.patch2chars(patch) for patch in patches)\n",
263
+ "\n",
264
+ "\n",
265
+ "class PatchLevelDecoder(PreTrainedModel):\n",
266
+ " \"\"\"\n",
267
+ " A Patch-level Decoder model for generating patch features in an auto-regressive manner. \n",
268
+ " It inherits PreTrainedModel from transformers.\n",
269
+ " \"\"\"\n",
270
+ " def __init__(self, config):\n",
271
+ " super().__init__(config)\n",
272
+ " self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)\n",
273
+ " torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)\n",
274
+ " self.base = GPT2Model(config)\n",
275
+ "\n",
276
+ " def forward(self,\n",
277
+ " patches: torch.Tensor,\n",
278
+ " masks=None) -> torch.Tensor:\n",
279
+ " \"\"\"\n",
280
+ " The forward pass of the patch-level decoder model.\n",
281
+ " :param patches: the patches to be encoded\n",
282
+ " :param masks: the masks for the patches\n",
283
+ " :return: the encoded patches\n",
284
+ " \"\"\"\n",
285
+ " patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)\n",
286
+ " patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))\n",
287
+ " patches = self.patch_embedding(patches.to(self.device))\n",
288
+ "\n",
289
+ " if masks==None:\n",
290
+ " return self.base(inputs_embeds=patches)\n",
291
+ " else:\n",
292
+ " return self.base(inputs_embeds=patches,\n",
293
+ " attention_mask=masks)\n",
294
+ "\n",
295
+ "\n",
296
+ "class CharLevelDecoder(PreTrainedModel):\n",
297
+ " \"\"\"\n",
298
+ " A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner\n",
299
+ " based on the encoded patch features. It inherits PreTrainedModel from transformers.\n",
300
+ " \"\"\"\n",
301
+ " def __init__(self, config):\n",
302
+ " super().__init__(config)\n",
303
+ " self.special_token_id = 0\n",
304
+ " self.bos_token_id = 1\n",
305
+ "\n",
306
+ " self.base = GPT2LMHeadModel(config)\n",
307
+ "\n",
308
+ " def forward(self,\n",
309
+ " encoded_patches: torch.Tensor,\n",
310
+ " target_patches: torch.Tensor):\n",
311
+ " \"\"\"\n",
312
+ " The forward pass of the char-level decoder model.\n",
313
+ " :param encoded_patches: the encoded patches\n",
314
+ " :param target_patches: the target patches\n",
315
+ " :return: the output of the model\n",
316
+ " \"\"\"\n",
317
+ " # preparing the labels for model training\n",
318
+ " target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)\n",
319
+ " # print('target_patches shape:', target_patches.shape)\n",
320
+ "\n",
321
+ " target_masks = target_patches == self.special_token_id\n",
322
+ " labels = target_patches.clone().masked_fill_(target_masks, -100)\n",
323
+ "\n",
324
+ " # masking the labels for model training\n",
325
+ " target_masks = torch.ones_like(labels)\n",
326
+ " target_masks = target_masks.masked_fill_(labels == -100, 0)\n",
327
+ "\n",
328
+ " # select patches\n",
329
+ " if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:\n",
330
+ " indices = list(range(len(target_patches)))\n",
331
+ " random.shuffle(indices)\n",
332
+ " selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])\n",
333
+ "\n",
334
+ " target_patches = target_patches[selected_indices,:]\n",
335
+ " target_masks = target_masks[selected_indices,:]\n",
336
+ " encoded_patches = encoded_patches[selected_indices,:]\n",
337
+ "\n",
338
+ " # get input embeddings\n",
339
+ " inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)\n",
340
+ "\n",
341
+ " # concatenate the encoded patches with the input embeddings\n",
342
+ " inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)\n",
343
+ "\n",
344
+ " output = self.base(inputs_embeds=inputs_embeds, \n",
345
+ " attention_mask=target_masks,\n",
346
+ " labels=labels)\n",
347
+ " # output_hidden_states=True=True)\n",
348
+ "\n",
349
+ " return output\n",
350
+ "\n",
351
+ " def generate(self,\n",
352
+ " encoded_patch: torch.Tensor, # [hidden_size]\n",
353
+ " tokens: torch.Tensor): # [1]\n",
354
+ " \"\"\"\n",
355
+ " The generate function for generating a patch based on the encoded patch and already generated tokens.\n",
356
+ " :param encoded_patch: the encoded patch\n",
357
+ " :param tokens: already generated tokens in the patch\n",
358
+ " :return: the probability distribution of next token\n",
359
+ " \"\"\"\n",
360
+ " encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]\n",
361
+ " tokens = tokens.reshape(1, -1)\n",
362
+ "\n",
363
+ " # Get input embeddings\n",
364
+ " tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)\n",
365
+ "\n",
366
+ " # Concatenate the encoded patch with the input embeddings\n",
367
+ " tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)\n",
368
+ " \n",
369
+ " # Get output from model\n",
370
+ " outputs = self.base(inputs_embeds=tokens)\n",
371
+ " \n",
372
+ " # Get probabilities of next token\n",
373
+ " probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)\n",
374
+ "\n",
375
+ " return probs\n",
376
+ "\n",
377
+ "class NotaGenLMHeadModel(PreTrainedModel):\n",
378
+ " \"\"\"\n",
379
+ " NotaGen is a language model with a hierarchical structure.\n",
380
+ " It includes a patch-level decoder and a char-level decoder.\n",
381
+ " The patch-level decoder is used to generate patch features in an auto-regressive manner.\n",
382
+ " The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.\n",
383
+ " It inherits PreTrainedModel from transformers.\n",
384
+ " \"\"\"\n",
385
+ " def __init__(self, encoder_config, decoder_config):\n",
386
+ " super().__init__(encoder_config)\n",
387
+ " self.special_token_id = 0\n",
388
+ " self.bos_token_id = 1\n",
389
+ " self.eos_token_id = 2\n",
390
+ " self.patch_level_decoder = PatchLevelDecoder(encoder_config)\n",
391
+ " self.char_level_decoder = CharLevelDecoder(decoder_config)\n",
392
+ "\n",
393
+ " def forward(self,\n",
394
+ " patches: torch.Tensor,\n",
395
+ " masks: torch.Tensor):\n",
396
+ " \"\"\"\n",
397
+ " The forward pass of the bGPT model.\n",
398
+ " :param patches: the patches to be encoded\n",
399
+ " :param masks: the masks for the patches\n",
400
+ " :return: the decoded patches\n",
401
+ " \"\"\"\n",
402
+ " patches = patches.reshape(len(patches), -1, PATCH_SIZE)\n",
403
+ " encoded_patches = self.patch_level_decoder(patches, masks)[\"last_hidden_state\"]\n",
404
+ " \n",
405
+ " left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)\n",
406
+ " masks[:, 0] = 0\n",
407
+ " \n",
408
+ " encoded_patches = encoded_patches[left_shift_masks == 1]\n",
409
+ " patches = patches[masks == 1] \n",
410
+ "\n",
411
+ " return self.char_level_decoder(encoded_patches, patches)\n",
412
+ " \n",
413
+ " def generate(self,\n",
414
+ " patches: torch.Tensor,\n",
415
+ " top_k=0,\n",
416
+ " top_p=1,\n",
417
+ " temperature=1.0):\n",
418
+ " \"\"\"\n",
419
+ " The generate function for generating patches based on patches.\n",
420
+ " :param patches: the patches to be encoded\n",
421
+ " :param top_k: the top k for sampling\n",
422
+ " :param top_p: the top p for sampling\n",
423
+ " :param temperature: the temperature for sampling\n",
424
+ " :return: the generated patches\n",
425
+ " \"\"\"\n",
426
+ " if patches.shape[-1] % PATCH_SIZE != 0:\n",
427
+ " tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)\n",
428
+ " tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)\n",
429
+ " patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]\n",
430
+ " else:\n",
431
+ " tokens = torch.tensor([self.bos_token_id], device=self.device)\n",
432
+ "\n",
433
+ " patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]\n",
434
+ " encoded_patches = self.patch_level_decoder(patches)[\"last_hidden_state\"] # [bs, seq, hidden_size]\n",
435
+ " generated_patch = [] \n",
436
+ "\n",
437
+ " while True:\n",
438
+ " prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]\n",
439
+ " prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]\n",
440
+ " prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]\n",
441
+ " token = temperature_sampling(prob, temperature=temperature) # int\n",
442
+ " char = chr(token)\n",
443
+ " generated_patch.append(token)\n",
444
+ "\n",
445
+ " if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:\n",
446
+ " break\n",
447
+ " else:\n",
448
+ " tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)\n",
449
+ " \n",
450
+ " return generated_patch\n",
451
+ "\n",
452
+ "def clean_to_abc(raw_text, unreduce=True, output_path='output.abc'):\n",
453
+ " # Remove [r:x/y] tags\n",
454
+ " cleaned = re.sub(r'\\[r:\\d+/\\d+\\]', '', raw_text)\n",
455
+ "\n",
456
+ " # Add required ABC headers\n",
457
+ " lines = cleaned.strip().splitlines()\n",
458
+ " header_inserted = False\n",
459
+ " abc_lines = []\n",
460
+ " for line in lines:\n",
461
+ " if not header_inserted and line.startswith('%%score'):\n",
462
+ " abc_lines.insert(0, 'T:Generated\\n')\n",
463
+ " abc_lines.insert(0, 'X:1\\n')\n",
464
+ " header_inserted = True\n",
465
+ " abc_lines.append(line if line.endswith('\\n') else line + '\\n')\n",
466
+ "\n",
467
+ " # Optional: fill missing rests\n",
468
+ " if unreduce:\n",
469
+ " try:\n",
470
+ " abc_lines = rest_unreduce(abc_lines)\n",
471
+ " except Exception as e:\n",
472
+ " print(\"Unreduce failed:\", e)\n",
473
+ "\n",
474
+ " # Save to .abc file\n",
475
+ " Path(output_path).write_text(''.join(abc_lines), encoding='utf-8')\n",
476
+ " print(f\"Saved cleaned ABC to {output_path}\")\n",
477
+ " return output_path"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "id": "6d126533-a9a1-48a5-9b1b-be6da37a55ad",
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "Note_list = Note_list + ['z', 'x']\n",
488
+ "\n",
489
+ "patchilizer = Patchilizer()\n",
490
+ "\n",
491
+ "patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,\n",
492
+ " max_length=PATCH_LENGTH,\n",
493
+ " max_position_embeddings=PATCH_LENGTH,\n",
494
+ " n_embd=HIDDEN_SIZE,\n",
495
+ " num_attention_heads=HIDDEN_SIZE // 64,\n",
496
+ " vocab_size=1)\n",
497
+ "byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,\n",
498
+ " max_length=PATCH_SIZE + 1,\n",
499
+ " max_position_embeddings=PATCH_SIZE + 1,\n",
500
+ " hidden_size=HIDDEN_SIZE,\n",
501
+ " num_attention_heads=HIDDEN_SIZE // 64,\n",
502
+ " vocab_size=128)\n",
503
+ "\n",
504
+ "model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)\n",
505
+ "\n",
506
+ "def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):\n",
507
+ " \"\"\"\n",
508
+ " Prepare model for k-bit training.\n",
509
+ " Features include:\n",
510
+ " 1. Convert model to mixed precision (FP16).\n",
511
+ " 2. Disable unnecessary gradient computations.\n",
512
+ " 3. Enable gradient checkpointing (optional).\n",
513
+ " \"\"\"\n",
514
+ " # Convert model to mixed precision\n",
515
+ " model = model.to(dtype=torch.float16)\n",
516
+ "\n",
517
+ " # Disable gradients for embedding layers\n",
518
+ " for param in model.parameters():\n",
519
+ " if param.dtype == torch.float32:\n",
520
+ " param.requires_grad = False\n",
521
+ "\n",
522
+ " # Enable gradient checkpointing\n",
523
+ " if use_gradient_checkpointing:\n",
524
+ " model.gradient_checkpointing_enable()\n",
525
+ "\n",
526
+ " return model\n",
527
+ "\n",
528
+ "\n",
529
+ "model = prepare_model_for_kbit_training(\n",
530
+ " model,\n",
531
+ " use_gradient_checkpointing=False \n",
532
+ ")\n",
533
+ "\n",
534
+ "print(\"Parameter Number: \" + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))\n",
535
+ "\n",
536
+ "checkpoint = torch.load(INFERENCE_WEIGHTS_PATH, map_location=torch.device(device))\n",
537
+ "model.load_state_dict(checkpoint['model'])\n",
538
+ "model = model.to(device)\n",
539
+ "model.eval()\n",
540
+ "\n",
541
+ "def complete_brackets(s):\n",
542
+ " stack = []\n",
543
+ " bracket_map = {'{': '}', '[': ']', '(': ')'}\n",
544
+ " \n",
545
+ " # Iterate through each character, handle bracket matching\n",
546
+ " for char in s:\n",
547
+ " if char in bracket_map:\n",
548
+ " stack.append(char)\n",
549
+ " elif char in bracket_map.values():\n",
550
+ " # Find the corresponding left bracket\n",
551
+ " for key, value in bracket_map.items():\n",
552
+ " if value == char:\n",
553
+ " if stack and stack[-1] == key:\n",
554
+ " stack.pop()\n",
555
+ " break # Found matching right bracket, process next character\n",
556
+ " \n",
557
+ " # Complete missing right brackets (in reverse order of remaining left brackets in stack)\n",
558
+ " completion = ''.join(bracket_map[c] for c in reversed(stack))\n",
559
+ " return s + completion\n",
560
+ "\n",
561
+ "\n",
562
+ "def rest_unreduce(abc_lines):\n",
563
+ "\n",
564
+ " tunebody_index = None\n",
565
+ " for i in range(len(abc_lines)):\n",
566
+ " if abc_lines[i].startswith('%%score'):\n",
567
+ " abc_lines[i] = complete_brackets(abc_lines[i])\n",
568
+ " if '[V:' in abc_lines[i]:\n",
569
+ " tunebody_index = i\n",
570
+ " break\n",
571
+ "\n",
572
+ " metadata_lines = abc_lines[: tunebody_index]\n",
573
+ " tunebody_lines = abc_lines[tunebody_index:]\n",
574
+ "\n",
575
+ " part_symbol_list = []\n",
576
+ " voice_group_list = []\n",
577
+ " for line in metadata_lines:\n",
578
+ " if line.startswith('%%score'):\n",
579
+ " for round_bracket_match in re.findall(r'\\((.*?)\\)', line):\n",
580
+ " voice_group_list.append(round_bracket_match.split())\n",
581
+ " existed_voices = [item for sublist in voice_group_list for item in sublist]\n",
582
+ " if line.startswith('V:'):\n",
583
+ " symbol = line.split()[0]\n",
584
+ " part_symbol_list.append(symbol)\n",
585
+ " if symbol[2:] not in existed_voices:\n",
586
+ " voice_group_list.append([symbol[2:]])\n",
587
+ " z_symbol_list = [] # voices that use z as rest\n",
588
+ " x_symbol_list = [] # voices that use x as rest\n",
589
+ " for voice_group in voice_group_list:\n",
590
+ " z_symbol_list.append('V:' + voice_group[0])\n",
591
+ " for j in range(1, len(voice_group)):\n",
592
+ " x_symbol_list.append('V:' + voice_group[j])\n",
593
+ "\n",
594
+ " part_symbol_list.sort(key=lambda x: int(x[2:]))\n",
595
+ "\n",
596
+ " unreduced_tunebody_lines = []\n",
597
+ "\n",
598
+ " for i, line in enumerate(tunebody_lines):\n",
599
+ " unreduced_line = ''\n",
600
+ "\n",
601
+ " line = re.sub(r'^\\[r:[^\\]]*\\]', '', line)\n",
602
+ "\n",
603
+ " pattern = r'\\[V:(\\d+)\\](.*?)(?=\\[V:|$)'\n",
604
+ " matches = re.findall(pattern, line)\n",
605
+ "\n",
606
+ " line_bar_dict = {}\n",
607
+ " for match in matches:\n",
608
+ " key = f'V:{match[0]}'\n",
609
+ " value = match[1]\n",
610
+ " line_bar_dict[key] = value\n",
611
+ "\n",
612
+ " # calculate duration and collect barline\n",
613
+ " dur_dict = {} \n",
614
+ " for symbol, bartext in line_bar_dict.items():\n",
615
+ " right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])\n",
616
+ " bartext = bartext[:-len(right_barline)]\n",
617
+ " try:\n",
618
+ " bar_dur = calculate_bartext_duration(bartext)\n",
619
+ " except:\n",
620
+ " bar_dur = None\n",
621
+ " if bar_dur is not None:\n",
622
+ " if bar_dur not in dur_dict.keys():\n",
623
+ " dur_dict[bar_dur] = 1\n",
624
+ " else:\n",
625
+ " dur_dict[bar_dur] += 1\n",
626
+ "\n",
627
+ " try:\n",
628
+ " ref_dur = max(dur_dict, key=dur_dict.get)\n",
629
+ " except:\n",
630
+ " pass # use last ref_dur\n",
631
+ "\n",
632
+ " if i == 0:\n",
633
+ " prefix_left_barline = line.split('[V:')[0]\n",
634
+ " else:\n",
635
+ " prefix_left_barline = ''\n",
636
+ "\n",
637
+ " for symbol in part_symbol_list:\n",
638
+ " if symbol in line_bar_dict.keys():\n",
639
+ " symbol_bartext = line_bar_dict[symbol]\n",
640
+ " else:\n",
641
+ " if symbol in z_symbol_list:\n",
642
+ " symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline\n",
643
+ " elif symbol in x_symbol_list:\n",
644
+ " symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline\n",
645
+ " unreduced_line += '[' + symbol + ']' + symbol_bartext\n",
646
+ "\n",
647
+ " unreduced_tunebody_lines.append(unreduced_line + '\\n')\n",
648
+ "\n",
649
+ " unreduced_lines = metadata_lines + unreduced_tunebody_lines\n",
650
+ "\n",
651
+ " return unreduced_lines\n",
652
+ "\n",
653
+ "\n",
654
+ "def inference_patch(period, composer, instrumentation):\n",
655
+ "\n",
656
+ " prompt_lines=[\n",
657
+ " '%' + period + '\\n',\n",
658
+ " '%' + composer + '\\n',\n",
659
+ " '%' + instrumentation + '\\n']\n",
660
+ "\n",
661
+ " while True:\n",
662
+ "\n",
663
+ " failure_flag = False\n",
664
+ "\n",
665
+ " bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]\n",
666
+ "\n",
667
+ " start_time = time.time()\n",
668
+ "\n",
669
+ " prompt_patches = patchilizer.patchilize_metadata(prompt_lines)\n",
670
+ " byte_list = list(''.join(prompt_lines))\n",
671
+ " context_tunebody_byte_list = []\n",
672
+ " metadata_byte_list = []\n",
673
+ "\n",
674
+ " print(''.join(byte_list), end='')\n",
675
+ "\n",
676
+ " prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch\n",
677
+ " in prompt_patches]\n",
678
+ " prompt_patches.insert(0, bos_patch)\n",
679
+ "\n",
680
+ " input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)\n",
681
+ "\n",
682
+ " end_flag = False\n",
683
+ " cut_index = None\n",
684
+ "\n",
685
+ " tunebody_flag = False\n",
686
+ "\n",
687
+ " with torch.inference_mode():\n",
688
+ " \n",
689
+ " while True:\n",
690
+ " with torch.autocast(device_type='cuda', dtype=torch.float16):\n",
691
+ " predicted_patch = model.generate(input_patches.unsqueeze(0),\n",
692
+ " top_k=TOP_K,\n",
693
+ " top_p=TOP_P,\n",
694
+ " temperature=TEMPERATURE)\n",
695
+ " if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # 初次进入tunebody,必须以[r:0/开头\n",
696
+ " tunebody_flag = True\n",
697
+ " r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)\n",
698
+ " temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)\n",
699
+ " predicted_patch = model.generate(temp_input_patches.unsqueeze(0),\n",
700
+ " top_k=TOP_K,\n",
701
+ " top_p=TOP_P,\n",
702
+ " temperature=TEMPERATURE)\n",
703
+ " predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch\n",
704
+ " if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:\n",
705
+ " end_flag = True\n",
706
+ " break\n",
707
+ " next_patch = patchilizer.decode([predicted_patch])\n",
708
+ "\n",
709
+ " for char in next_patch:\n",
710
+ " byte_list.append(char)\n",
711
+ " if tunebody_flag:\n",
712
+ " context_tunebody_byte_list.append(char)\n",
713
+ " else:\n",
714
+ " metadata_byte_list.append(char)\n",
715
+ " print(char, end='')\n",
716
+ "\n",
717
+ " patch_end_flag = False\n",
718
+ " for j in range(len(predicted_patch)):\n",
719
+ " if patch_end_flag:\n",
720
+ " predicted_patch[j] = patchilizer.special_token_id\n",
721
+ " if predicted_patch[j] == patchilizer.eos_token_id:\n",
722
+ " patch_end_flag = True\n",
723
+ "\n",
724
+ " predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)\n",
725
+ " input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)\n",
726
+ "\n",
727
+ " if len(byte_list) > 102400:\n",
728
+ " failure_flag = True\n",
729
+ " break\n",
730
+ " if time.time() - start_time > 10 * 60: \n",
731
+ " failure_flag = True\n",
732
+ " break\n",
733
+ "\n",
734
+ " if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:\n",
735
+ " print('Stream generating...')\n",
736
+ "\n",
737
+ " metadata = ''.join(metadata_byte_list)\n",
738
+ " context_tunebody = ''.join(context_tunebody_byte_list)\n",
739
+ "\n",
740
+ " if '\\n' not in context_tunebody:\n",
741
+ " break # Generated content is all metadata, abandon\n",
742
+ "\n",
743
+ " context_tunebody_liness = context_tunebody.split('\\n')\n",
744
+ " if not context_tunebody.endswith('\\n'):\n",
745
+ " context_tunebody_liness = [context_tunebody_liness[i] + '\\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]\n",
746
+ " else:\n",
747
+ " context_tunebody_liness = [context_tunebody_liness[i] + '\\n' for i in range(len(context_tunebody_liness))]\n",
748
+ "\n",
749
+ " cut_index = len(context_tunebody_liness) // 2\n",
750
+ " abc_code_slice = metadata + ''.join(context_tunebody_liness[-cut_index:])\n",
751
+ "\n",
752
+ " input_patches = patchilizer.encode_generate(abc_code_slice)\n",
753
+ "\n",
754
+ " input_patches = [item for sublist in input_patches for item in sublist]\n",
755
+ " input_patches = torch.tensor([input_patches], device=device)\n",
756
+ " input_patches = input_patches.reshape(1, -1)\n",
757
+ "\n",
758
+ " context_tunebody_byte_list = list(''.join(context_tunebody_lines[-cut_index:]))\n",
759
+ "\n",
760
+ " if not failure_flag:\n",
761
+ " abc_text = ''.join(byte_list)\n",
762
+ "\n",
763
+ " # unreduce\n",
764
+ " abc_lines = abc_text.split('\\n')\n",
765
+ " abc_lines = list(filter(None, abc_lines))\n",
766
+ " abc_lines = [line + '\\n' for line in abc_lines]\n",
767
+ " try:\n",
768
+ " unreduced_abc_lines = rest_unreduce(abc_lines)\n",
769
+ " except:\n",
770
+ " failure_flag = True\n",
771
+ " pass\n",
772
+ " else:\n",
773
+ " unreduced_abc_lines = [line for line in unreduced_abc_lines if not(line.startswith('%') and not line.startswith('%%'))]\n",
774
+ " unreduced_abc_lines = ['X:1\\n'] + unreduced_abc_lines\n",
775
+ " unreduced_abc_text = ''.join(unreduced_abc_lines)\n",
776
+ " return unreduced_abc_text"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": null,
782
+ "id": "502c4420-533b-43cc-80ca-b2c94cd4be04",
783
+ "metadata": {},
784
+ "outputs": [],
785
+ "source": [
786
+ "result = inference_patch('Classical', 'Beethoven, Ludwig van', 'Art Song')\n",
787
+ "\n",
788
+ "abc_lines = result.splitlines()\n",
789
+ "abc_lines = [line + '\\n' for line in abc_lines if line.strip()] # Add newlines and remove empty lines\n",
790
+ "\n",
791
+ "abc_lines = rest_unreduce(abc_lines)\n",
792
+ "\n",
793
+ "with open(\"output.abc\", \"w\", encoding=\"utf-8\") as f:\n",
794
+ " f.writelines(abc_lines)\n",
795
+ "\n",
796
+ "!python abc2xml.py -o . output.abc"
797
+ ]
798
+ }
799
+ ],
800
+ "metadata": {
801
+ "kernelspec": {
802
+ "display_name": "Python 3 (ipykernel)",
803
+ "language": "python",
804
+ "name": "python3"
805
+ },
806
+ "language_info": {
807
+ "codemirror_mode": {
808
+ "name": "ipython",
809
+ "version": 3
810
+ },
811
+ "file_extension": ".py",
812
+ "mimetype": "text/x-python",
813
+ "name": "python",
814
+ "nbconvert_exporter": "python",
815
+ "pygments_lexer": "ipython3",
816
+ "version": "3.10.0"
817
+ }
818
+ },
819
+ "nbformat": 4,
820
+ "nbformat_minor": 5
821
+ }
demo.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import threading
4
+ import queue
5
+ from io import TextIOBase
6
+ from inference import inference_patch
7
+ import datetime
8
+ import subprocess
9
+ import os
10
+
11
+ # Predefined valid combinations set
12
+ with open('prompts.txt', 'r') as f:
13
+ prompts = f.readlines()
14
+ valid_combinations = set()
15
+ for prompt in prompts:
16
+ prompt = prompt.strip()
17
+ parts = prompt.split('_')
18
+ valid_combinations.add((parts[0], parts[1], parts[2]))
19
+
20
+ # Generate available options
21
+ periods = sorted({p for p, _, _ in valid_combinations})
22
+ composers = sorted({c for _, c, _ in valid_combinations})
23
+ instruments = sorted({i for _, _, i in valid_combinations})
24
+
25
+ # Dynamic component updates
26
+ def update_components(period, composer):
27
+ if not period:
28
+ return [
29
+ gr.Dropdown(choices=[], value=None, interactive=False),
30
+ gr.Dropdown(choices=[], value=None, interactive=False)
31
+ ]
32
+
33
+ valid_composers = sorted({c for p, c, _ in valid_combinations if p == period})
34
+ valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else []
35
+
36
+ return [
37
+ gr.Dropdown(
38
+ choices=valid_composers,
39
+ value=composer if composer in valid_composers else None,
40
+ interactive=True
41
+ ),
42
+ gr.Dropdown(
43
+ choices=valid_instruments,
44
+ value=None,
45
+ interactive=bool(valid_instruments)
46
+ )
47
+ ]
48
+
49
+
50
+ class RealtimeStream(TextIOBase):
51
+ def __init__(self, queue):
52
+ self.queue = queue
53
+
54
+ def write(self, text):
55
+ self.queue.put(text)
56
+ return len(text)
57
+
58
+
59
+ def save_and_convert(abc_content, period, composer, instrumentation):
60
+ if not all([period, composer, instrumentation]):
61
+ raise gr.Error("Please complete a valid generation first before saving")
62
+
63
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
64
+ prompt_str = f"{period}_{composer}_{instrumentation}"
65
+ filename_base = f"{timestamp}_{prompt_str}"
66
+
67
+ abc_filename = f"{filename_base}.abc"
68
+ with open(abc_filename, "w", encoding="utf-8") as f:
69
+ f.write(abc_content)
70
+
71
+ xml_filename = f"{filename_base}.xml"
72
+ try:
73
+ subprocess.run(
74
+ ["python", "abc2xml.py", '-o', '.', abc_filename, ],
75
+ check=True,
76
+ capture_output=True,
77
+ text=True
78
+ )
79
+ except subprocess.CalledProcessError as e:
80
+ error_msg = f"Conversion failed: {e.stderr}" if e.stderr else "Unknown error"
81
+ raise gr.Error(f"ABC to XML conversion failed: {error_msg}. Please try to generate another composition.")
82
+
83
+ return f"Saved successfully: {abc_filename} -> {xml_filename}"
84
+
85
+
86
+
87
+ def generate_music(period, composer, instrumentation):
88
+ if (period, composer, instrumentation) not in valid_combinations:
89
+ raise gr.Error("Invalid prompt combination! Please re-select from the period options")
90
+
91
+ output_queue = queue.Queue()
92
+ original_stdout = sys.stdout
93
+ sys.stdout = RealtimeStream(output_queue)
94
+
95
+ result_container = []
96
+ def run_inference():
97
+ try:
98
+ result_container.append(inference_patch(period, composer, instrumentation))
99
+ finally:
100
+ sys.stdout = original_stdout
101
+
102
+ thread = threading.Thread(target=run_inference)
103
+ thread.start()
104
+
105
+ process_output = ""
106
+ while thread.is_alive():
107
+ try:
108
+ text = output_queue.get(timeout=0.1)
109
+ process_output += text
110
+ yield process_output, None
111
+ except queue.Empty:
112
+ continue
113
+
114
+ while not output_queue.empty():
115
+ text = output_queue.get()
116
+ process_output += text
117
+ yield process_output, None
118
+
119
+ final_result = result_container[0] if result_container else ""
120
+ yield process_output, final_result
121
+
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown("## NotaGen")
124
+
125
+ with gr.Row():
126
+ # 左侧栏
127
+ with gr.Column():
128
+ period_dd = gr.Dropdown(
129
+ choices=periods,
130
+ value=None,
131
+ label="Period",
132
+ interactive=True
133
+ )
134
+ composer_dd = gr.Dropdown(
135
+ choices=[],
136
+ value=None,
137
+ label="Composer",
138
+ interactive=False
139
+ )
140
+ instrument_dd = gr.Dropdown(
141
+ choices=[],
142
+ value=None,
143
+ label="Instrumentation",
144
+ interactive=False
145
+ )
146
+
147
+ generate_btn = gr.Button("Generate!", variant="primary")
148
+
149
+ process_output = gr.Textbox(
150
+ label="Generation process",
151
+ interactive=False,
152
+ lines=15,
153
+ max_lines=15,
154
+ placeholder="Generation progress will be shown here...",
155
+ elem_classes="process-output"
156
+ )
157
+
158
+ # 右侧栏
159
+ with gr.Column():
160
+ final_output = gr.Textbox(
161
+ label="Post-processed ABC notation scores",
162
+ interactive=True,
163
+ lines=23,
164
+ placeholder="Post-processed ABC scores will be shown here...",
165
+ elem_classes="final-output"
166
+ )
167
+
168
+ with gr.Row():
169
+ save_btn = gr.Button("💾 Save as ABC & XML files", variant="secondary")
170
+
171
+ save_status = gr.Textbox(
172
+ label="Save Status",
173
+ interactive=False,
174
+ visible=True,
175
+ max_lines=2
176
+ )
177
+
178
+ period_dd.change(
179
+ update_components,
180
+ inputs=[period_dd, composer_dd],
181
+ outputs=[composer_dd, instrument_dd]
182
+ )
183
+ composer_dd.change(
184
+ update_components,
185
+ inputs=[period_dd, composer_dd],
186
+ outputs=[composer_dd, instrument_dd]
187
+ )
188
+
189
+ generate_btn.click(
190
+ generate_music,
191
+ inputs=[period_dd, composer_dd, instrument_dd],
192
+ outputs=[process_output, final_output]
193
+ )
194
+
195
+ save_btn.click(
196
+ save_and_convert,
197
+ inputs=[final_output, period_dd, composer_dd, instrument_dd],
198
+ outputs=[save_status]
199
+ )
200
+
201
+
202
+ css = """
203
+ .process-output {
204
+ background-color: #f0f0f0;
205
+ font-family: monospace;
206
+ padding: 10px;
207
+ border-radius: 5px;
208
+ }
209
+ .final-output {
210
+ background-color: #ffffff;
211
+ font-family: sans-serif;
212
+ padding: 10px;
213
+ border-radius: 5px;
214
+ }
215
+
216
+ .process-output textarea {
217
+ max-height: 500px !important;
218
+ overflow-y: auto !important;
219
+ white-space: pre-wrap;
220
+ }
221
+
222
+ """
223
+ css += """
224
+ button#💾-save-convert:hover {
225
+ background-color: #ffe6e6;
226
+ }
227
+ """
228
+
229
+ demo.css = css
230
+
231
+ if __name__ == "__main__":
232
+
233
+ demo.launch(
234
+ server_name="0.0.0.0",
235
+ server_port=7861
236
+ )
extract_clamp2.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dir = '' # interleaved abc folder
2
+ output_dir = '' # feature folder
3
+
4
+ import os
5
+ import json
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from config import *
11
+ from utils import *
12
+ from samplings import *
13
+ from accelerate import Accelerator
14
+ from transformers import BertConfig, AutoTokenizer
15
+ import argparse
16
+
17
+
18
+ normalize = True
19
+
20
+ os.makedirs("logs", exist_ok=True)
21
+ for file in ["logs/files_extract_clamp2.json",
22
+ "logs/files_shuffle_extract_clamp2.json",
23
+ "logs/log_extract_clamp2.txt",
24
+ "logs/pass_extract_clamp2.txt",
25
+ "logs/skip_extract_clamp2.txt"]:
26
+ if os.path.exists(file):
27
+ os.remove(file)
28
+
29
+ files = []
30
+ for root, dirs, fs in os.walk(input_dir):
31
+ for f in fs:
32
+ if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf"):
33
+ files.append(os.path.join(root, f))
34
+ print(f"Found {len(files)} files in total")
35
+ with open("logs/files_extract_clamp2.json", "w", encoding="utf-8") as f:
36
+ json.dump(files, f)
37
+ random.shuffle(files)
38
+ with open("logs/files_shuffle_extract_clamp2.json", "w", encoding="utf-8") as f:
39
+ json.dump(files, f)
40
+
41
+ accelerator = Accelerator()
42
+ device = accelerator.device
43
+ print("Using device:", device)
44
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
45
+ f.write("Using device: " + str(device) + "\n")
46
+
47
+ m3_config = BertConfig(vocab_size=1,
48
+ hidden_size=M3_HIDDEN_SIZE,
49
+ num_hidden_layers=PATCH_NUM_LAYERS,
50
+ num_attention_heads=M3_HIDDEN_SIZE//64,
51
+ intermediate_size=M3_HIDDEN_SIZE*4,
52
+ max_position_embeddings=PATCH_LENGTH)
53
+ model = CLaMP2Model(m3_config,
54
+ text_model_name=TEXT_MODEL_NAME,
55
+ hidden_size=CLAMP2_HIDDEN_SIZE,
56
+ load_m3=CLAMP2_LOAD_M3)
57
+ model = model.to(device)
58
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
59
+ patchilizer = M3Patchilizer()
60
+
61
+ # print parameter number
62
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
63
+
64
+ model.eval()
65
+ checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True)
66
+ print(f"Successfully Loaded CLaMP 2 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
67
+ model.load_state_dict(checkpoint['model'])
68
+
69
+ def extract_feature(filename, get_normalized=normalize):
70
+ with open(filename, "r", encoding="utf-8") as f:
71
+ lines = f.readlines()
72
+
73
+ filtered_lines = []
74
+ for line in lines:
75
+ if line.startswith('%') and not line.startswith('%%'):
76
+ pass
77
+ else:
78
+ filtered_lines.append(line)
79
+
80
+ item = ''.join(filtered_lines)
81
+
82
+ if filename.endswith(".txt"):
83
+ item = list(set(item.split("\n")))
84
+ item = "\n".join(item)
85
+ item = item.split("\n")
86
+ item = [c for c in item if len(c) > 0]
87
+ item = tokenizer.sep_token.join(item)
88
+ input_data = tokenizer(item, return_tensors="pt")
89
+ input_data = input_data['input_ids'].squeeze(0)
90
+ max_input_length = MAX_TEXT_LENGTH
91
+ else:
92
+ input_data = patchilizer.encode(item, add_special_patches=True)
93
+ input_data = torch.tensor(input_data)
94
+ max_input_length = PATCH_LENGTH
95
+
96
+ segment_list = []
97
+ for i in range(0, len(input_data), max_input_length):
98
+ segment_list.append(input_data[i:i+max_input_length])
99
+ segment_list[-1] = input_data[-max_input_length:]
100
+
101
+ last_hidden_states_list = []
102
+
103
+ for input_segment in segment_list:
104
+ input_masks = torch.tensor([1]*input_segment.size(0))
105
+ if filename.endswith(".txt"):
106
+ pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
107
+ else:
108
+ pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
109
+ input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
110
+ input_segment = torch.cat((input_segment, pad_indices), 0)
111
+
112
+ if filename.endswith(".txt"):
113
+ last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
114
+ text_masks=input_masks.unsqueeze(0).to(device),
115
+ get_normalized=get_normalized)
116
+ else:
117
+ last_hidden_states = model.get_music_features(music_inputs=input_segment.unsqueeze(0).to(device),
118
+ music_masks=input_masks.unsqueeze(0).to(device),
119
+ get_normalized=get_normalized)
120
+ if not get_normalized:
121
+ last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
122
+ last_hidden_states_list.append(last_hidden_states)
123
+
124
+ if not get_normalized:
125
+ last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
126
+ last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
127
+ last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
128
+ else:
129
+ full_chunk_cnt = len(input_data) // max_input_length
130
+ remain_chunk_len = len(input_data) % max_input_length
131
+ if remain_chunk_len == 0:
132
+ feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
133
+ else:
134
+ feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
135
+
136
+ last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
137
+ last_hidden_states_list = last_hidden_states_list * feature_weights
138
+ last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()
139
+
140
+ return last_hidden_states_list
141
+
142
+ def process_directory(input_dir, output_dir, files):
143
+ print(f"Found {len(files)} files in total")
144
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
145
+ f.write("Found " + str(len(files)) + " files in total\n")
146
+
147
+ # calculate the number of files to process per GPU
148
+ num_files_per_gpu = len(files) // accelerator.num_processes
149
+
150
+ # calculate the start and end index for the current GPU
151
+ start_idx = accelerator.process_index * num_files_per_gpu
152
+ end_idx = start_idx + num_files_per_gpu
153
+ if accelerator.process_index == accelerator.num_processes - 1:
154
+ end_idx = len(files)
155
+
156
+ files_to_process = files[start_idx:end_idx]
157
+
158
+ # process the files
159
+ for file in tqdm(files_to_process):
160
+ output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
161
+ try:
162
+ os.makedirs(output_subdir, exist_ok=True)
163
+ except Exception as e:
164
+ print(output_subdir + " can not be created\n" + str(e))
165
+ with open("logs/log_extract_clamp.txt", "a") as f:
166
+ f.write(output_subdir + " can not be created\n" + str(e) + "\n")
167
+
168
+ output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
169
+
170
+ if os.path.exists(output_file):
171
+ print(f"Skipping {file}, output already exists")
172
+ with open("logs/skip_extract_clamp2.txt", "a", encoding="utf-8") as f:
173
+ f.write(file + "\n")
174
+ continue
175
+
176
+ try:
177
+ with torch.no_grad():
178
+ features = extract_feature(file).unsqueeze(0)
179
+ np.save(output_file, features.detach().cpu().numpy())
180
+ with open("logs/pass_extract_clamp2.txt", "a", encoding="utf-8") as f:
181
+ f.write(file + "\n")
182
+ except Exception as e:
183
+ print(f"Failed to process {file}: {e}")
184
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
185
+ f.write("Failed to process " + file + ": " + str(e) + "\n")
186
+
187
+ with open("logs/files_shuffle_extract_clamp2.json", "r", encoding="utf-8") as f:
188
+ files = json.load(f)
189
+
190
+ # process the files
191
+ process_directory(input_dir, output_dir, files)
192
+
193
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
194
+ f.write("GPU ID: " + str(device) + "\n")
illustration.png ADDED

Git LFS Details

  • SHA256: 10e0d5742ed50035210c40983bdf56d038d0288ebd89881b895e1e50afe609a3
  • Pointer size: 131 Bytes
  • Size of remote file: 384 kB
illustration_online.png ADDED

Git LFS Details

  • SHA256: b13315492555c202c6ba8d0891014f486fee8cac8ca3c908a26686bdc9e27347
  • Pointer size: 131 Bytes
  • Size of remote file: 253 kB
inference (1).py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import torch
5
+ from utils import *
6
+ from config import *
7
+ from transformers import GPT2Config, LlamaConfig
8
+ from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
9
+ from abctoolkit.transpose import Note_list, Pitch_sign_list
10
+ from abctoolkit.duration import calculate_bartext_duration
11
+
12
+ Note_list = Note_list + ['z', 'x']
13
+
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda")
16
+ else:
17
+ device = torch.device("cpu")
18
+
19
+ os.makedirs(ORIGINAL_OUTPUT_FOLDER, exist_ok=True)
20
+ os.makedirs(INTERLEAVED_OUTPUT_FOLDER, exist_ok=True)
21
+
22
+ patchilizer = Patchilizer()
23
+
24
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
25
+ max_length=PATCH_LENGTH,
26
+ max_position_embeddings=PATCH_LENGTH,
27
+ n_embd=HIDDEN_SIZE,
28
+ num_attention_heads=HIDDEN_SIZE // 64,
29
+ vocab_size=1)
30
+ byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
31
+ max_length=PATCH_SIZE + 1,
32
+ max_position_embeddings=PATCH_SIZE + 1,
33
+ hidden_size=HIDDEN_SIZE,
34
+ num_attention_heads=HIDDEN_SIZE // 64,
35
+ vocab_size=128)
36
+
37
+ model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config)
38
+
39
+ print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
40
+
41
+ checkpoint = torch.load(INFERENCE_WEIGHTS_PATH, map_location=torch.device(device))
42
+ model.load_state_dict(checkpoint['model'])
43
+ model = model.to(device)
44
+ model.eval()
45
+
46
+
47
+ def rest_unreduce(abc_lines):
48
+
49
+ tunebody_index = None
50
+ for i in range(len(abc_lines)):
51
+ if '[V:' in abc_lines[i]:
52
+ tunebody_index = i
53
+ break
54
+
55
+ metadata_lines = abc_lines[: tunebody_index]
56
+ tunebody_lines = abc_lines[tunebody_index:]
57
+
58
+ part_symbol_list = []
59
+ voice_group_list = []
60
+ for line in metadata_lines:
61
+ if line.startswith('%%score'):
62
+ for round_bracket_match in re.findall(r'\((.*?)\)', line):
63
+ voice_group_list.append(round_bracket_match.split())
64
+ existed_voices = [item for sublist in voice_group_list for item in sublist]
65
+ if line.startswith('V:'):
66
+ symbol = line.split()[0]
67
+ part_symbol_list.append(symbol)
68
+ if symbol[2:] not in existed_voices:
69
+ voice_group_list.append([symbol[2:]])
70
+ z_symbol_list = [] # voices that use z as rest
71
+ x_symbol_list = [] # voices that use x as rest
72
+ for voice_group in voice_group_list:
73
+ z_symbol_list.append('V:' + voice_group[0])
74
+ for j in range(1, len(voice_group)):
75
+ x_symbol_list.append('V:' + voice_group[j])
76
+
77
+ part_symbol_list.sort(key=lambda x: int(x[2:]))
78
+
79
+ unreduced_tunebody_lines = []
80
+
81
+ for i, line in enumerate(tunebody_lines):
82
+ unreduced_line = ''
83
+
84
+ line = re.sub(r'^\[r:[^\]]*\]', '', line)
85
+
86
+ pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
87
+ matches = re.findall(pattern, line)
88
+
89
+ line_bar_dict = {}
90
+ for match in matches:
91
+ key = f'V:{match[0]}'
92
+ value = match[1]
93
+ line_bar_dict[key] = value
94
+
95
+ # calculate duration and collect barline
96
+ dur_dict = {}
97
+ for symbol, bartext in line_bar_dict.items():
98
+ right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
99
+ bartext = bartext[:-len(right_barline)]
100
+ try:
101
+ bar_dur = calculate_bartext_duration(bartext)
102
+ except:
103
+ bar_dur = None
104
+ if bar_dur is not None:
105
+ if bar_dur not in dur_dict.keys():
106
+ dur_dict[bar_dur] = 1
107
+ else:
108
+ dur_dict[bar_dur] += 1
109
+
110
+ try:
111
+ ref_dur = max(dur_dict, key=dur_dict.get)
112
+ except:
113
+ pass # use last ref_dur
114
+
115
+ if i == 0:
116
+ prefix_left_barline = line.split('[V:')[0]
117
+ else:
118
+ prefix_left_barline = ''
119
+
120
+ for symbol in part_symbol_list:
121
+ if symbol in line_bar_dict.keys():
122
+ symbol_bartext = line_bar_dict[symbol]
123
+ else:
124
+ if symbol in z_symbol_list:
125
+ symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
126
+ elif symbol in x_symbol_list:
127
+ symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
128
+ unreduced_line += '[' + symbol + ']' + symbol_bartext
129
+
130
+ unreduced_tunebody_lines.append(unreduced_line + '\n')
131
+
132
+ unreduced_lines = metadata_lines + unreduced_tunebody_lines
133
+
134
+ return unreduced_lines
135
+
136
+
137
+ def inference_patch(prompt_lines=[], pieces=NUM_SAMPLES):
138
+
139
+ file_no = 1
140
+
141
+ bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
142
+
143
+ while file_no <= pieces:
144
+
145
+ start_time = time.time()
146
+ start_time_format = time.strftime("%Y%m%d-%H%M%S")
147
+
148
+ prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
149
+ byte_list = list(''.join(prompt_lines))
150
+ print(''.join(byte_list), end='')
151
+
152
+ prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
153
+ in prompt_patches]
154
+ prompt_patches.insert(0, bos_patch)
155
+
156
+ input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
157
+
158
+ failure_flag = False
159
+ end_flag = False
160
+ cut_index = None
161
+
162
+ tunebody_flag = False
163
+ while True:
164
+ predicted_patch = model.generate(input_patches.unsqueeze(0),
165
+ top_k=TOP_K,
166
+ top_p=TOP_P,
167
+ temperature=TEMPERATURE)
168
+ if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # start with [r:0/
169
+ tunebody_flag = True
170
+ r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
171
+ temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
172
+ predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
173
+ top_k=TOP_K,
174
+ top_p=TOP_P,
175
+ temperature=TEMPERATURE)
176
+ predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
177
+ if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
178
+ end_flag = True
179
+ break
180
+ next_patch = patchilizer.decode([predicted_patch])
181
+
182
+ for char in next_patch:
183
+ byte_list.append(char)
184
+ print(char, end='')
185
+
186
+ patch_end_flag = False
187
+ for j in range(len(predicted_patch)):
188
+ if patch_end_flag:
189
+ predicted_patch[j] = patchilizer.special_token_id
190
+ if predicted_patch[j] == patchilizer.eos_token_id:
191
+ patch_end_flag = True
192
+
193
+ predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
194
+ input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
195
+
196
+ if len(byte_list) > 102400:
197
+ failure_flag = True
198
+ break
199
+ if time.time() - start_time > 20 * 60:
200
+ failure_flag = True
201
+ break
202
+
203
+ if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
204
+ print('Stream generating...')
205
+ abc_code = ''.join(byte_list)
206
+ abc_lines = abc_code.split('\n')
207
+
208
+ tunebody_index = None
209
+ for i, line in enumerate(abc_lines):
210
+ if line.startswith('[r:') or line.startswith('[V:'):
211
+ tunebody_index = i
212
+ break
213
+ if tunebody_index is None or tunebody_index == len(abc_lines) - 1:
214
+ break
215
+
216
+ metadata_lines = abc_lines[:tunebody_index]
217
+ tunebody_lines = abc_lines[tunebody_index:]
218
+
219
+ metadata_lines = [line + '\n' for line in metadata_lines]
220
+ if not abc_code.endswith('\n'):
221
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [
222
+ tunebody_lines[-1]]
223
+ else:
224
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
225
+
226
+ if cut_index is None:
227
+ cut_index = len(tunebody_lines) // 2
228
+
229
+ abc_code_slice = ''.join(metadata_lines + tunebody_lines[-cut_index:])
230
+ input_patches = patchilizer.encode_generate(abc_code_slice)
231
+
232
+ input_patches = [item for sublist in input_patches for item in sublist]
233
+ input_patches = torch.tensor([input_patches], device=device)
234
+ input_patches = input_patches.reshape(1, -1)
235
+
236
+ if not failure_flag:
237
+ generation_time_cost = time.time() - start_time
238
+
239
+ abc_text = ''.join(byte_list)
240
+ filename = time.strftime("%Y%m%d-%H%M%S") + \
241
+ "_" + format(generation_time_cost, '.2f') + '_' + str(file_no) + ".abc"
242
+
243
+ # unreduce
244
+ unreduced_output_path = os.path.join(INTERLEAVED_OUTPUT_FOLDER, filename)
245
+
246
+ abc_lines = abc_text.split('\n')
247
+ abc_lines = list(filter(None, abc_lines))
248
+ abc_lines = [line + '\n' for line in abc_lines]
249
+ try:
250
+ abc_lines = rest_unreduce(abc_lines)
251
+
252
+ with open(unreduced_output_path, 'w') as file:
253
+ file.writelines(abc_lines)
254
+ except:
255
+ pass
256
+ else:
257
+ # original
258
+ original_output_path = os.path.join(ORIGINAL_OUTPUT_FOLDER, filename)
259
+ with open(original_output_path, 'w') as w:
260
+ w.write(abc_text)
261
+
262
+ file_no += 1
263
+
264
+ else:
265
+ print('failed')
266
+
267
+
268
+
269
+ if __name__ == '__main__':
270
+
271
+ inference_patch()
inference.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import torch
5
+ from utils import *
6
+ from config import *
7
+ from transformers import GPT2Config
8
+ from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
9
+ from abctoolkit.transpose import Note_list, Pitch_sign_list
10
+ from abctoolkit.duration import calculate_bartext_duration
11
+
12
+ Note_list = Note_list + ['z', 'x']
13
+
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda")
16
+ elif torch.backends.mps.is_available():
17
+ device = torch.device("mps")
18
+ else:
19
+ device = torch.device("cpu")
20
+
21
+ patchilizer = Patchilizer()
22
+
23
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
24
+ max_length=PATCH_LENGTH,
25
+ max_position_embeddings=PATCH_LENGTH,
26
+ n_embd=HIDDEN_SIZE,
27
+ num_attention_heads=HIDDEN_SIZE // 64,
28
+ vocab_size=1)
29
+ byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
30
+ max_length=PATCH_SIZE + 1,
31
+ max_position_embeddings=PATCH_SIZE + 1,
32
+ hidden_size=HIDDEN_SIZE,
33
+ num_attention_heads=HIDDEN_SIZE // 64,
34
+ vocab_size=128)
35
+
36
+ model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
37
+
38
+
39
+ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
40
+ """
41
+ Prepare model for k-bit training.
42
+ Features include:
43
+ 1. Convert model to mixed precision (FP16).
44
+ 2. Disable unnecessary gradient computations.
45
+ 3. Enable gradient checkpointing (optional).
46
+ """
47
+ # Convert model to mixed precision
48
+ model = model.to(dtype=torch.float16)
49
+
50
+ # Disable gradients for embedding layers
51
+ for param in model.parameters():
52
+ if param.dtype == torch.float32:
53
+ param.requires_grad = False
54
+
55
+ # Enable gradient checkpointing
56
+ if use_gradient_checkpointing:
57
+ model.gradient_checkpointing_enable()
58
+
59
+ return model
60
+
61
+
62
+ model = prepare_model_for_kbit_training(
63
+ model,
64
+ use_gradient_checkpointing=False
65
+ )
66
+
67
+ print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
68
+
69
+ checkpoint = torch.load(INFERENCE_WEIGHTS_PATH, map_location=torch.device(device))
70
+ model.load_state_dict(checkpoint['model'])
71
+ model = model.to(device)
72
+ model.eval()
73
+
74
+
75
+ def complete_brackets(s):
76
+ stack = []
77
+ bracket_map = {'{': '}', '[': ']', '(': ')'}
78
+
79
+ # Iterate through each character, handle bracket matching
80
+ for char in s:
81
+ if char in bracket_map:
82
+ stack.append(char)
83
+ elif char in bracket_map.values():
84
+ # Find the corresponding left bracket
85
+ for key, value in bracket_map.items():
86
+ if value == char:
87
+ if stack and stack[-1] == key:
88
+ stack.pop()
89
+ break # Found matching right bracket, process next character
90
+
91
+ # Complete missing right brackets (in reverse order of remaining left brackets in stack)
92
+ completion = ''.join(bracket_map[c] for c in reversed(stack))
93
+ return s + completion
94
+
95
+
96
+ def rest_unreduce(abc_lines):
97
+
98
+ tunebody_index = None
99
+ for i in range(len(abc_lines)):
100
+ if abc_lines[i].startswith('%%score'):
101
+ abc_lines[i] = complete_brackets(abc_lines[i])
102
+ if '[V:' in abc_lines[i]:
103
+ tunebody_index = i
104
+ break
105
+
106
+ metadata_lines = abc_lines[: tunebody_index]
107
+ tunebody_lines = abc_lines[tunebody_index:]
108
+
109
+ part_symbol_list = []
110
+ voice_group_list = []
111
+ for line in metadata_lines:
112
+ if line.startswith('%%score'):
113
+ for round_bracket_match in re.findall(r'\((.*?)\)', line):
114
+ voice_group_list.append(round_bracket_match.split())
115
+ existed_voices = [item for sublist in voice_group_list for item in sublist]
116
+ if line.startswith('V:'):
117
+ symbol = line.split()[0]
118
+ part_symbol_list.append(symbol)
119
+ if symbol[2:] not in existed_voices:
120
+ voice_group_list.append([symbol[2:]])
121
+ z_symbol_list = [] # voices that use z as rest
122
+ x_symbol_list = [] # voices that use x as rest
123
+ for voice_group in voice_group_list:
124
+ z_symbol_list.append('V:' + voice_group[0])
125
+ for j in range(1, len(voice_group)):
126
+ x_symbol_list.append('V:' + voice_group[j])
127
+
128
+ part_symbol_list.sort(key=lambda x: int(x[2:]))
129
+
130
+ unreduced_tunebody_lines = []
131
+
132
+ for i, line in enumerate(tunebody_lines):
133
+ unreduced_line = ''
134
+
135
+ line = re.sub(r'^\[r:[^\]]*\]', '', line)
136
+
137
+ pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
138
+ matches = re.findall(pattern, line)
139
+
140
+ line_bar_dict = {}
141
+ for match in matches:
142
+ key = f'V:{match[0]}'
143
+ value = match[1]
144
+ line_bar_dict[key] = value
145
+
146
+ # calculate duration and collect barline
147
+ dur_dict = {}
148
+ for symbol, bartext in line_bar_dict.items():
149
+ right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
150
+ bartext = bartext[:-len(right_barline)]
151
+ try:
152
+ bar_dur = calculate_bartext_duration(bartext)
153
+ except:
154
+ bar_dur = None
155
+ if bar_dur is not None:
156
+ if bar_dur not in dur_dict.keys():
157
+ dur_dict[bar_dur] = 1
158
+ else:
159
+ dur_dict[bar_dur] += 1
160
+
161
+ try:
162
+ ref_dur = max(dur_dict, key=dur_dict.get)
163
+ except:
164
+ pass # use last ref_dur
165
+
166
+ if i == 0:
167
+ prefix_left_barline = line.split('[V:')[0]
168
+ else:
169
+ prefix_left_barline = ''
170
+
171
+ for symbol in part_symbol_list:
172
+ if symbol in line_bar_dict.keys():
173
+ symbol_bartext = line_bar_dict[symbol]
174
+ else:
175
+ if symbol in z_symbol_list:
176
+ symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
177
+ elif symbol in x_symbol_list:
178
+ symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
179
+ unreduced_line += '[' + symbol + ']' + symbol_bartext
180
+
181
+ unreduced_tunebody_lines.append(unreduced_line + '\n')
182
+
183
+ unreduced_lines = metadata_lines + unreduced_tunebody_lines
184
+
185
+ return unreduced_lines
186
+
187
+
188
+ def inference_patch(period, composer, instrumentation):
189
+
190
+ prompt_lines=[
191
+ '%' + period + '\n',
192
+ '%' + composer + '\n',
193
+ '%' + instrumentation + '\n']
194
+
195
+ while True:
196
+
197
+ failure_flag = False
198
+
199
+ bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
200
+
201
+ start_time = time.time()
202
+
203
+ prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
204
+ byte_list = list(''.join(prompt_lines))
205
+ context_tunebody_byte_list = []
206
+ metadata_byte_list = []
207
+
208
+ print(''.join(byte_list), end='')
209
+
210
+ prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
211
+ in prompt_patches]
212
+ prompt_patches.insert(0, bos_patch)
213
+
214
+ input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
215
+
216
+ end_flag = False
217
+ cut_index = None
218
+
219
+ tunebody_flag = False
220
+
221
+ with torch.inference_mode():
222
+
223
+ while True:
224
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
225
+ predicted_patch = model.generate(input_patches.unsqueeze(0),
226
+ top_k=TOP_K,
227
+ top_p=TOP_P,
228
+ temperature=TEMPERATURE)
229
+ if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # 初次进入tunebody,必须以[r:0/开头
230
+ tunebody_flag = True
231
+ r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
232
+ temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
233
+ predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
234
+ top_k=TOP_K,
235
+ top_p=TOP_P,
236
+ temperature=TEMPERATURE)
237
+ predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
238
+ if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
239
+ end_flag = True
240
+ break
241
+ next_patch = patchilizer.decode([predicted_patch])
242
+
243
+ for char in next_patch:
244
+ byte_list.append(char)
245
+ if tunebody_flag:
246
+ context_tunebody_byte_list.append(char)
247
+ else:
248
+ metadata_byte_list.append(char)
249
+ print(char, end='')
250
+
251
+ patch_end_flag = False
252
+ for j in range(len(predicted_patch)):
253
+ if patch_end_flag:
254
+ predicted_patch[j] = patchilizer.special_token_id
255
+ if predicted_patch[j] == patchilizer.eos_token_id:
256
+ patch_end_flag = True
257
+
258
+ predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
259
+ input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
260
+
261
+ if len(byte_list) > 102400:
262
+ failure_flag = True
263
+ break
264
+ if time.time() - start_time > 10 * 60:
265
+ failure_flag = True
266
+ break
267
+
268
+ if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
269
+ print('Stream generating...')
270
+
271
+ metadata = ''.join(metadata_byte_list)
272
+ context_tunebody = ''.join(context_tunebody_byte_list)
273
+
274
+ if '\n' not in context_tunebody:
275
+ break # Generated content is all metadata, abandon
276
+
277
+ context_tunebody_liness = context_tunebody.split('\n')
278
+ if not context_tunebody.endswith('\n'):
279
+ context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
280
+ else:
281
+ context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
282
+
283
+ cut_index = len(context_tunebody_liness) // 2
284
+ abc_code_slice = metadata + ''.join(context_tunebody_liness[-cut_index:])
285
+
286
+ input_patches = patchilizer.encode_generate(abc_code_slice)
287
+
288
+ input_patches = [item for sublist in input_patches for item in sublist]
289
+ input_patches = torch.tensor([input_patches], device=device)
290
+ input_patches = input_patches.reshape(1, -1)
291
+
292
+ context_tunebody_byte_list = list(''.join(context_tunebody_lines[-cut_index:]))
293
+
294
+ if not failure_flag:
295
+ abc_text = ''.join(byte_list)
296
+
297
+ # unreduce
298
+ abc_lines = abc_text.split('\n')
299
+ abc_lines = list(filter(None, abc_lines))
300
+ abc_lines = [line + '\n' for line in abc_lines]
301
+ try:
302
+ unreduced_abc_lines = rest_unreduce(abc_lines)
303
+ except:
304
+ failure_flag = True
305
+ pass
306
+ else:
307
+ unreduced_abc_lines = [line for line in unreduced_abc_lines if not(line.startswith('%') and not line.startswith('%%'))]
308
+ unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
309
+ unreduced_abc_text = ''.join(unreduced_abc_lines)
310
+ return unreduced_abc_text
311
+
312
+
313
+
314
+
315
+
316
+
317
+ if __name__ == '__main__':
318
+ inference_patch('Classical', 'Beethoven, Ludwig van', 'Keyboard')
notagen.png ADDED

Git LFS Details

  • SHA256: 782948b2cd663b846ebbc03cb14112efdc65dd487a74a3e10fb484199f33b658
  • Pointer size: 131 Bytes
  • Size of remote file: 613 kB
prompts.txt ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Baroque_Bach, Johann Sebastian_Chamber
2
+ Baroque_Bach, Johann Sebastian_Choral
3
+ Baroque_Bach, Johann Sebastian_Keyboard
4
+ Baroque_Bach, Johann Sebastian_Orchestral
5
+ Baroque_Bach, Johann Sebastian_Vocal-Orchestral
6
+ Baroque_Corelli, Arcangelo_Chamber
7
+ Baroque_Corelli, Arcangelo_Orchestral
8
+ Baroque_Handel, George Frideric_Chamber
9
+ Baroque_Handel, George Frideric_Keyboard
10
+ Baroque_Handel, George Frideric_Orchestral
11
+ Baroque_Handel, George Frideric_Vocal-Orchestral
12
+ Baroque_Scarlatti, Domenico_Keyboard
13
+ Baroque_Vivaldi, Antonio_Chamber
14
+ Baroque_Vivaldi, Antonio_Orchestral
15
+ Baroque_Vivaldi, Antonio_Vocal-Orchestral
16
+ Classical_Beethoven, Ludwig van_Art Song
17
+ Classical_Beethoven, Ludwig van_Chamber
18
+ Classical_Beethoven, Ludwig van_Keyboard
19
+ Classical_Beethoven, Ludwig van_Orchestral
20
+ Classical_Haydn, Joseph_Chamber
21
+ Classical_Haydn, Joseph_Keyboard
22
+ Classical_Haydn, Joseph_Orchestral
23
+ Classical_Haydn, Joseph_Vocal-Orchestral
24
+ Classical_Mozart, Wolfgang Amadeus_Chamber
25
+ Classical_Mozart, Wolfgang Amadeus_Choral
26
+ Classical_Mozart, Wolfgang Amadeus_Keyboard
27
+ Classical_Mozart, Wolfgang Amadeus_Orchestral
28
+ Classical_Mozart, Wolfgang Amadeus_Vocal-Orchestral
29
+ Classical_Paradis, Maria Theresia von_Art Song
30
+ Classical_Reichardt, Louise_Art Song
31
+ Classical_Saint-Georges, Joseph Bologne_Chamber
32
+ Classical_Schroter, Corona_Art Song
33
+ Romantic_Bartok, Bela_Keyboard
34
+ Romantic_Berlioz, Hector_Choral
35
+ Romantic_Bizet, Georges_Art Song
36
+ Romantic_Boulanger, Lili_Art Song
37
+ Romantic_Boulton, Harold_Art Song
38
+ Romantic_Brahms, Johannes_Art Song
39
+ Romantic_Brahms, Johannes_Chamber
40
+ Romantic_Brahms, Johannes_Choral
41
+ Romantic_Brahms, Johannes_Keyboard
42
+ Romantic_Brahms, Johannes_Orchestral
43
+ Romantic_Burgmuller, Friedrich_Keyboard
44
+ Romantic_Butterworth, George_Art Song
45
+ Romantic_Chaminade, Cecile_Art Song
46
+ Romantic_Chausson, Ernest_Art Song
47
+ Romantic_Chopin, Frederic_Art Song
48
+ Romantic_Chopin, Frederic_Keyboard
49
+ Romantic_Cornelius, Peter_Art Song
50
+ Romantic_Debussy, Claude_Art Song
51
+ Romantic_Debussy, Claude_Keyboard
52
+ Romantic_Dvorak, Antonin_Chamber
53
+ Romantic_Dvorak, Antonin_Choral
54
+ Romantic_Dvorak, Antonin_Keyboard
55
+ Romantic_Dvorak, Antonin_Orchestral
56
+ Romantic_Faisst, Clara_Art Song
57
+ Romantic_Faure, Gabriel_Art Song
58
+ Romantic_Faure, Gabriel_Chamber
59
+ Romantic_Faure, Gabriel_Keyboard
60
+ Romantic_Franz, Robert_Art Song
61
+ Romantic_Gonzaga, Chiquinha_Art Song
62
+ Romantic_Grandval, Clemence de_Art Song
63
+ Romantic_Grieg, Edvard_Keyboard
64
+ Romantic_Grieg, Edvard_Orchestral
65
+ Romantic_Hensel, Fanny_Art Song
66
+ Romantic_Holmes, Augusta Mary Anne_Art Song
67
+ Romantic_Jaell, Marie_Art Song
68
+ Romantic_Kinkel, Johanna_Art Song
69
+ Romantic_Kralik, Mathilde_Art Song
70
+ Romantic_Lang, Josephine_Art Song
71
+ Romantic_Lehmann, Liza_Art Song
72
+ Romantic_Liszt, Franz_Keyboard
73
+ Romantic_Mayer, Emilie_Chamber
74
+ Romantic_Medtner, Nikolay_Keyboard
75
+ Romantic_Mendelssohn, Felix_Art Song
76
+ Romantic_Mendelssohn, Felix_Chamber
77
+ Romantic_Mendelssohn, Felix_Choral
78
+ Romantic_Mendelssohn, Felix_Keyboard
79
+ Romantic_Mendelssohn, Felix_Orchestral
80
+ Romantic_Munktell, Helena_Art Song
81
+ Romantic_Parratt, Walter_Choral
82
+ Romantic_Prokofiev, Sergey_Keyboard
83
+ Romantic_Rachmaninoff, Sergei_Choral
84
+ Romantic_Rachmaninoff, Sergei_Keyboard
85
+ Romantic_Ravel, Maurice_Art Song
86
+ Romantic_Ravel, Maurice_Chamber
87
+ Romantic_Ravel, Maurice_Keyboard
88
+ Romantic_Saint-Saens, Camille_Chamber
89
+ Romantic_Saint-Saens, Camille_Keyboard
90
+ Romantic_Saint-Saens, Camille_Orchestral
91
+ Romantic_Satie, Erik_Art Song
92
+ Romantic_Satie, Erik_Keyboard
93
+ Romantic_Schubert, Franz_Art Song
94
+ Romantic_Schubert, Franz_Chamber
95
+ Romantic_Schubert, Franz_Choral
96
+ Romantic_Schubert, Franz_Keyboard
97
+ Romantic_Schumann, Clara_Art Song
98
+ Romantic_Schumann, Robert_Art Song
99
+ Romantic_Schumann, Robert_Chamber
100
+ Romantic_Schumann, Robert_Choral
101
+ Romantic_Schumann, Robert_Keyboard
102
+ Romantic_Scriabin, Aleksandr_Keyboard
103
+ Romantic_Shostakovich, Dmitry_Chamber
104
+ Romantic_Shostakovich, Dmitry_Keyboard
105
+ Romantic_Sibelius, Jean_Keyboard
106
+ Romantic_Smetana, Bedrich_Keyboard
107
+ Romantic_Tchaikovsky, Pyotr_Keyboard
108
+ Romantic_Tchaikovsky, Pyotr_Orchestral
109
+ Romantic_Viardot, Pauline_Art Song
110
+ Romantic_Warlock, Peter_Art Song
111
+ Romantic_Wolf, Hugo_Art Song
112
+ Romantic_Zumsteeg, Emilie_Art Song
requirements (6).txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.40.0
2
+ numpy==1.26.4
3
+ wandb==0.17.2
4
+ abctoolkit==0.0.6
5
+ samplings==0.1.7
6
+ pyparsing==3.2.1
7
+ gradio==5.17.1
statistics.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gt_feature_folder = ''
2
+ output_feature_folder = ''
3
+
4
+ import os
5
+ import json
6
+ import random
7
+ import re
8
+ import numpy as np
9
+ from config import *
10
+
11
+ def load_npy_files(folder_path_list):
12
+ """
13
+ Load all .npy files from a specified folder and return a list of numpy arrays.
14
+ """
15
+ npy_list = []
16
+ for file_path in folder_path_list:
17
+ if file_path.endswith('.npy'):
18
+ # file_path = os.path.join(folder_path, file_name)
19
+ np_array = np.load(file_path)[0]
20
+ npy_list.append(np_array)
21
+ return npy_list
22
+
23
+ def average_npy(npy_list):
24
+ """
25
+ Compute the average of a list of numpy arrays.
26
+ """
27
+ return np.mean(npy_list, axis=0)
28
+
29
+ def cosine_similarity(vec1, vec2):
30
+ """
31
+ Compute cosine similarity between two numpy arrays.
32
+ """
33
+ dot_product = np.dot(vec1, vec2)
34
+
35
+ norm_vec1 = np.linalg.norm(vec1)
36
+ norm_vec2 = np.linalg.norm(vec2)
37
+
38
+ cosine_sim = dot_product / (norm_vec1 * norm_vec2)
39
+
40
+ return cosine_sim
41
+
42
+
43
+
44
+ def test_generated_results_similarity():
45
+
46
+ gt_feature_paths = []
47
+ for gt_feature_file in os.listdir(gt_feature_folder):
48
+ gt_feature_paths.append(os.path.join(gt_feature_folder, gt_feature_file))
49
+ gt_features = load_npy_files(gt_feature_paths)
50
+ gt_avg_feature = average_npy(gt_features)
51
+
52
+ clamp2score_list = []
53
+ for output_feature_file in os.listdir(output_feature_folder):
54
+ output_feature_path = os.path.join(output_feature_folder, output_feature_file)
55
+ output_feature = np.load(output_feature_path)[0]
56
+ clamp2score = cosine_similarity(gt_avg_feature, output_feature)
57
+ clamp2score_list.append(clamp2score)
58
+ avg_clampscore = sum(clamp2score_list) / len(clamp2score_list)
59
+
60
+ print('average clamp 2 score:', avg_clampscore)
61
+
62
+
63
+
64
+
65
+ if __name__ == '__main__':
66
+
67
+ test_generated_results_similarity()
68
+
train-gen (1).py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import math
5
+ import json
6
+ import wandb
7
+ import torch
8
+ import random
9
+ import numpy as np
10
+ from utils import *
11
+ from config import *
12
+ from tqdm import tqdm
13
+ from copy import deepcopy
14
+ from torch.cuda.amp import autocast, GradScaler
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from transformers import GPT2Config, LlamaConfig, get_scheduler, get_constant_schedule_with_warmup
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ from torch.utils.data.distributed import DistributedSampler
20
+
21
+ # Set up distributed training
22
+ world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
23
+ global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
24
+ local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
25
+
26
+ if world_size > 1:
27
+ torch.cuda.set_device(local_rank)
28
+ device = torch.device("cuda", local_rank)
29
+ dist.init_process_group(backend='nccl') if world_size > 1 else None
30
+ else:
31
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
32
+
33
+ # Set random seed
34
+ seed = 0 + global_rank
35
+ random.seed(seed)
36
+ np.random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed_all(seed)
39
+ torch.backends.cudnn.deterministic = True
40
+ torch.backends.cudnn.benchmark = False
41
+
42
+ batch_size = BATCH_SIZE
43
+
44
+ patchilizer = Patchilizer()
45
+
46
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
47
+ max_length=PATCH_LENGTH,
48
+ max_position_embeddings=PATCH_LENGTH,
49
+ n_embd=HIDDEN_SIZE,
50
+ num_attention_heads=HIDDEN_SIZE//64,
51
+ vocab_size=1)
52
+ char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
53
+ max_length=PATCH_SIZE+1,
54
+ max_position_embeddings=PATCH_SIZE+1,
55
+ hidden_size=HIDDEN_SIZE,
56
+ num_attention_heads=HIDDEN_SIZE//64,
57
+ vocab_size=128)
58
+
59
+ model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
60
+
61
+ model = model.to(device)
62
+
63
+ # print parameter number
64
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
65
+
66
+ if world_size > 1:
67
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
68
+
69
+ scaler = GradScaler()
70
+ is_autocast = True
71
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
72
+
73
+
74
+ def clear_unused_tensors():
75
+ gc.disable() # Temporarily disable garbage collection
76
+ try:
77
+ # Get the set of tensor ids used by the model
78
+ if hasattr(model, "module"):
79
+ model_tensors = {id(p) for p in model.module.parameters()}
80
+ else:
81
+ model_tensors = {id(p) for p in model.parameters()}
82
+
83
+ # Get the set of tensor ids used by the optimizer
84
+ optimizer_tensors = {
85
+ id(state)
86
+ for state_dict in optimizer.state.values()
87
+ for state in state_dict.values()
88
+ if isinstance(state, torch.Tensor) # Ensure only tensors are considered
89
+ }
90
+
91
+ # List of all CUDA tensors currently in memory
92
+ tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
93
+
94
+ # Create weak references to avoid interfering with garbage collection
95
+ tensor_refs = [weakref.ref(tensor) for tensor in tensors]
96
+
97
+ for tensor_ref in tensor_refs:
98
+ tensor = tensor_ref() # Dereference the weak reference
99
+ if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors:
100
+ # Mark the tensor for deletion
101
+ tensor.detach_() # Detach from computation graph
102
+ del tensor # Delete the tensor reference
103
+ except:
104
+ pass
105
+
106
+ finally:
107
+ gc.enable() # Re-enable garbage collection
108
+ gc.collect() # Force a garbage collection
109
+ torch.cuda.empty_cache() # Clear the CUDA cache
110
+
111
+ def collate_batch(input_batches):
112
+
113
+ input_patches, input_masks = zip(*input_batches)
114
+ input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=0)
115
+ input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0)
116
+
117
+ return input_patches.to(device), input_masks.to(device)
118
+
119
+ def split_into_minibatches(input_patches, input_masks, minibatch_size):
120
+ minibatches = []
121
+ for start_idx in range(0, len(input_patches), minibatch_size):
122
+ end_idx = start_idx + minibatch_size
123
+ minibatch_patches = input_patches[start_idx:end_idx]
124
+ minibatch_masks = input_masks[start_idx:end_idx]
125
+ minibatches.append((minibatch_patches, minibatch_masks))
126
+ return minibatches
127
+
128
+ class NotaGenDataset(Dataset):
129
+ def __init__(self, filenames):
130
+ self.filenames = filenames
131
+
132
+ def __len__(self):
133
+ return len(self.filenames)
134
+
135
+ def __getitem__(self, idx):
136
+
137
+ filepath = self.filenames[idx]['path']
138
+
139
+ key = random.choice(['C#', 'F#', 'B', 'E', 'A', 'D', 'G', 'C', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
140
+
141
+ folder = os.path.dirname(filepath)
142
+ name = os.path.split(filepath)[-1]
143
+ des_filepath = os.path.join(folder, key, name + '_' + key + '.abc')
144
+
145
+ with open(des_filepath, 'r', encoding='utf-8') as f:
146
+ abc_text = f.read()
147
+
148
+ file_bytes = patchilizer.encode_train(abc_text)
149
+ file_masks = [1] * len(file_bytes)
150
+
151
+ file_bytes = torch.tensor(file_bytes, dtype=torch.long)
152
+ file_masks = torch.tensor(file_masks, dtype=torch.long)
153
+
154
+ return file_bytes, file_masks
155
+
156
+ def process_one_batch(batch):
157
+ input_patches, input_masks = batch
158
+ loss = model(input_patches, input_masks).loss
159
+
160
+ # Reduce the loss on GPU 0
161
+ if world_size > 1:
162
+ loss = loss.unsqueeze(0)
163
+ dist.reduce(loss, dst=0)
164
+ loss = loss / world_size
165
+ dist.broadcast(loss, src=0)
166
+
167
+ return loss
168
+
169
+
170
+ # do one epoch for training
171
+ def train_epoch(epoch):
172
+ tqdm_train_set = tqdm(train_set)
173
+ total_train_loss = 0
174
+ iter_idx = 1
175
+ model.train()
176
+ train_steps = (epoch-1)*len(train_set)
177
+
178
+ for batch in tqdm_train_set:
179
+ minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
180
+ for minibatch in minibatches:
181
+ with autocast():
182
+ loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
183
+ scaler.scale(loss).backward()
184
+ total_train_loss += loss.item()
185
+ scaler.step(optimizer)
186
+ scaler.update()
187
+
188
+ lr_scheduler.step()
189
+ model.zero_grad(set_to_none=True)
190
+ tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
191
+ train_steps += 1
192
+
193
+ # Log the training loss to wandb
194
+ if global_rank==0 and WANDB_LOGGING:
195
+ wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
196
+
197
+ iter_idx += 1
198
+ if iter_idx % 1000 == 0:
199
+ clear_unused_tensors()
200
+
201
+ return total_train_loss / (iter_idx-1)
202
+
203
+ # do one epoch for eval
204
+ def eval_epoch():
205
+ tqdm_eval_set = tqdm(eval_set)
206
+ total_eval_loss = 0
207
+ total_eval_bpb = 0
208
+ iter_idx = 1
209
+ model.eval()
210
+
211
+ # Evaluate data for one epoch
212
+ for batch in tqdm_eval_set:
213
+ minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
214
+ for minibatch in minibatches:
215
+ with torch.no_grad():
216
+ loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
217
+ total_eval_loss += loss.item()
218
+ tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
219
+ iter_idx += 1
220
+ return total_eval_loss / (iter_idx-1)
221
+
222
+ # train and eval
223
+ if __name__ == "__main__":
224
+
225
+ # Initialize wandb
226
+ if WANDB_LOGGING and global_rank==0:
227
+ wandb.login(key=WANDB_KEY)
228
+ wandb.init(project="notagen",
229
+ name=WANDB_NAME)
230
+
231
+ # load data
232
+ with open(DATA_TRAIN_INDEX_PATH, "r", encoding="utf-8") as f:
233
+ print("Loading Data...")
234
+ train_files = []
235
+ for line in f:
236
+ train_files.append(json.loads(line))
237
+
238
+ with open(DATA_EVAL_INDEX_PATH, "r", encoding="utf-8") as f:
239
+ print("Loading Data...")
240
+ eval_files = []
241
+ for line in f:
242
+ eval_files.append(json.loads(line))
243
+
244
+ train_batch_nums = int(len(train_files) / batch_size)
245
+ eval_batch_nums = int(len(eval_files) / batch_size)
246
+
247
+
248
+ random.shuffle(train_files)
249
+ random.shuffle(eval_files)
250
+
251
+ train_files = train_files[:train_batch_nums*batch_size]
252
+ eval_files = eval_files[:eval_batch_nums*batch_size]
253
+
254
+ train_set = NotaGenDataset(train_files)
255
+ eval_set = NotaGenDataset(eval_files)
256
+
257
+ train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=local_rank)
258
+ eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=local_rank)
259
+
260
+ train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
261
+ eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
262
+
263
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=1000)
264
+
265
+ model = model.to(device)
266
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
267
+
268
+ if LOAD_FROM_CHECKPOINT and os.path.exists(WEIGHTS_PATH):
269
+ # Load checkpoint to CPU
270
+ checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
271
+
272
+ # Here, model is assumed to be on GPU
273
+ # Load state dict to CPU model first, then move the model to GPU
274
+ if torch.cuda.device_count() > 1:
275
+ # If you have a DataParallel model, you need to load to model.module instead
276
+ cpu_model = deepcopy(model.module)
277
+ cpu_model.load_state_dict(checkpoint['model'])
278
+ model.module.load_state_dict(cpu_model.state_dict())
279
+ else:
280
+ # Load to a CPU clone of the model, then load back
281
+ cpu_model = deepcopy(model)
282
+ cpu_model.load_state_dict(checkpoint['model'])
283
+ model.load_state_dict(cpu_model.state_dict())
284
+ optimizer.load_state_dict(checkpoint['optimizer'])
285
+ lr_scheduler.load_state_dict(checkpoint['lr_sched'])
286
+ pre_epoch = checkpoint['epoch']
287
+ best_epoch = checkpoint['best_epoch']
288
+ min_eval_loss = checkpoint['min_eval_loss']
289
+ print("Successfully Loaded Checkpoint from Epoch %d" % pre_epoch)
290
+ checkpoint = None
291
+
292
+ else:
293
+ pre_epoch = 0
294
+ best_epoch = 0
295
+ min_eval_loss = 100
296
+
297
+ for epoch in range(1+pre_epoch, NUM_EPOCHS+1):
298
+ train_sampler.set_epoch(epoch)
299
+ eval_sampler.set_epoch(epoch)
300
+ print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
301
+ train_loss = train_epoch(epoch)
302
+ eval_loss = eval_epoch()
303
+ if global_rank==0:
304
+ with open(LOGS_PATH,'a') as f:
305
+ f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
306
+ if eval_loss < min_eval_loss:
307
+ best_epoch = epoch
308
+ min_eval_loss = eval_loss
309
+ checkpoint = {
310
+ 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
311
+ 'optimizer': optimizer.state_dict(),
312
+ 'lr_sched': lr_scheduler.state_dict(),
313
+ 'epoch': epoch,
314
+ 'best_epoch': best_epoch,
315
+ 'min_eval_loss': min_eval_loss
316
+ }
317
+ torch.save(checkpoint, WEIGHTS_PATH)
318
+
319
+ if world_size > 1:
320
+ dist.barrier()
321
+
322
+ if global_rank==0:
323
+ print("Best Eval Epoch : "+str(best_epoch))
324
+ print("Min Eval Loss : "+str(min_eval_loss))
325
+
train-gen.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import math
5
+ import json
6
+ import wandb
7
+ import torch
8
+ import random
9
+ import numpy as np
10
+ from abctoolkit.transpose import Key2index, Key2Mode
11
+ from utils import *
12
+ from config import *
13
+ from tqdm import tqdm
14
+ from copy import deepcopy
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ from torch.utils.data import Dataset, DataLoader
17
+ from transformers import GPT2Config, LlamaConfig, get_scheduler, get_constant_schedule_with_warmup
18
+ import torch.distributed as dist
19
+ from torch.nn.parallel import DistributedDataParallel as DDP
20
+ from torch.utils.data.distributed import DistributedSampler
21
+
22
+ Index2Key = {index: key for key, index in Key2index.items() if index not in [1, 11]}
23
+ Mode2Key = {mode: key for key, mode_list in Key2Mode.items() for mode in mode_list }
24
+
25
+ # Set up distributed training
26
+ world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
27
+ global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
28
+ local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
29
+
30
+ if world_size > 1:
31
+ torch.cuda.set_device(local_rank)
32
+ device = torch.device("cuda", local_rank)
33
+ dist.init_process_group(backend='nccl') if world_size > 1 else None
34
+ else:
35
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
36
+
37
+ # Set random seed
38
+ seed = 0 + global_rank
39
+ random.seed(seed)
40
+ np.random.seed(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed_all(seed)
43
+ torch.backends.cudnn.deterministic = True
44
+ torch.backends.cudnn.benchmark = False
45
+
46
+ batch_size = BATCH_SIZE
47
+
48
+ patchilizer = Patchilizer()
49
+
50
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
51
+ max_length=PATCH_LENGTH,
52
+ max_position_embeddings=PATCH_LENGTH,
53
+ n_embd=HIDDEN_SIZE,
54
+ num_attention_heads=HIDDEN_SIZE//64,
55
+ vocab_size=1)
56
+ char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
57
+ max_length=PATCH_SIZE+1,
58
+ max_position_embeddings=PATCH_SIZE+1,
59
+ hidden_size=HIDDEN_SIZE,
60
+ num_attention_heads=HIDDEN_SIZE//64,
61
+ vocab_size=128)
62
+
63
+ model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
64
+
65
+ model = model.to(device)
66
+
67
+ # print parameter number
68
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
69
+
70
+ if world_size > 1:
71
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
72
+
73
+ scaler = GradScaler()
74
+ is_autocast = True
75
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
76
+
77
+
78
+ def clear_unused_tensors():
79
+ gc.disable() # Temporarily disable garbage collection
80
+ try:
81
+ # Get the set of tensor ids used by the model
82
+ if hasattr(model, "module"):
83
+ model_tensors = {id(p) for p in model.module.parameters()}
84
+ else:
85
+ model_tensors = {id(p) for p in model.parameters()}
86
+
87
+ # Get the set of tensor ids used by the optimizer
88
+ optimizer_tensors = {
89
+ id(state)
90
+ for state_dict in optimizer.state.values()
91
+ for state in state_dict.values()
92
+ if isinstance(state, torch.Tensor) # Ensure only tensors are considered
93
+ }
94
+
95
+ # List of all CUDA tensors currently in memory
96
+ tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
97
+
98
+ # Create weak references to avoid interfering with garbage collection
99
+ tensor_refs = [weakref.ref(tensor) for tensor in tensors]
100
+
101
+ for tensor_ref in tensor_refs:
102
+ tensor = tensor_ref() # Dereference the weak reference
103
+ if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors:
104
+ # Mark the tensor for deletion
105
+ tensor.detach_() # Detach from computation graph
106
+ del tensor # Delete the tensor reference
107
+ except:
108
+ pass
109
+
110
+ finally:
111
+ gc.enable() # Re-enable garbage collection
112
+ gc.collect() # Force a garbage collection
113
+ torch.cuda.empty_cache() # Clear the CUDA cache
114
+
115
+ def collate_batch(input_batches):
116
+
117
+ input_patches, input_masks = zip(*input_batches)
118
+ input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=0)
119
+ input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0)
120
+
121
+ return input_patches.to(device), input_masks.to(device)
122
+
123
+ def split_into_minibatches(input_patches, input_masks, minibatch_size):
124
+ minibatches = []
125
+ for start_idx in range(0, len(input_patches), minibatch_size):
126
+ end_idx = start_idx + minibatch_size
127
+ minibatch_patches = input_patches[start_idx:end_idx]
128
+ minibatch_masks = input_masks[start_idx:end_idx]
129
+ minibatches.append((minibatch_patches, minibatch_masks))
130
+ return minibatches
131
+
132
+ class NotaGenDataset(Dataset):
133
+ def __init__(self, filenames):
134
+ self.filenames = filenames
135
+
136
+ def __len__(self):
137
+ return len(self.filenames)
138
+
139
+ def __getitem__(self, idx):
140
+
141
+ filepath = self.filenames[idx]['path']
142
+ ori_key = Mode2Key[self.filenames[idx]['key']]
143
+
144
+ # choose a key to transpose, according to a probility distribution
145
+ ori_key_index = Key2index[ori_key]
146
+ available_index = [(ori_key_index + offset) % 12 for offset in range(-3, 4)]
147
+ index_prob = [1/16, 2/16, 3/16, 4/16, 3/16, 2/16, 1/16]
148
+ index_prob_range = [0] + [sum(index_prob[0 : i + 1]) for i in range(len(index_prob))]
149
+ random_number = random.random()
150
+ for i in range(len(index_prob_range) - 1):
151
+ if index_prob_range[i] <= random_number < index_prob_range[i + 1]:
152
+ des_key_index = available_index[i]
153
+ if des_key_index == 1:
154
+ des_key = 'Db' if random.random() < 0.8 else 'C#'
155
+ elif des_key_index == 11:
156
+ des_key = 'B' if random.random() < 0.8 else 'Cb'
157
+ elif des_key_index == 6:
158
+ des_key = 'F#' if random.random() < 0.5 else 'Gb'
159
+ else:
160
+ des_key = Index2Key[des_key_index]
161
+
162
+ folder = os.path.dirname(filepath)
163
+ name = os.path.split(filepath)[-1]
164
+ des_filepath = os.path.join(folder, des_key, name + '_' + des_key + '.abc')
165
+
166
+ with open(des_filepath, 'r', encoding='utf-8') as f:
167
+ abc_text = f.read()
168
+
169
+ file_bytes = patchilizer.encode_train(abc_text)
170
+ file_masks = [1] * len(file_bytes)
171
+
172
+ file_bytes = torch.tensor(file_bytes, dtype=torch.long)
173
+ file_masks = torch.tensor(file_masks, dtype=torch.long)
174
+
175
+ return file_bytes, file_masks
176
+
177
+
178
+ def process_one_batch(batch):
179
+ input_patches, input_masks = batch
180
+ loss = model(input_patches, input_masks).loss
181
+
182
+ # Reduce the loss on GPU 0
183
+ if world_size > 1:
184
+ loss = loss.unsqueeze(0)
185
+ dist.reduce(loss, dst=0)
186
+ loss = loss / world_size
187
+ dist.broadcast(loss, src=0)
188
+
189
+ return loss
190
+
191
+
192
+ # do one epoch for training
193
+ def train_epoch(epoch):
194
+ tqdm_train_set = tqdm(train_set)
195
+ total_train_loss = 0
196
+ iter_idx = 1
197
+ model.train()
198
+ train_steps = (epoch-1)*len(train_set)
199
+
200
+ for batch in tqdm_train_set:
201
+ minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
202
+ for minibatch in minibatches:
203
+ with autocast():
204
+ loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
205
+ scaler.scale(loss).backward()
206
+ total_train_loss += loss.item()
207
+ scaler.step(optimizer)
208
+ scaler.update()
209
+
210
+ lr_scheduler.step()
211
+ model.zero_grad(set_to_none=True)
212
+ tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
213
+ train_steps += 1
214
+
215
+ # Log the training loss to wandb
216
+ if global_rank==0 and WANDB_LOGGING:
217
+ wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
218
+
219
+ iter_idx += 1
220
+ if iter_idx % 1000 == 0:
221
+ clear_unused_tensors()
222
+
223
+ return total_train_loss / (iter_idx-1)
224
+
225
+ # do one epoch for eval
226
+ def eval_epoch():
227
+ tqdm_eval_set = tqdm(eval_set)
228
+ total_eval_loss = 0
229
+ total_eval_bpb = 0
230
+ iter_idx = 1
231
+ model.eval()
232
+
233
+ # Evaluate data for one epoch
234
+ for batch in tqdm_eval_set:
235
+ minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
236
+ for minibatch in minibatches:
237
+ with torch.no_grad():
238
+ loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
239
+ total_eval_loss += loss.item()
240
+ tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
241
+ iter_idx += 1
242
+ return total_eval_loss / (iter_idx-1)
243
+
244
+ # train and eval
245
+ if __name__ == "__main__":
246
+
247
+ # Initialize wandb
248
+ if WANDB_LOGGING and global_rank==0:
249
+ wandb.login(key=WANDB_KEY)
250
+ wandb.init(project="notagen",
251
+ name=WANDB_NAME)
252
+
253
+ # load data
254
+ with open(DATA_TRAIN_INDEX_PATH, "r", encoding="utf-8") as f:
255
+ print("Loading Data...")
256
+ train_files = []
257
+ for line in f:
258
+ train_files.append(json.loads(line))
259
+
260
+ with open(DATA_EVAL_INDEX_PATH, "r", encoding="utf-8") as f:
261
+ print("Loading Data...")
262
+ eval_files = []
263
+ for line in f:
264
+ eval_files.append(json.loads(line))
265
+
266
+ if len(eval_files) == 0:
267
+ train_files, eval_files = split_data(train_files)
268
+
269
+ train_batch_nums = int(len(train_files) / batch_size)
270
+ eval_batch_nums = int(len(eval_files) / batch_size)
271
+
272
+ random.shuffle(train_files)
273
+ random.shuffle(eval_files)
274
+
275
+ train_files = train_files[:train_batch_nums*batch_size]
276
+ eval_files = eval_files[:eval_batch_nums*batch_size]
277
+
278
+ train_set = NotaGenDataset(train_files)
279
+ eval_set = NotaGenDataset(eval_files)
280
+
281
+ train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=local_rank)
282
+ eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=local_rank)
283
+
284
+ train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
285
+ eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
286
+
287
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=1000)
288
+
289
+ model = model.to(device)
290
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
291
+
292
+ if not LOAD_FROM_CHECKPOINT:
293
+ if os.path.exists(PRETRAINED_PATH):
294
+ # Load pre-trained checkpoint to CPU
295
+ checkpoint = torch.load(PRETRAINED_PATH, map_location='cpu')
296
+
297
+ # Here, model is assumed to be on GPU
298
+ # Load state dict to CPU model first, then move the model to GPU
299
+ if torch.cuda.device_count() > 1:
300
+ # If you have a DataParallel model, you need to load to model.module instead
301
+ cpu_model = deepcopy(model.module)
302
+ cpu_model.load_state_dict(checkpoint['model'])
303
+ model.module.load_state_dict(cpu_model.state_dict())
304
+ else:
305
+ # Load to a CPU clone of the model, then load back
306
+ cpu_model = deepcopy(model)
307
+ cpu_model.load_state_dict(checkpoint['model'])
308
+ model.load_state_dict(cpu_model.state_dict())
309
+
310
+ print(f"Successfully Loaded Pretrained Checkpoint at Epoch {checkpoint['epoch']} with Loss {checkpoint['min_eval_loss']}")
311
+
312
+ pre_epoch = 0
313
+ best_epoch = 0
314
+ min_eval_loss = 100
315
+ else:
316
+ raise Exception('Pre-trained Checkpoint not found. Please check your pre-trained ckpt path.')
317
+
318
+ else:
319
+ if os.path.exists(WEIGHTS_PATH):
320
+ # Load checkpoint to CPU
321
+ checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
322
+
323
+ # Here, model is assumed to be on GPU
324
+ # Load state dict to CPU model first, then move the model to GPU
325
+ if torch.cuda.device_count() > 1:
326
+ # If you have a DataParallel model, you need to load to model.module instead
327
+ cpu_model = deepcopy(model.module)
328
+ cpu_model.load_state_dict(checkpoint['model'])
329
+ model.module.load_state_dict(cpu_model.state_dict())
330
+ else:
331
+ # Load to a CPU clone of the model, then load back
332
+ cpu_model = deepcopy(model)
333
+ cpu_model.load_state_dict(checkpoint['model'])
334
+ model.load_state_dict(cpu_model.state_dict())
335
+ optimizer.load_state_dict(checkpoint['optimizer'])
336
+ lr_scheduler.load_state_dict(checkpoint['lr_sched'])
337
+ pre_epoch = checkpoint['epoch']
338
+ best_epoch = checkpoint['best_epoch']
339
+ min_eval_loss = checkpoint['min_eval_loss']
340
+ print("Successfully Loaded Checkpoint from Epoch %d" % pre_epoch)
341
+ checkpoint = None
342
+
343
+ else:
344
+ raise Exception('Checkpoint not found to continue training. Please check your parameter settings.')
345
+
346
+
347
+ for epoch in range(1+pre_epoch, NUM_EPOCHS+1):
348
+ train_sampler.set_epoch(epoch)
349
+ eval_sampler.set_epoch(epoch)
350
+ print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
351
+ train_loss = train_epoch(epoch)
352
+ eval_loss = eval_epoch()
353
+ if global_rank==0:
354
+ with open(LOGS_PATH,'a') as f:
355
+ f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
356
+ if eval_loss < min_eval_loss:
357
+ best_epoch = epoch
358
+ min_eval_loss = eval_loss
359
+ checkpoint = {
360
+ 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
361
+ 'optimizer': optimizer.state_dict(),
362
+ 'lr_sched': lr_scheduler.state_dict(),
363
+ 'epoch': epoch,
364
+ 'best_epoch': best_epoch,
365
+ 'min_eval_loss': min_eval_loss
366
+ }
367
+ torch.save(checkpoint, WEIGHTS_PATH)
368
+
369
+ if world_size > 1:
370
+ dist.barrier()
371
+
372
+ if global_rank==0:
373
+ print("Best Eval Epoch : "+str(best_epoch))
374
+ print("Min Eval Loss : "+str(min_eval_loss))
train.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import math
5
+ import json
6
+ import wandb
7
+ import torch
8
+ import random
9
+ import numpy as np
10
+ from abctoolkit.transpose import Key2index, Key2Mode
11
+ from utils import *
12
+ from config import *
13
+ from data import generate_preference_dict
14
+ from tqdm import tqdm
15
+ from copy import deepcopy
16
+ from torch.utils.data import Dataset, DataLoader
17
+ from transformers import GPT2Config, get_scheduler, get_constant_schedule_with_warmup
18
+
19
+
20
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
21
+
22
+ # Set random seed
23
+ seed = 0
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed_all(seed)
28
+ torch.backends.cudnn.deterministic = True
29
+ torch.backends.cudnn.benchmark = False
30
+
31
+ patchilizer = Patchilizer()
32
+
33
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
34
+ max_length=PATCH_LENGTH,
35
+ max_position_embeddings=PATCH_LENGTH,
36
+ n_embd=HIDDEN_SIZE,
37
+ num_attention_heads=HIDDEN_SIZE//64,
38
+ vocab_size=1)
39
+ char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
40
+ max_length=PATCH_SIZE+1,
41
+ max_position_embeddings=PATCH_SIZE+1,
42
+ hidden_size=HIDDEN_SIZE,
43
+ num_attention_heads=HIDDEN_SIZE//64,
44
+ vocab_size=128)
45
+
46
+ model_ref = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
47
+ model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
48
+
49
+
50
+ model_ref = model_ref.to(device)
51
+ model = model.to(device)
52
+
53
+
54
+ # print parameter number
55
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
56
+
57
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
58
+
59
+
60
+ def collate_batch(input_batches):
61
+ pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = input_batches
62
+ pos_input_patches = pos_input_patches.unsqueeze(0)
63
+ pos_input_masks = pos_input_masks.unsqueeze(0)
64
+ neg_input_patches = neg_input_patches.unsqueeze(0)
65
+ neg_input_masks = neg_input_masks.unsqueeze(0)
66
+ pos_input_patches = torch.nn.utils.rnn.pad_sequence(pos_input_patches, batch_first=True, padding_value=0)
67
+ pos_input_masks = torch.nn.utils.rnn.pad_sequence(pos_input_masks, batch_first=True, padding_value=0)
68
+ neg_input_patches = torch.nn.utils.rnn.pad_sequence(neg_input_patches, batch_first=True, padding_value=0)
69
+ neg_input_masks = torch.nn.utils.rnn.pad_sequence(neg_input_masks, batch_first=True, padding_value=0)
70
+ return (pos_input_patches.to(device), pos_input_masks.to(device),
71
+ neg_input_patches.to(device), neg_input_masks.to(device))
72
+
73
+
74
+ class NotaGenDataset(Dataset):
75
+ def __init__(self, preference_dict):
76
+ self.preference_dict = preference_dict
77
+ self.pair_list = []
78
+ for pos_filepath in self.preference_dict['chosen']:
79
+ for neg_filepath in self.preference_dict['rejected']:
80
+ self.pair_list.append({'chosen': pos_filepath, 'rejected': neg_filepath})
81
+
82
+ def __len__(self):
83
+ return len(self.pair_list)
84
+
85
+ def __getitem__(self, idx):
86
+ try:
87
+ pair = self.pair_list[idx]
88
+ pos_filepath = pair['chosen']
89
+ neg_filepath = pair['rejected']
90
+
91
+ with open(pos_filepath, 'r', encoding='utf-8') as f:
92
+ pos_abc_text = f.read()
93
+ with open(neg_filepath, 'r', encoding='utf-8') as f:
94
+ neg_abc_text = f.read()
95
+
96
+ pos_file_bytes = patchilizer.encode(pos_abc_text)
97
+ pos_file_masks = [1] * len(pos_file_bytes)
98
+ neg_file_bytes = patchilizer.encode(neg_abc_text)
99
+ neg_file_masks = [1] * len(neg_file_bytes)
100
+
101
+ pos_file_bytes = torch.tensor(pos_file_bytes, dtype=torch.long)
102
+ pos_file_masks = torch.tensor(pos_file_masks, dtype=torch.long)
103
+ neg_file_bytes = torch.tensor(neg_file_bytes, dtype=torch.long)
104
+ neg_file_masks = torch.tensor(neg_file_masks, dtype=torch.long)
105
+
106
+ return pos_file_bytes, pos_file_masks, neg_file_bytes, neg_file_masks
107
+ except Exception as e:
108
+ print(e)
109
+ return self.__getitem__((idx+1) % len(self.pair_list))
110
+
111
+
112
+ def process_one_batch(batch):
113
+ pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = batch
114
+ pos_input_patches_ref = pos_input_patches.clone()
115
+ pos_input_masks_ref = pos_input_masks.clone()
116
+ neg_input_patches_ref = neg_input_patches.clone()
117
+ neg_input_masks_ref = neg_input_masks.clone()
118
+ policy_pos_logps = model(pos_input_patches, pos_input_masks)
119
+ policy_neg_logps = model(neg_input_patches, neg_input_masks)
120
+ with torch.no_grad():
121
+ ref_pos_logps = model_ref(pos_input_patches_ref, pos_input_masks_ref).detach()
122
+ ref_neg_logps = model_ref(neg_input_patches_ref, neg_input_masks_ref).detach()
123
+ logits = (policy_pos_logps - policy_neg_logps) - (ref_pos_logps - ref_neg_logps)
124
+ loss = - torch.nn.functional.logsigmoid(BETA * (logits - LAMBDA * max(0, ref_pos_logps - policy_pos_logps)))
125
+ return loss
126
+
127
+
128
+
129
+ # train
130
+ if __name__ == "__main__":
131
+
132
+ # Initialize wandb
133
+ if WANDB_LOGGING:
134
+ wandb.login(key=WANDB_KEY)
135
+ wandb.init(project="notagen",
136
+ name=WANDB_NAME)
137
+
138
+ # load data
139
+ with open(DATA_INDEX_PATH, 'r') as f:
140
+ preference_dict = json.loads(f.read())
141
+
142
+ train_set = NotaGenDataset(preference_dict)
143
+
144
+ # Load model actor/ref
145
+ if os.path.exists(PRETRAINED_PATH):
146
+ checkpoint = torch.load(PRETRAINED_PATH, map_location='cpu')
147
+ cpu_model = deepcopy(model)
148
+ cpu_model.load_state_dict(checkpoint['model'])
149
+ model.load_state_dict(cpu_model.state_dict())
150
+ cpu_model_ref = deepcopy(model_ref)
151
+ cpu_model_ref.load_state_dict(checkpoint['model'])
152
+ model_ref.load_state_dict(cpu_model_ref.state_dict())
153
+ else:
154
+ raise Exception('No pre-trained model loaded.')
155
+
156
+ model.train()
157
+ total_train_loss = 0
158
+ iter_idx = 1
159
+
160
+ tqdm_set = tqdm(range(OPTIMIZATION_STEPS))
161
+ for i in tqdm_set:
162
+ idx = random.randint(0, len(train_set)-1)
163
+ batch = train_set[idx]
164
+ batch = collate_batch(batch)
165
+
166
+ loss = process_one_batch(batch)
167
+ total_train_loss += loss.item()
168
+
169
+ loss.backward()
170
+ torch.nn.utils.clip_grad_norm(model.parameters(),max_norm=1.0 )
171
+ optimizer.step()
172
+
173
+ model.zero_grad(set_to_none=True)
174
+ tqdm_set.set_postfix({'train_loss': total_train_loss / (i + 1)})
175
+
176
+ # Log the training loss to wandb
177
+ if WANDB_LOGGING:
178
+ wandb.log({"train_loss": total_train_loss / (i + 1)}, step=i+1)
179
+
180
+ checkpoint = {'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict()}
181
+
182
+ torch.save(checkpoint, WEIGHTS_PATH)
183
+
184
+
185
+
186
+
utils (1).py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import math
4
+ import torch
5
+ import random
6
+ from config import *
7
+ from unidecode import unidecode
8
+ from torch.nn import functional as F
9
+ from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config
10
+
11
+ try:
12
+ import torch.distributed.nn
13
+ from torch import distributed as dist
14
+
15
+ has_distributed = True
16
+ except ImportError:
17
+ has_distributed = False
18
+
19
+ try:
20
+ import horovod.torch as hvd
21
+ except ImportError:
22
+ hvd = None
23
+
24
+ class ClipLoss(torch.nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ local_loss=False,
29
+ gather_with_grad=False,
30
+ cache_labels=False,
31
+ rank=0,
32
+ world_size=1,
33
+ use_horovod=False,
34
+ ):
35
+ super().__init__()
36
+ self.local_loss = local_loss
37
+ self.gather_with_grad = gather_with_grad
38
+ self.cache_labels = cache_labels
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.use_horovod = use_horovod
42
+
43
+ # cache state
44
+ self.prev_num_logits = 0
45
+ self.labels = {}
46
+
47
+ def gather_features(
48
+ self,
49
+ image_features,
50
+ text_features,
51
+ local_loss=False,
52
+ gather_with_grad=False,
53
+ rank=0,
54
+ world_size=1,
55
+ use_horovod=False
56
+ ):
57
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
58
+ if use_horovod:
59
+ assert hvd is not None, 'Please install horovod'
60
+ if gather_with_grad:
61
+ all_image_features = hvd.allgather(image_features)
62
+ all_text_features = hvd.allgather(text_features)
63
+ else:
64
+ with torch.no_grad():
65
+ all_image_features = hvd.allgather(image_features)
66
+ all_text_features = hvd.allgather(text_features)
67
+ if not local_loss:
68
+ # ensure grads for local rank when all_* features don't have a gradient
69
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
70
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
71
+ gathered_image_features[rank] = image_features
72
+ gathered_text_features[rank] = text_features
73
+ all_image_features = torch.cat(gathered_image_features, dim=0)
74
+ all_text_features = torch.cat(gathered_text_features, dim=0)
75
+ else:
76
+ # We gather tensors from all gpus
77
+ if gather_with_grad:
78
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
79
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
80
+ else:
81
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
82
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
83
+ dist.all_gather(gathered_image_features, image_features)
84
+ dist.all_gather(gathered_text_features, text_features)
85
+ if not local_loss:
86
+ # ensure grads for local rank when all_* features don't have a gradient
87
+ gathered_image_features[rank] = image_features
88
+ gathered_text_features[rank] = text_features
89
+ all_image_features = torch.cat(gathered_image_features, dim=0)
90
+ all_text_features = torch.cat(gathered_text_features, dim=0)
91
+
92
+ return all_image_features, all_text_features
93
+
94
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
95
+ # calculated ground-truth and cache if enabled
96
+ if self.prev_num_logits != num_logits or device not in self.labels:
97
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
98
+ if self.world_size > 1 and self.local_loss:
99
+ labels = labels + num_logits * self.rank
100
+ if self.cache_labels:
101
+ self.labels[device] = labels
102
+ self.prev_num_logits = num_logits
103
+ else:
104
+ labels = self.labels[device]
105
+ return labels
106
+
107
+ def get_logits(self, image_features, text_features, logit_scale):
108
+ if self.world_size > 1:
109
+ all_image_features, all_text_features = self.gather_features(
110
+ image_features, text_features,
111
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
112
+
113
+ if self.local_loss:
114
+ logits_per_image = logit_scale * image_features @ all_text_features.T
115
+ logits_per_text = logit_scale * text_features @ all_image_features.T
116
+ else:
117
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
118
+ logits_per_text = logits_per_image.T
119
+ else:
120
+ logits_per_image = logit_scale * image_features @ text_features.T
121
+ logits_per_text = logit_scale * text_features @ image_features.T
122
+
123
+ return logits_per_image, logits_per_text
124
+
125
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
126
+ device = image_features.device
127
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
128
+
129
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
130
+
131
+ total_loss = (
132
+ F.cross_entropy(logits_per_image, labels) +
133
+ F.cross_entropy(logits_per_text, labels)
134
+ ) / 2
135
+
136
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
137
+
138
+ class M3Patchilizer:
139
+ def __init__(self):
140
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
141
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
142
+ self.pad_token_id = 0
143
+ self.bos_token_id = 1
144
+ self.eos_token_id = 2
145
+ self.mask_token_id = 3
146
+
147
+ def split_bars(self, body):
148
+ bars = re.split(self.regexPattern, ''.join(body))
149
+ bars = list(filter(None, bars)) # remove empty strings
150
+ if bars[0] in self.delimiters:
151
+ bars[1] = bars[0] + bars[1]
152
+ bars = bars[1:]
153
+ bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
154
+ return bars
155
+
156
+ def bar2patch(self, bar, patch_size=PATCH_SIZE):
157
+ patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
158
+ patch = patch[:patch_size]
159
+ patch += [self.pad_token_id] * (patch_size - len(patch))
160
+ return patch
161
+
162
+ def patch2bar(self, patch):
163
+ return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch)
164
+
165
+ def encode(self,
166
+ item,
167
+ patch_size=PATCH_SIZE,
168
+ add_special_patches=False,
169
+ truncate=False,
170
+ random_truncate=False):
171
+
172
+ item = unidecode(item)
173
+ lines = re.findall(r'.*?\n|.*$', item)
174
+ lines = list(filter(None, lines)) # remove empty lines
175
+
176
+ patches = []
177
+
178
+ if lines[0].split(" ")[0] == "ticks_per_beat":
179
+ patch = ""
180
+ for line in lines:
181
+ if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2):
182
+ patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:])
183
+ else:
184
+ if patch:
185
+ patches.append(patch)
186
+ patch = line
187
+ if patch!="":
188
+ patches.append(patch)
189
+ else:
190
+ for line in lines:
191
+ if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')):
192
+ patches.append(line)
193
+ else:
194
+ bars = self.split_bars(line)
195
+ if bars:
196
+ bars[-1] += '\n'
197
+ patches.extend(bars)
198
+
199
+ if add_special_patches:
200
+ bos_patch = chr(self.bos_token_id) * patch_size
201
+ eos_patch = chr(self.eos_token_id) * patch_size
202
+ patches = [bos_patch] + patches + [eos_patch]
203
+
204
+ if len(patches) > PATCH_LENGTH and truncate:
205
+ choices = ["head", "tail", "middle"]
206
+ choice = random.choice(choices)
207
+ if choice=="head" or random_truncate==False:
208
+ patches = patches[:PATCH_LENGTH]
209
+ elif choice=="tail":
210
+ patches = patches[-PATCH_LENGTH:]
211
+ else:
212
+ start = random.randint(1, len(patches)-PATCH_LENGTH)
213
+ patches = patches[start:start+PATCH_LENGTH]
214
+
215
+ patches = [self.bar2patch(patch) for patch in patches]
216
+
217
+ return patches
218
+
219
+ def decode(self, patches):
220
+ return ''.join(self.patch2bar(patch) for patch in patches)
221
+
222
+ class M3PatchEncoder(PreTrainedModel):
223
+ def __init__(self, config):
224
+ super(M3PatchEncoder, self).__init__(config)
225
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE)
226
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
227
+ self.base = BertModel(config=config)
228
+ self.pad_token_id = 0
229
+ self.bos_token_id = 1
230
+ self.eos_token_id = 2
231
+ self.mask_token_id = 3
232
+
233
+ def forward(self,
234
+ input_patches, # [batch_size, seq_length, hidden_size]
235
+ input_masks): # [batch_size, seq_length]
236
+ # Transform input_patches into embeddings
237
+ input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128)
238
+ input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor)
239
+ input_patches = self.patch_embedding(input_patches.to(self.device))
240
+
241
+ # Apply BERT model to input_patches and input_masks
242
+ return self.base(inputs_embeds=input_patches, attention_mask=input_masks)
243
+
244
+ class M3TokenDecoder(PreTrainedModel):
245
+ def __init__(self, config):
246
+ super(M3TokenDecoder, self).__init__(config)
247
+ self.base = GPT2LMHeadModel(config=config)
248
+ self.pad_token_id = 0
249
+ self.bos_token_id = 1
250
+ self.eos_token_id = 2
251
+ self.mask_token_id = 3
252
+
253
+ def forward(self,
254
+ patch_features, # [batch_size, hidden_size]
255
+ target_patches): # [batch_size, seq_length]
256
+ # get input embeddings
257
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
258
+
259
+ # concatenate the encoded patches with the input embeddings
260
+ inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
261
+
262
+ # preparing the labels for model training
263
+ target_masks = target_patches == self.pad_token_id
264
+ target_patches = target_patches.clone().masked_fill_(target_masks, -100)
265
+
266
+ # get the attention mask
267
+ target_masks = ~target_masks
268
+ target_masks = target_masks.type(torch.int)
269
+
270
+ return self.base(inputs_embeds=inputs_embeds,
271
+ attention_mask=target_masks,
272
+ labels=target_patches)
273
+
274
+ def generate(self,
275
+ patch_feature,
276
+ tokens):
277
+ # reshape the patch_feature and tokens
278
+ patch_feature = patch_feature.reshape(1, 1, -1)
279
+ tokens = tokens.reshape(1, -1)
280
+
281
+ # get input embeddings
282
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
283
+
284
+ # concatenate the encoded patches with the input embeddings
285
+ tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1)
286
+
287
+ # get the outputs from the model
288
+ outputs = self.base(inputs_embeds=tokens)
289
+
290
+ # get the probabilities of the next token
291
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
292
+
293
+ return probs.detach().cpu().numpy()
294
+
295
+ class M3Model(PreTrainedModel):
296
+ def __init__(self, encoder_config, decoder_config):
297
+ super(M3Model, self).__init__(encoder_config)
298
+ self.encoder = M3PatchEncoder(encoder_config)
299
+ self.decoder = M3TokenDecoder(decoder_config)
300
+ self.pad_token_id = 0
301
+ self.bos_token_id = 1
302
+ self.eos_token_id = 2
303
+ self.mask_token_id = 3
304
+
305
+ def forward(self,
306
+ input_patches, # [batch_size, seq_length, hidden_size]
307
+ input_masks, # [batch_size, seq_length]
308
+ selected_indices, # [batch_size, seq_length]
309
+ target_patches): # [batch_size, seq_length, hidden_size]
310
+ input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device)
311
+ input_masks = input_masks.to(self.device)
312
+ selected_indices = selected_indices.to(self.device)
313
+ target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device)
314
+
315
+ # Pass the input_patches and input_masks through the encoder
316
+ outputs = self.encoder(input_patches, input_masks)["last_hidden_state"]
317
+
318
+ # Use selected_indices to form target_patches
319
+ target_patches = target_patches[selected_indices.bool()]
320
+ patch_features = outputs[selected_indices.bool()]
321
+
322
+ # Pass patch_features and target_patches through the decoder
323
+ return self.decoder(patch_features, target_patches)
324
+
325
+ class CLaMP2Model(PreTrainedModel):
326
+ def __init__(self,
327
+ music_config,
328
+ global_rank=None,
329
+ world_size=None,
330
+ text_model_name=TEXT_MODEL_NAME,
331
+ hidden_size=CLAMP2_HIDDEN_SIZE,
332
+ load_m3=CLAMP2_LOAD_M3):
333
+ super(CLaMP2Model, self).__init__(music_config)
334
+
335
+ self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model
336
+ self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections
337
+ torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution
338
+
339
+ self.music_model = M3PatchEncoder(music_config) # Initialize the music model
340
+ self.music_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for music projections
341
+ torch.nn.init.normal_(self.music_proj.weight, std=0.02) # Initialize weights with normal distribution
342
+
343
+ if global_rank==None or world_size==None:
344
+ global_rank = 0
345
+ world_size = 1
346
+
347
+ self.loss_fn = ClipLoss(local_loss=False,
348
+ gather_with_grad=True,
349
+ cache_labels=False,
350
+ rank=global_rank,
351
+ world_size=world_size,
352
+ use_horovod=False)
353
+
354
+ if load_m3 and os.path.exists(M3_WEIGHTS_PATH):
355
+ checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
356
+ decoder_config = GPT2Config(vocab_size=128,
357
+ n_positions=PATCH_SIZE,
358
+ n_embd=M3_HIDDEN_SIZE,
359
+ n_layer=TOKEN_NUM_LAYERS,
360
+ n_head=M3_HIDDEN_SIZE//64,
361
+ n_inner=M3_HIDDEN_SIZE*4)
362
+ model = M3Model(music_config, decoder_config)
363
+ model.load_state_dict(checkpoint['model'])
364
+ self.music_model = model.encoder
365
+ model = None
366
+ print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
367
+
368
+ def avg_pooling(self, input_features, input_masks):
369
+ input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension
370
+ input_features = input_features * input_masks # apply mask to input_features
371
+ avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling
372
+
373
+ return avg_pool
374
+
375
+ def get_text_features(self,
376
+ text_inputs,
377
+ text_masks,
378
+ get_normalized=False):
379
+ text_features = self.text_model(text_inputs.to(self.device),
380
+ attention_mask=text_masks.to(self.device))['last_hidden_state']
381
+
382
+ if get_normalized:
383
+ text_features = self.avg_pooling(text_features, text_masks)
384
+ text_features = self.text_proj(text_features)
385
+
386
+ return text_features
387
+
388
+ def get_music_features(self,
389
+ music_inputs,
390
+ music_masks,
391
+ get_normalized=False):
392
+ music_features = self.music_model(music_inputs.to(self.device),
393
+ music_masks.to(self.device))['last_hidden_state']
394
+
395
+ if get_normalized:
396
+ music_features = self.avg_pooling(music_features, music_masks)
397
+ music_features = self.music_proj(music_features)
398
+
399
+ return music_features
400
+
401
+ def forward(self,
402
+ text_inputs, # [batch_size, seq_length]
403
+ text_masks, # [batch_size, seq_length]
404
+ music_inputs, # [batch_size, seq_length, hidden_size]
405
+ music_masks): # [batch_size, seq_length]
406
+ # Compute the text features
407
+ text_features = self.get_text_features(text_inputs, text_masks, get_normalized=True)
408
+
409
+ # Compute the music features
410
+ music_features = self.get_music_features(music_inputs, music_masks, get_normalized=True)
411
+
412
+ return self.loss_fn(text_features,
413
+ music_features,
414
+ LOGIT_SCALE,
415
+ output_dict=False)
416
+
417
+ def split_data(data, eval_ratio=EVAL_SPLIT):
418
+ random.shuffle(data)
419
+ split_idx = int(len(data)*eval_ratio)
420
+ eval_set = data[:split_idx]
421
+ train_set = data[split_idx:]
422
+ return train_set, eval_set
423
+
424
+ def mask_patches(target_patches, patchilizer, mode):
425
+ indices = list(range(len(target_patches)))
426
+ random.shuffle(indices)
427
+ selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))]
428
+ sorted_indices = sorted(selected_indices)
429
+ input_patches = torch.tensor(target_patches)
430
+
431
+ if mode=="eval":
432
+ choice = "original"
433
+ else:
434
+ choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0]
435
+
436
+ if choice=="mask":
437
+ input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE)
438
+ elif choice=="shuffle":
439
+ for idx in sorted_indices:
440
+ patch = input_patches[idx]
441
+ try:
442
+ index_eos = (patch == patchilizer.eos_token_id).nonzero().item()
443
+ except:
444
+ index_eos = len(patch)
445
+
446
+ indices = list(range(1, index_eos))
447
+ random.shuffle(indices)
448
+ indices = [0] + indices + list(range(index_eos, len(patch)))
449
+ input_patches[idx] = patch[indices]
450
+
451
+ selected_indices = torch.zeros(len(target_patches))
452
+ selected_indices[sorted_indices] = 1.
453
+
454
+ return input_patches, selected_indices
455
+
456
+ def remove_instrument_info(item):
457
+ # remove instrument information from symbolic music
458
+ lines = re.findall(r'.*?\n|.*$', item)
459
+ lines = list(filter(None, lines))
460
+ if lines[0].split(" ")[0] == "ticks_per_beat":
461
+ type = "mtf"
462
+ else:
463
+ type = "abc"
464
+
465
+ cleaned_lines = []
466
+ for line in lines:
467
+ if type=="abc" and line.startswith("V:"):
468
+ # find the position of " nm=" or " snm="
469
+ nm_pos = line.find(" nm=")
470
+ snm_pos = line.find(" snm=")
471
+ # keep the part before " nm=" or " snm="
472
+ if nm_pos != -1:
473
+ line = line[:nm_pos]
474
+ elif snm_pos != -1:
475
+ line = line[:snm_pos]
476
+ if nm_pos != -1 or snm_pos != -1:
477
+ line += "\n"
478
+ elif type=="mtf" and line.startswith("program_change"):
479
+ line = " ".join(line.split(" ")[:-1]) + " 0\n"
480
+
481
+ cleaned_lines.append(line)
482
+
483
+ return ''.join(cleaned_lines)
utils (2).py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import bisect
4
+ import json
5
+ import re
6
+ import numpy as np
7
+ from config import *
8
+ from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
9
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
+ from tokenizers import Tokenizer
11
+
12
+
13
+ class Patchilizer:
14
+ def __init__(self, stream=PATCH_STREAM):
15
+ self.stream = stream
16
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
17
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
18
+ self.bos_token_id = 1
19
+ self.eos_token_id = 2
20
+ self.special_token_id = 0
21
+
22
+ def split_bars(self, body_lines):
23
+ """
24
+ Split a body of music into individual bars.
25
+ """
26
+ new_bars = []
27
+ try:
28
+ for line in body_lines:
29
+ line_bars = re.split(self.regexPattern, line)
30
+ line_bars = list(filter(None, line_bars))
31
+ new_line_bars = []
32
+
33
+ if len(line_bars) == 1:
34
+ new_line_bars = line_bars
35
+ else:
36
+ if line_bars[0] in self.delimiters:
37
+ new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
38
+ else:
39
+ new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
40
+ if 'V' not in new_line_bars[-1]:
41
+ new_line_bars[-2] += new_line_bars[-1]
42
+ new_line_bars = new_line_bars[:-1]
43
+ new_bars += new_line_bars
44
+ except:
45
+ pass
46
+
47
+ return new_bars
48
+
49
+ def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
50
+ if not generate_last and len(abc_text) % patch_size != 0:
51
+ abc_text += chr(self.eos_token_id)
52
+ patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
53
+ return patches
54
+
55
+ def patch2chars(self, patch):
56
+ """
57
+ Convert a patch into a bar.
58
+ """
59
+ bytes = ''
60
+ for idx in patch:
61
+ if idx == self.eos_token_id:
62
+ break
63
+ if idx < self.eos_token_id:
64
+ pass
65
+ bytes += chr(idx)
66
+ return bytes
67
+
68
+
69
+ def patchilize_metadata(self, metadata_lines):
70
+
71
+ metadata_patches = []
72
+ for line in metadata_lines:
73
+ metadata_patches += self.split_patches(line)
74
+
75
+ return metadata_patches
76
+
77
+ def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
78
+
79
+ tunebody_patches = []
80
+ bars = self.split_bars(tunebody_lines)
81
+ if encode_mode == 'train':
82
+ for bar in bars:
83
+ tunebody_patches += self.split_patches(bar)
84
+ elif encode_mode == 'generate':
85
+ for bar in bars[:-1]:
86
+ tunebody_patches += self.split_patches(bar)
87
+ tunebody_patches += self.split_patches(bars[-1], generate_last=True)
88
+
89
+ return tunebody_patches
90
+
91
+ def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
92
+
93
+ lines = abc_text.split('\n')
94
+ lines = list(filter(None, lines))
95
+ lines = [line + '\n' for line in lines]
96
+
97
+ tunebody_index = -1
98
+ for i, line in enumerate(lines):
99
+ if '[V:' in line:
100
+ tunebody_index = i
101
+ break
102
+
103
+ metadata_lines = lines[ : tunebody_index]
104
+ tunebody_lines = lines[tunebody_index : ]
105
+
106
+ if self.stream:
107
+ tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
108
+ enumerate(tunebody_lines)]
109
+
110
+ metadata_patches = self.patchilize_metadata(metadata_lines)
111
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
112
+
113
+ if add_special_patches:
114
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
115
+ eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
116
+
117
+ metadata_patches = [bos_patch] + metadata_patches
118
+ tunebody_patches = tunebody_patches + [eos_patch]
119
+
120
+ if self.stream:
121
+ if len(metadata_patches) + len(tunebody_patches) > patch_length:
122
+ available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
123
+ line_index_for_cut_index = list(range(len(available_cut_indexes)))
124
+ end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
125
+ biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
126
+ available_cut_indexes = available_cut_indexes[:biggest_index + 1]
127
+
128
+ if len(available_cut_indexes) == 1:
129
+ choices = ['head']
130
+ elif len(available_cut_indexes) == 2:
131
+ choices = ['head', 'tail']
132
+ else:
133
+ choices = ['head', 'tail', 'middle']
134
+ choice = random.choice(choices)
135
+ if choice == 'head':
136
+ patches = metadata_patches + tunebody_patches[0:]
137
+ else:
138
+ if choice == 'tail':
139
+ cut_index = len(available_cut_indexes) - 1
140
+ else:
141
+ cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
142
+
143
+ line_index = line_index_for_cut_index[cut_index]
144
+ stream_tunebody_lines = tunebody_lines[line_index : ]
145
+
146
+ stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
147
+ if add_special_patches:
148
+ stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
149
+ patches = metadata_patches + stream_tunebody_patches
150
+ else:
151
+ patches = metadata_patches + tunebody_patches
152
+ else:
153
+ patches = metadata_patches + tunebody_patches
154
+
155
+ if cut:
156
+ patches = patches[ : patch_length]
157
+ else:
158
+ pass
159
+
160
+ # encode to ids
161
+ id_patches = []
162
+ for patch in patches:
163
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
164
+ id_patches.append(id_patch)
165
+
166
+ return id_patches
167
+
168
+ def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
169
+
170
+ lines = abc_code.split('\n')
171
+ lines = list(filter(None, lines))
172
+
173
+ tunebody_index = None
174
+ for i, line in enumerate(lines):
175
+ if line.startswith('[V:') or line.startswith('[r:'):
176
+ tunebody_index = i
177
+ break
178
+
179
+ metadata_lines = lines[ : tunebody_index]
180
+ tunebody_lines = lines[tunebody_index : ]
181
+
182
+ metadata_lines = [line + '\n' for line in metadata_lines]
183
+ if self.stream:
184
+ if not abc_code.endswith('\n'):
185
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
186
+ else:
187
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
188
+ else:
189
+ tunebody_lines = [line + '\n' for line in tunebody_lines]
190
+
191
+ metadata_patches = self.patchilize_metadata(metadata_lines)
192
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
193
+
194
+ if add_special_patches:
195
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
196
+
197
+ metadata_patches = [bos_patch] + metadata_patches
198
+
199
+ patches = metadata_patches + tunebody_patches
200
+ patches = patches[ : patch_length]
201
+
202
+ # encode to ids
203
+ id_patches = []
204
+ for patch in patches:
205
+ if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
206
+ id_patch = [ord(c) for c in patch]
207
+ else:
208
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
209
+ id_patches.append(id_patch)
210
+
211
+ return id_patches
212
+
213
+ def decode(self, patches):
214
+ """
215
+ Decode patches into music.
216
+ """
217
+ return ''.join(self.patch2chars(patch) for patch in patches)
218
+
219
+
220
+
221
+
222
+ class PatchLevelDecoder(PreTrainedModel):
223
+ """
224
+ A Patch-level Decoder model for generating patch features in an auto-regressive manner.
225
+ It inherits PreTrainedModel from transformers.
226
+ """
227
+ def __init__(self, config):
228
+ super().__init__(config)
229
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
230
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
231
+ self.base = GPT2Model(config)
232
+
233
+ def forward(self,
234
+ patches: torch.Tensor,
235
+ masks=None) -> torch.Tensor:
236
+ """
237
+ The forward pass of the patch-level decoder model.
238
+ :param patches: the patches to be encoded
239
+ :param masks: the masks for the patches
240
+ :return: the encoded patches
241
+ """
242
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
243
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
244
+ patches = self.patch_embedding(patches.to(self.device))
245
+
246
+ if masks==None:
247
+ return self.base(inputs_embeds=patches)
248
+ else:
249
+ return self.base(inputs_embeds=patches,
250
+ attention_mask=masks)
251
+
252
+
253
+ class CharLevelDecoder(PreTrainedModel):
254
+ """
255
+ A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
256
+ based on the encoded patch features. It inherits PreTrainedModel from transformers.
257
+ """
258
+ def __init__(self, config):
259
+ super().__init__(config)
260
+ self.special_token_id = 0
261
+ self.bos_token_id = 1
262
+
263
+ self.base = GPT2LMHeadModel(config)
264
+
265
+ def forward(self,
266
+ encoded_patches: torch.Tensor,
267
+ target_patches: torch.Tensor):
268
+ """
269
+ The forward pass of the char-level decoder model.
270
+ :param encoded_patches: the encoded patches
271
+ :param target_patches: the target patches
272
+ :return: the output of the model
273
+ """
274
+ # preparing the labels for model training
275
+ target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
276
+ # print('target_patches shape:', target_patches.shape)
277
+
278
+ target_masks = target_patches == self.special_token_id
279
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
280
+
281
+ # masking the labels for model training
282
+ target_masks = torch.ones_like(labels)
283
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
284
+
285
+ # select patches
286
+ if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
287
+ indices = list(range(len(target_patches)))
288
+ random.shuffle(indices)
289
+ selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
290
+
291
+ target_patches = target_patches[selected_indices,:]
292
+ target_masks = target_masks[selected_indices,:]
293
+ encoded_patches = encoded_patches[selected_indices,:]
294
+
295
+ # get input embeddings
296
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
297
+
298
+ # concatenate the encoded patches with the input embeddings
299
+ inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
300
+
301
+ output = self.base(inputs_embeds=inputs_embeds,
302
+ attention_mask=target_masks,
303
+ labels=labels)
304
+ # output_hidden_states=True=True)
305
+
306
+ return output
307
+
308
+ def generate(self,
309
+ encoded_patch: torch.Tensor, # [hidden_size]
310
+ tokens: torch.Tensor): # [1]
311
+ """
312
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
313
+ :param encoded_patch: the encoded patch
314
+ :param tokens: already generated tokens in the patch
315
+ :return: the probability distribution of next token
316
+ """
317
+ encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
318
+ tokens = tokens.reshape(1, -1)
319
+
320
+ # Get input embeddings
321
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
322
+
323
+ # Concatenate the encoded patch with the input embeddings
324
+ tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
325
+
326
+ # Get output from model
327
+ outputs = self.base(inputs_embeds=tokens)
328
+
329
+ # Get probabilities of next token
330
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
331
+
332
+ return probs
333
+
334
+ def safe_normalize_probs(probs):
335
+ epsilon = 1e-12
336
+ probs = np.array(probs, dtype=np.float64)
337
+ probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
338
+ probs = probs + epsilon
339
+ s = probs.sum()
340
+ if s > 0:
341
+ probs = probs / s
342
+ else:
343
+ probs = np.zeros_like(probs)
344
+ probs[0] = 1.0
345
+ return probs
346
+
347
+ class NotaGenLMHeadModel(PreTrainedModel):
348
+ """
349
+ NotaGen is a language model with a hierarchical structure.
350
+ It includes a patch-level decoder and a char-level decoder.
351
+ The patch-level decoder is used to generate patch features in an auto-regressive manner.
352
+ The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
353
+ It inherits PreTrainedModel from transformers.
354
+ """
355
+ def __init__(self, encoder_config, decoder_config):
356
+ super().__init__(encoder_config)
357
+ self.special_token_id = 0
358
+ self.bos_token_id = 1
359
+ self.eos_token_id = 2
360
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
361
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
362
+
363
+ def forward(self,
364
+ patches: torch.Tensor,
365
+ masks: torch.Tensor):
366
+ """
367
+ The forward pass of the bGPT model.
368
+ :param patches: the patches to be encoded
369
+ :param masks: the masks for the patches
370
+ :return: the decoded patches
371
+ """
372
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
373
+ encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
374
+
375
+ left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
376
+ masks[:, 0] = 0
377
+
378
+ encoded_patches = encoded_patches[left_shift_masks == 1]
379
+ patches = patches[masks == 1]
380
+
381
+ return self.char_level_decoder(encoded_patches, patches)
382
+
383
+ def generate(self,
384
+ patches: torch.Tensor,
385
+ top_k=0,
386
+ top_p=1,
387
+ temperature=1.0):
388
+ """
389
+ The generate function for generating patches based on patches.
390
+ :param patches: the patches to be encoded
391
+ :param top_k: the top k for sampling
392
+ :param top_p: the top p for sampling
393
+ :param temperature: the temperature for sampling
394
+ :return: the generated patches
395
+ """
396
+ if patches.shape[-1] % PATCH_SIZE != 0:
397
+ tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
398
+ tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
399
+ patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
400
+ else:
401
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
402
+
403
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
404
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
405
+ generated_patch = []
406
+
407
+ while True:
408
+ prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
409
+ prob = safe_normalize_probs(prob)
410
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
411
+ prob = safe_normalize_probs(prob)
412
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
413
+ prob = safe_normalize_probs(prob)
414
+ token = temperature_sampling(prob, temperature=temperature) # int
415
+ char = chr(token)
416
+ generated_patch.append(token)
417
+
418
+ if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
419
+ break
420
+ else:
421
+ tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
422
+
423
+ return generated_patch
utils (3).py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import bisect
4
+ import json
5
+ import re
6
+ import numpy as np
7
+ from config import *
8
+ from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
9
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
+ from tokenizers import Tokenizer
11
+
12
+
13
+ class Patchilizer:
14
+ def __init__(self, stream=PATCH_STREAM):
15
+ self.stream = stream
16
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
17
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
18
+ self.bos_token_id = 1
19
+ self.eos_token_id = 2
20
+ self.special_token_id = 0
21
+
22
+ def split_bars(self, body_lines):
23
+ """
24
+ Split a body of music into individual bars.
25
+ """
26
+ new_bars = []
27
+ try:
28
+ for line in body_lines:
29
+ line_bars = re.split(self.regexPattern, line)
30
+ line_bars = list(filter(None, line_bars))
31
+ new_line_bars = []
32
+
33
+ if len(line_bars) == 1:
34
+ new_line_bars = line_bars
35
+ else:
36
+ if line_bars[0] in self.delimiters:
37
+ new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
38
+ else:
39
+ new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
40
+ if 'V' not in new_line_bars[-1]:
41
+ new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\n 的组合
42
+ new_line_bars = new_line_bars[:-1]
43
+ new_bars += new_line_bars
44
+ except:
45
+ pass
46
+
47
+ return new_bars
48
+
49
+ def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
50
+ if not generate_last and len(abc_text) % patch_size != 0:
51
+ abc_text += chr(self.eos_token_id)
52
+ patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
53
+ return patches
54
+
55
+ def patch2chars(self, patch):
56
+ """
57
+ Convert a patch into a bar.
58
+ """
59
+ bytes = ''
60
+ for idx in patch:
61
+ if idx == self.eos_token_id:
62
+ break
63
+ if idx < self.eos_token_id:
64
+ pass
65
+ bytes += chr(idx)
66
+ return bytes
67
+
68
+
69
+ def patchilize_metadata(self, metadata_lines):
70
+
71
+ metadata_patches = []
72
+ for line in metadata_lines:
73
+ metadata_patches += self.split_patches(line)
74
+
75
+ return metadata_patches
76
+
77
+ def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
78
+
79
+ tunebody_patches = []
80
+ bars = self.split_bars(tunebody_lines)
81
+ if encode_mode == 'train':
82
+ for bar in bars:
83
+ tunebody_patches += self.split_patches(bar)
84
+ elif encode_mode == 'generate':
85
+ for bar in bars[:-1]:
86
+ tunebody_patches += self.split_patches(bar)
87
+ tunebody_patches += self.split_patches(bars[-1], generate_last=True)
88
+
89
+ return tunebody_patches
90
+
91
+ def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
92
+
93
+ lines = abc_text.split('\n')
94
+ lines = list(filter(None, lines))
95
+ lines = [line + '\n' for line in lines]
96
+
97
+ tunebody_index = -1
98
+ for i, line in enumerate(lines):
99
+ if '[V:' in line:
100
+ tunebody_index = i
101
+ break
102
+
103
+ metadata_lines = lines[ : tunebody_index]
104
+ tunebody_lines = lines[tunebody_index : ]
105
+
106
+ if self.stream:
107
+ tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
108
+ enumerate(tunebody_lines)]
109
+
110
+ metadata_patches = self.patchilize_metadata(metadata_lines)
111
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
112
+
113
+ if add_special_patches:
114
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
115
+ eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
116
+
117
+ metadata_patches = [bos_patch] + metadata_patches
118
+ tunebody_patches = tunebody_patches + [eos_patch]
119
+
120
+ if self.stream:
121
+ if len(metadata_patches) + len(tunebody_patches) > patch_length:
122
+ available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
123
+ line_index_for_cut_index = list(range(len(available_cut_indexes)))
124
+ end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
125
+ biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
126
+ available_cut_indexes = available_cut_indexes[:biggest_index + 1]
127
+
128
+ if len(available_cut_indexes) == 1:
129
+ choices = ['head']
130
+ elif len(available_cut_indexes) == 2:
131
+ choices = ['head', 'tail']
132
+ else:
133
+ choices = ['head', 'tail', 'middle']
134
+ choice = random.choice(choices)
135
+ if choice == 'head':
136
+ patches = metadata_patches + tunebody_patches[0:]
137
+ else:
138
+ if choice == 'tail':
139
+ cut_index = len(available_cut_indexes) - 1
140
+ else:
141
+ cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
142
+
143
+ line_index = line_index_for_cut_index[cut_index]
144
+ stream_tunebody_lines = tunebody_lines[line_index : ]
145
+
146
+ stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
147
+ if add_special_patches:
148
+ stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
149
+ patches = metadata_patches + stream_tunebody_patches
150
+ else:
151
+ patches = metadata_patches + tunebody_patches
152
+ else:
153
+ patches = metadata_patches + tunebody_patches
154
+
155
+ if cut:
156
+ patches = patches[ : patch_length]
157
+ else:
158
+ pass
159
+
160
+ # encode to ids
161
+ id_patches = []
162
+ for patch in patches:
163
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
164
+ id_patches.append(id_patch)
165
+
166
+ return id_patches
167
+
168
+ def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
169
+
170
+ lines = abc_code.split('\n')
171
+ lines = list(filter(None, lines))
172
+
173
+ tunebody_index = None
174
+ for i, line in enumerate(lines):
175
+ if line.startswith('[V:') or line.startswith('[r:'):
176
+ tunebody_index = i
177
+ break
178
+
179
+ metadata_lines = lines[ : tunebody_index]
180
+ tunebody_lines = lines[tunebody_index : ]
181
+
182
+ metadata_lines = [line + '\n' for line in metadata_lines]
183
+ if self.stream:
184
+ if not abc_code.endswith('\n'):
185
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
186
+ else:
187
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
188
+ else:
189
+ tunebody_lines = [line + '\n' for line in tunebody_lines]
190
+
191
+ metadata_patches = self.patchilize_metadata(metadata_lines)
192
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
193
+
194
+ if add_special_patches:
195
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
196
+
197
+ metadata_patches = [bos_patch] + metadata_patches
198
+
199
+ patches = metadata_patches + tunebody_patches
200
+ patches = patches[ : patch_length]
201
+
202
+ # encode to ids
203
+ id_patches = []
204
+ for patch in patches:
205
+ if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
206
+ id_patch = [ord(c) for c in patch]
207
+ else:
208
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
209
+ id_patches.append(id_patch)
210
+
211
+ return id_patches
212
+
213
+ def decode(self, patches):
214
+ """
215
+ Decode patches into music.
216
+ """
217
+ return ''.join(self.patch2chars(patch) for patch in patches)
218
+
219
+
220
+
221
+
222
+ class PatchLevelDecoder(PreTrainedModel):
223
+ """
224
+ A Patch-level Decoder model for generating patch features in an auto-regressive manner.
225
+ It inherits PreTrainedModel from transformers.
226
+ """
227
+ def __init__(self, config):
228
+ super().__init__(config)
229
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
230
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
231
+ self.base = GPT2Model(config)
232
+
233
+ def forward(self,
234
+ patches: torch.Tensor,
235
+ masks=None) -> torch.Tensor:
236
+ """
237
+ The forward pass of the patch-level decoder model.
238
+ :param patches: the patches to be encoded
239
+ :param masks: the masks for the patches
240
+ :return: the encoded patches
241
+ """
242
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
243
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
244
+ patches = self.patch_embedding(patches.to(self.device))
245
+
246
+ if masks==None:
247
+ return self.base(inputs_embeds=patches)
248
+ else:
249
+ return self.base(inputs_embeds=patches,
250
+ attention_mask=masks)
251
+
252
+
253
+ class CharLevelDecoder(PreTrainedModel):
254
+ """
255
+ A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
256
+ based on the encoded patch features. It inherits PreTrainedModel from transformers.
257
+ """
258
+ def __init__(self, config):
259
+ super().__init__(config)
260
+ self.special_token_id = 0
261
+ self.bos_token_id = 1
262
+
263
+ self.base = GPT2LMHeadModel(config)
264
+
265
+ def forward(self,
266
+ encoded_patches: torch.Tensor,
267
+ target_patches: torch.Tensor):
268
+ """
269
+ The forward pass of the char-level decoder model.
270
+ :param encoded_patches: the encoded patches
271
+ :param target_patches: the target patches
272
+ :return: the output of the model
273
+ """
274
+ # preparing the labels for model training
275
+ target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
276
+ # print('target_patches shape:', target_patches.shape)
277
+
278
+ target_masks = target_patches == self.special_token_id
279
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
280
+
281
+ # masking the labels for model training
282
+ target_masks = torch.ones_like(labels)
283
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
284
+
285
+ # select patches
286
+ if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
287
+ indices = list(range(len(target_patches)))
288
+ random.shuffle(indices)
289
+ selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
290
+
291
+ target_patches = target_patches[selected_indices,:]
292
+ target_masks = target_masks[selected_indices,:]
293
+ encoded_patches = encoded_patches[selected_indices,:]
294
+
295
+ # get input embeddings
296
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
297
+
298
+ # concatenate the encoded patches with the input embeddings
299
+ inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
300
+
301
+ output = self.base(inputs_embeds=inputs_embeds,
302
+ attention_mask=target_masks,
303
+ labels=labels)
304
+ # output_hidden_states=True=True)
305
+
306
+ return output
307
+
308
+ def generate(self,
309
+ encoded_patch: torch.Tensor, # [hidden_size]
310
+ tokens: torch.Tensor): # [1]
311
+ """
312
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
313
+ :param encoded_patch: the encoded patch
314
+ :param tokens: already generated tokens in the patch
315
+ :return: the probability distribution of next token
316
+ """
317
+ encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
318
+ tokens = tokens.reshape(1, -1)
319
+
320
+ # Get input embeddings
321
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
322
+
323
+ # Concatenate the encoded patch with the input embeddings
324
+ tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
325
+
326
+ # Get output from model
327
+ outputs = self.base(inputs_embeds=tokens)
328
+
329
+ # Get probabilities of next token
330
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
331
+
332
+ return probs
333
+
334
+ def safe_normalize_probs(probs):
335
+ epsilon = 1e-12 # Smallest value to avoid log(0) and maintain precision
336
+ probs = np.array(probs, dtype=np.float64)
337
+ probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
338
+ probs = probs + epsilon # Ensure strictly positive
339
+ s = probs.sum()
340
+ if s > 0:
341
+ probs = probs / s
342
+ else:
343
+ probs = np.zeros_like(probs)
344
+ probs[0] = 1.0
345
+ return probs
346
+
347
+ class NotaGenLMHeadModel(PreTrainedModel):
348
+ """
349
+ NotaGen is a language model with a hierarchical structure.
350
+ It includes a patch-level decoder and a char-level decoder.
351
+ The patch-level decoder is used to generate patch features in an auto-regressive manner.
352
+ The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
353
+ It inherits PreTrainedModel from transformers.
354
+ """
355
+ def __init__(self, encoder_config, decoder_config):
356
+ super().__init__(encoder_config)
357
+ self.special_token_id = 0
358
+ self.bos_token_id = 1
359
+ self.eos_token_id = 2
360
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
361
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
362
+
363
+ def forward(self,
364
+ patches: torch.Tensor,
365
+ masks: torch.Tensor):
366
+ """
367
+ The forward pass of the bGPT model.
368
+ :param patches: the patches to be encoded
369
+ :param masks: the masks for the patches
370
+ :return: the decoded patches
371
+ """
372
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
373
+ encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
374
+
375
+ left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
376
+ masks[:, 0] = 0
377
+
378
+ encoded_patches = encoded_patches[left_shift_masks == 1]
379
+ patches = patches[masks == 1]
380
+
381
+ return self.char_level_decoder(encoded_patches, patches)
382
+
383
+ def generate(self,
384
+ patches: torch.Tensor,
385
+ top_k=0,
386
+ top_p=1,
387
+ temperature=1.0):
388
+ """
389
+ The generate function for generating patches based on patches.
390
+ :param patches: the patches to be encoded
391
+ :param top_k: the top k for sampling
392
+ :param top_p: the top p for sampling
393
+ :param temperature: the temperature for sampling
394
+ :return: the generated patches
395
+ """
396
+ if patches.shape[-1] % PATCH_SIZE != 0:
397
+ tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
398
+ tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
399
+ patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
400
+ else:
401
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
402
+
403
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
404
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
405
+ generated_patch = []
406
+
407
+ while True:
408
+ prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
409
+ prob = safe_normalize_probs(prob)
410
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
411
+ prob = safe_normalize_probs(prob)
412
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
413
+ prob = safe_normalize_probs(prob)
414
+ token = temperature_sampling(prob, temperature=temperature) # int
415
+ char = chr(token)
416
+ generated_patch.append(token)
417
+
418
+ if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
419
+ break
420
+ else:
421
+ tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
422
+
423
+ return generated_patch
utils (4).py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import bisect
4
+ import json
5
+ import re
6
+ import numpy as np
7
+ from config import *
8
+ from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
9
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
+ from tokenizers import Tokenizer
11
+
12
+
13
+ class Patchilizer:
14
+ def __init__(self, stream=PATCH_STREAM):
15
+ self.stream = stream
16
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
17
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
18
+ self.bos_token_id = 1
19
+ self.eos_token_id = 2
20
+ self.special_token_id = 0
21
+
22
+ def split_bars(self, body_lines):
23
+ """
24
+ Split a body of music into individual bars.
25
+ """
26
+ new_bars = []
27
+ try:
28
+ for line in body_lines:
29
+ line_bars = re.split(self.regexPattern, line)
30
+ line_bars = list(filter(None, line_bars))
31
+ new_line_bars = []
32
+
33
+ if len(line_bars) == 1:
34
+ new_line_bars = line_bars
35
+ else:
36
+ if line_bars[0] in self.delimiters:
37
+ new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
38
+ else:
39
+ new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
40
+ if 'V' not in new_line_bars[-1]:
41
+ new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\n 的组合
42
+ new_line_bars = new_line_bars[:-1]
43
+ new_bars += new_line_bars
44
+ except:
45
+ pass
46
+
47
+ return new_bars
48
+
49
+ def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
50
+ if not generate_last and len(abc_text) % patch_size != 0:
51
+ abc_text += chr(self.eos_token_id)
52
+ patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
53
+ return patches
54
+
55
+ def patch2chars(self, patch):
56
+ """
57
+ Convert a patch into a bar.
58
+ """
59
+ bytes = ''
60
+ for idx in patch:
61
+ if idx == self.eos_token_id:
62
+ break
63
+ if idx < self.eos_token_id:
64
+ pass
65
+ bytes += chr(idx)
66
+ return bytes
67
+
68
+
69
+ def patchilize_metadata(self, metadata_lines):
70
+
71
+ metadata_patches = []
72
+ for line in metadata_lines:
73
+ metadata_patches += self.split_patches(line)
74
+
75
+ return metadata_patches
76
+
77
+ def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
78
+
79
+ tunebody_patches = []
80
+ bars = self.split_bars(tunebody_lines)
81
+ if encode_mode == 'train':
82
+ for bar in bars:
83
+ tunebody_patches += self.split_patches(bar)
84
+ elif encode_mode == 'generate':
85
+ for bar in bars[:-1]:
86
+ tunebody_patches += self.split_patches(bar)
87
+ tunebody_patches += self.split_patches(bars[-1], generate_last=True)
88
+
89
+ return tunebody_patches
90
+
91
+ def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
92
+
93
+ lines = abc_text.split('\n')
94
+ lines = list(filter(None, lines))
95
+ lines = [line + '\n' for line in lines]
96
+
97
+ tunebody_index = -1
98
+ for i, line in enumerate(lines):
99
+ if '[V:' in line:
100
+ tunebody_index = i
101
+ break
102
+
103
+ metadata_lines = lines[ : tunebody_index]
104
+ tunebody_lines = lines[tunebody_index : ]
105
+
106
+ if self.stream:
107
+ tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
108
+ enumerate(tunebody_lines)]
109
+
110
+ metadata_patches = self.patchilize_metadata(metadata_lines)
111
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
112
+
113
+ if add_special_patches:
114
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
115
+ eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
116
+
117
+ metadata_patches = [bos_patch] + metadata_patches
118
+ tunebody_patches = tunebody_patches + [eos_patch]
119
+
120
+ if self.stream:
121
+ if len(metadata_patches) + len(tunebody_patches) > patch_length:
122
+ available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
123
+ line_index_for_cut_index = list(range(len(available_cut_indexes)))
124
+ end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
125
+ biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
126
+ available_cut_indexes = available_cut_indexes[:biggest_index + 1]
127
+
128
+ if len(available_cut_indexes) == 1:
129
+ choices = ['head']
130
+ elif len(available_cut_indexes) == 2:
131
+ choices = ['head', 'tail']
132
+ else:
133
+ choices = ['head', 'tail', 'middle']
134
+ choice = random.choice(choices)
135
+ if choice == 'head':
136
+ patches = metadata_patches + tunebody_patches[0:]
137
+ else:
138
+ if choice == 'tail':
139
+ cut_index = len(available_cut_indexes) - 1
140
+ else:
141
+ cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
142
+
143
+ line_index = line_index_for_cut_index[cut_index]
144
+ stream_tunebody_lines = tunebody_lines[line_index : ]
145
+
146
+ stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
147
+ if add_special_patches:
148
+ stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
149
+ patches = metadata_patches + stream_tunebody_patches
150
+ else:
151
+ patches = metadata_patches + tunebody_patches
152
+ else:
153
+ patches = metadata_patches + tunebody_patches
154
+
155
+ if cut:
156
+ patches = patches[ : patch_length]
157
+ else:
158
+ pass
159
+
160
+ # encode to ids
161
+ id_patches = []
162
+ for patch in patches:
163
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
164
+ id_patches.append(id_patch)
165
+
166
+ return id_patches
167
+
168
+ def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
169
+
170
+ lines = abc_code.split('\n')
171
+ lines = list(filter(None, lines))
172
+
173
+ tunebody_index = None
174
+ for i, line in enumerate(lines):
175
+ if line.startswith('[V:') or line.startswith('[r:'):
176
+ tunebody_index = i
177
+ break
178
+
179
+ metadata_lines = lines[ : tunebody_index]
180
+ tunebody_lines = lines[tunebody_index : ]
181
+
182
+ metadata_lines = [line + '\n' for line in metadata_lines]
183
+ if self.stream:
184
+ if not abc_code.endswith('\n'):
185
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
186
+ else:
187
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
188
+ else:
189
+ tunebody_lines = [line + '\n' for line in tunebody_lines]
190
+
191
+ metadata_patches = self.patchilize_metadata(metadata_lines)
192
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
193
+
194
+ if add_special_patches:
195
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
196
+
197
+ metadata_patches = [bos_patch] + metadata_patches
198
+
199
+ patches = metadata_patches + tunebody_patches
200
+ patches = patches[ : patch_length]
201
+
202
+ # encode to ids
203
+ id_patches = []
204
+ for patch in patches:
205
+ if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
206
+ id_patch = [ord(c) for c in patch]
207
+ else:
208
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
209
+ id_patches.append(id_patch)
210
+
211
+ return id_patches
212
+
213
+ def decode(self, patches):
214
+ """
215
+ Decode patches into music.
216
+ """
217
+ return ''.join(self.patch2chars(patch) for patch in patches)
218
+
219
+
220
+
221
+
222
+ class PatchLevelDecoder(PreTrainedModel):
223
+ """
224
+ A Patch-level Decoder model for generating patch features in an auto-regressive manner.
225
+ It inherits PreTrainedModel from transformers.
226
+ """
227
+ def __init__(self, config):
228
+ super().__init__(config)
229
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
230
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
231
+ self.base = GPT2Model(config)
232
+
233
+ def forward(self,
234
+ patches: torch.Tensor,
235
+ masks=None) -> torch.Tensor:
236
+ """
237
+ The forward pass of the patch-level decoder model.
238
+ :param patches: the patches to be encoded
239
+ :param masks: the masks for the patches
240
+ :return: the encoded patches
241
+ """
242
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
243
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
244
+ patches = self.patch_embedding(patches.to(self.device))
245
+
246
+ if masks==None:
247
+ return self.base(inputs_embeds=patches)
248
+ else:
249
+ return self.base(inputs_embeds=patches,
250
+ attention_mask=masks)
251
+
252
+
253
+ class CharLevelDecoder(PreTrainedModel):
254
+ """
255
+ A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
256
+ based on the encoded patch features. It inherits PreTrainedModel from transformers.
257
+ """
258
+ def __init__(self, config):
259
+ super().__init__(config)
260
+ self.special_token_id = 0
261
+ self.bos_token_id = 1
262
+
263
+ self.base = GPT2LMHeadModel(config)
264
+
265
+ def forward(self,
266
+ encoded_patches: torch.Tensor,
267
+ target_patches: torch.Tensor):
268
+ """
269
+ The forward pass of the char-level decoder model.
270
+ :param encoded_patches: the encoded patches
271
+ :param target_patches: the target patches
272
+ :return: the output of the model
273
+ """
274
+ # preparing the labels for model training
275
+ target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
276
+ # print('target_patches shape:', target_patches.shape)
277
+
278
+ target_masks = target_patches == self.special_token_id
279
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
280
+
281
+ # masking the labels for model training
282
+ target_masks = torch.ones_like(labels)
283
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
284
+
285
+ # select patches
286
+ if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
287
+ indices = list(range(len(target_patches)))
288
+ random.shuffle(indices)
289
+ selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
290
+
291
+ target_patches = target_patches[selected_indices,:]
292
+ target_masks = target_masks[selected_indices,:]
293
+ encoded_patches = encoded_patches[selected_indices,:]
294
+
295
+ # get input embeddings
296
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
297
+
298
+ # concatenate the encoded patches with the input embeddings
299
+ inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
300
+
301
+ output = self.base(inputs_embeds=inputs_embeds,
302
+ attention_mask=target_masks,
303
+ labels=labels)
304
+ # output_hidden_states=True=True)
305
+
306
+ return output
307
+
308
+ def generate(self,
309
+ encoded_patch: torch.Tensor, # [hidden_size]
310
+ tokens: torch.Tensor): # [1]
311
+ """
312
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
313
+ :param encoded_patch: the encoded patch
314
+ :param tokens: already generated tokens in the patch
315
+ :return: the probability distribution of next token
316
+ """
317
+ encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
318
+ tokens = tokens.reshape(1, -1)
319
+
320
+ # Get input embeddings
321
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
322
+
323
+ # Concatenate the encoded patch with the input embeddings
324
+ tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
325
+
326
+ # Get output from model
327
+ outputs = self.base(inputs_embeds=tokens)
328
+
329
+ # Get probabilities of next token
330
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
331
+
332
+ return probs
333
+
334
+ def safe_normalize_probs(probs):
335
+ epsilon = 1e-12
336
+ probs = np.array(probs, dtype=np.float64)
337
+ probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
338
+ probs = probs + epsilon
339
+ s = probs.sum()
340
+ if s > 0:
341
+ probs = probs / s
342
+ else:
343
+ probs = np.zeros_like(probs)
344
+ probs[0] = 1.0
345
+ return probs
346
+
347
+ class NotaGenLMHeadModel(PreTrainedModel):
348
+ """
349
+ NotaGen is a language model with a hierarchical structure.
350
+ It includes a patch-level decoder and a char-level decoder.
351
+ The patch-level decoder is used to generate patch features in an auto-regressive manner.
352
+ The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
353
+ It inherits PreTrainedModel from transformers.
354
+ """
355
+ def __init__(self, encoder_config, decoder_config):
356
+ super().__init__(encoder_config)
357
+ self.special_token_id = 0
358
+ self.bos_token_id = 1
359
+ self.eos_token_id = 2
360
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
361
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
362
+
363
+ def forward(self,
364
+ patches: torch.Tensor,
365
+ masks: torch.Tensor):
366
+ """
367
+ The forward pass of the bGPT model.
368
+ :param patches: the patches to be encoded
369
+ :param masks: the masks for the patches
370
+ :return: the decoded patches
371
+ """
372
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
373
+ encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
374
+
375
+ left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
376
+ masks[:, 0] = 0
377
+
378
+ encoded_patches = encoded_patches[left_shift_masks == 1]
379
+ patches = patches[masks == 1]
380
+
381
+ return self.char_level_decoder(encoded_patches, patches)
382
+
383
+ def generate(self,
384
+ patches: torch.Tensor,
385
+ top_k=0,
386
+ top_p=1,
387
+ temperature=1.0):
388
+ """
389
+ The generate function for generating patches based on patches.
390
+ :param patches: the patches to be encoded
391
+ :param top_k: the top k for sampling
392
+ :param top_p: the top p for sampling
393
+ :param temperature: the temperature for sampling
394
+ :return: the generated patches
395
+ """
396
+ if patches.shape[-1] % PATCH_SIZE != 0:
397
+ tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
398
+ tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
399
+ patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
400
+ else:
401
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
402
+
403
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
404
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
405
+ generated_patch = []
406
+
407
+ while True:
408
+ prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
409
+ prob = safe_normalize_probs(prob)
410
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
411
+ prob = safe_normalize_probs(prob)
412
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
413
+ prob = safe_normalize_probs(prob)
414
+ token = temperature_sampling(prob, temperature=temperature) # int
415
+ char = chr(token)
416
+ generated_patch.append(token)
417
+
418
+ if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
419
+ break
420
+ else:
421
+ tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
422
+
423
+ return generated_patch
utils (5).py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import bisect
4
+ import json
5
+ import re
6
+ import numpy as np
7
+ from config import *
8
+ from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
9
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
+ from tokenizers import Tokenizer
11
+
12
+
13
+ class Patchilizer:
14
+ def __init__(self, stream=PATCH_STREAM):
15
+ self.stream = stream
16
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
17
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
18
+ self.bos_token_id = 1
19
+ self.eos_token_id = 2
20
+ self.special_token_id = 0
21
+
22
+ def split_bars(self, body_lines):
23
+ """
24
+ Split a body of music into individual bars.
25
+ """
26
+ new_bars = []
27
+ try:
28
+ for line in body_lines:
29
+ line_bars = re.split(self.regexPattern, line)
30
+ line_bars = list(filter(None, line_bars))
31
+ new_line_bars = []
32
+
33
+ if len(line_bars) == 1:
34
+ new_line_bars = line_bars
35
+ else:
36
+ if line_bars[0] in self.delimiters:
37
+ new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
38
+ else:
39
+ new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
40
+ if 'V' not in new_line_bars[-1]:
41
+ new_line_bars[-2] += new_line_bars[-1]
42
+ new_line_bars = new_line_bars[:-1]
43
+ new_bars += new_line_bars
44
+ except:
45
+ pass
46
+
47
+ return new_bars
48
+
49
+ def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
50
+ if not generate_last and len(abc_text) % patch_size != 0:
51
+ abc_text += chr(self.eos_token_id)
52
+ patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
53
+ return patches
54
+
55
+ def patch2chars(self, patch):
56
+ """
57
+ Convert a patch into a bar.
58
+ """
59
+ bytes = ''
60
+ for idx in patch:
61
+ if idx == self.eos_token_id:
62
+ break
63
+ if idx < self.eos_token_id:
64
+ pass
65
+ bytes += chr(idx)
66
+ return bytes
67
+
68
+
69
+ def patchilize_metadata(self, metadata_lines):
70
+
71
+ metadata_patches = []
72
+ for line in metadata_lines:
73
+ metadata_patches += self.split_patches(line)
74
+
75
+ return metadata_patches
76
+
77
+ def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
78
+
79
+ tunebody_patches = []
80
+ bars = self.split_bars(tunebody_lines)
81
+ if encode_mode == 'train':
82
+ for bar in bars:
83
+ tunebody_patches += self.split_patches(bar)
84
+ elif encode_mode == 'generate':
85
+ for bar in bars[:-1]:
86
+ tunebody_patches += self.split_patches(bar)
87
+ tunebody_patches += self.split_patches(bars[-1], generate_last=True)
88
+
89
+ return tunebody_patches
90
+
91
+ def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
92
+
93
+ lines = abc_text.split('\n')
94
+ lines = list(filter(None, lines))
95
+ lines = [line + '\n' for line in lines]
96
+
97
+ tunebody_index = -1
98
+ for i, line in enumerate(lines):
99
+ if line.startswith('[V:'):
100
+ tunebody_index = i
101
+ break
102
+
103
+ metadata_lines = lines[ : tunebody_index]
104
+ tunebody_lines = lines[tunebody_index : ]
105
+
106
+ if self.stream:
107
+ tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
108
+ enumerate(tunebody_lines)] # [r:n/n]
109
+
110
+ metadata_patches = self.patchilize_metadata(metadata_lines)
111
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
112
+
113
+ if add_special_patches:
114
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
115
+ eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
116
+
117
+ metadata_patches = [bos_patch] + metadata_patches
118
+ tunebody_patches = tunebody_patches + [eos_patch]
119
+
120
+ if self.stream:
121
+ if len(metadata_patches) + len(tunebody_patches) > patch_length:
122
+ available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
123
+ line_index_for_cut_index = list(range(len(available_cut_indexes)))
124
+ end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
125
+ biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
126
+ available_cut_indexes = available_cut_indexes[:biggest_index + 1]
127
+
128
+ if len(available_cut_indexes) == 1:
129
+ choices = ['head']
130
+ elif len(available_cut_indexes) == 2:
131
+ choices = ['head', 'tail']
132
+ else:
133
+ choices = ['head', 'tail', 'middle']
134
+ choice = random.choice(choices)
135
+ if choice == 'head':
136
+ patches = metadata_patches + tunebody_patches[0:]
137
+ else:
138
+ if choice == 'tail':
139
+ cut_index = len(available_cut_indexes) - 1
140
+ else:
141
+ cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
142
+
143
+ line_index = line_index_for_cut_index[cut_index]
144
+ stream_tunebody_lines = tunebody_lines[line_index : ]
145
+
146
+ stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
147
+ if add_special_patches:
148
+ stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
149
+ patches = metadata_patches + stream_tunebody_patches
150
+ else:
151
+ patches = metadata_patches + tunebody_patches
152
+ else:
153
+ patches = metadata_patches + tunebody_patches
154
+
155
+ if cut:
156
+ patches = patches[ : patch_length]
157
+ else:
158
+ pass
159
+
160
+ # encode to ids
161
+ id_patches = []
162
+ for patch in patches:
163
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
164
+ id_patches.append(id_patch)
165
+
166
+ return id_patches
167
+
168
+ def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
169
+
170
+ lines = abc_code.split('\n')
171
+ lines = list(filter(None, lines))
172
+
173
+ tunebody_index = None
174
+ for i, line in enumerate(lines):
175
+ if line.startswith('[V:') or line.startswith('[r:'):
176
+ tunebody_index = i
177
+ break
178
+
179
+ metadata_lines = lines[ : tunebody_index]
180
+ tunebody_lines = lines[tunebody_index : ]
181
+
182
+ metadata_lines = [line + '\n' for line in metadata_lines]
183
+ if self.stream:
184
+ if not abc_code.endswith('\n'):
185
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
186
+ else:
187
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
188
+ else:
189
+ tunebody_lines = [line + '\n' for line in tunebody_lines]
190
+
191
+ metadata_patches = self.patchilize_metadata(metadata_lines)
192
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
193
+
194
+ if add_special_patches:
195
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
196
+
197
+ metadata_patches = [bos_patch] + metadata_patches
198
+
199
+ patches = metadata_patches + tunebody_patches
200
+ patches = patches[ : patch_length]
201
+
202
+ # encode to ids
203
+ id_patches = []
204
+ for patch in patches:
205
+ if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
206
+ id_patch = [ord(c) for c in patch]
207
+ else:
208
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
209
+ id_patches.append(id_patch)
210
+
211
+ return id_patches
212
+
213
+ def decode(self, patches):
214
+ """
215
+ Decode patches into music.
216
+ """
217
+ return ''.join(self.patch2chars(patch) for patch in patches)
218
+
219
+
220
+
221
+
222
+ class PatchLevelDecoder(PreTrainedModel):
223
+ """
224
+ A Patch-level Decoder model for generating patch features in an auto-regressive manner.
225
+ It inherits PreTrainedModel from transformers.
226
+ """
227
+ def __init__(self, config):
228
+ super().__init__(config)
229
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
230
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
231
+ self.base = GPT2Model(config)
232
+
233
+ def forward(self,
234
+ patches: torch.Tensor,
235
+ masks=None) -> torch.Tensor:
236
+ """
237
+ The forward pass of the patch-level decoder model.
238
+ :param patches: the patches to be encoded
239
+ :param masks: the masks for the patches
240
+ :return: the encoded patches
241
+ """
242
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
243
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
244
+ patches = self.patch_embedding(patches.to(self.device))
245
+
246
+ if masks==None:
247
+ return self.base(inputs_embeds=patches)
248
+ else:
249
+ return self.base(inputs_embeds=patches,
250
+ attention_mask=masks)
251
+
252
+
253
+ class CharLevelDecoder(PreTrainedModel):
254
+ """
255
+ A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
256
+ based on the encoded patch features. It inherits PreTrainedModel from transformers.
257
+ """
258
+ def __init__(self, config):
259
+ super().__init__(config)
260
+ self.special_token_id = 0
261
+ self.bos_token_id = 1
262
+
263
+ self.base = GPT2LMHeadModel(config)
264
+
265
+ def forward(self,
266
+ encoded_patches: torch.Tensor,
267
+ target_patches: torch.Tensor):
268
+ """
269
+ The forward pass of the char-level decoder model.
270
+ :param encoded_patches: the encoded patches
271
+ :param target_patches: the target patches
272
+ :return: the output of the model
273
+ """
274
+ # preparing the labels for model training
275
+ target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
276
+
277
+ target_masks = target_patches == self.special_token_id
278
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
279
+
280
+ # masking the labels for model training
281
+ target_masks = torch.ones_like(labels)
282
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
283
+
284
+ # select patches
285
+ if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
286
+ indices = list(range(len(target_patches)))
287
+ random.shuffle(indices)
288
+ selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
289
+
290
+ target_patches = target_patches[selected_indices,:]
291
+ target_masks = target_masks[selected_indices,:]
292
+ encoded_patches = encoded_patches[selected_indices,:]
293
+
294
+ # get input embeddings
295
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
296
+
297
+ # concatenate the encoded patches with the input embeddings
298
+ inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
299
+
300
+ output = self.base(inputs_embeds=inputs_embeds,
301
+ attention_mask=target_masks,
302
+ labels=labels)
303
+
304
+ return output
305
+
306
+ def generate(self,
307
+ encoded_patch: torch.Tensor,
308
+ tokens: torch.Tensor):
309
+ """
310
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
311
+ :param encoded_patch: the encoded patch
312
+ :param tokens: already generated tokens in the patch
313
+ :return: the probability distribution of next token
314
+ """
315
+ encoded_patch = encoded_patch.reshape(1, 1, -1)
316
+ tokens = tokens.reshape(1, -1)
317
+
318
+ # Get input embeddings
319
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
320
+
321
+ # Concatenate the encoded patch with the input embeddings
322
+ tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
323
+
324
+ # Get output from model
325
+ outputs = self.base(inputs_embeds=tokens)
326
+
327
+ # Get probabilities of next token
328
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
329
+
330
+ return probs
331
+
332
+ def safe_normalize_probs(probs):
333
+ epsilon = 1e-12
334
+ probs = np.array(probs, dtype=np.float64)
335
+ probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
336
+ probs = probs + epsilon
337
+ s = probs.sum()
338
+ if s > 0:
339
+ probs = probs / s
340
+ else:
341
+ probs = np.zeros_like(probs)
342
+ probs[0] = 1.0
343
+ return probs
344
+
345
+ class NotaGenLMHeadModel(PreTrainedModel):
346
+ """
347
+ NotaGen is a language model with a hierarchical structure.
348
+ It includes a patch-level decoder and a char-level decoder.
349
+ The patch-level decoder is used to generate patch features in an auto-regressive manner.
350
+ The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
351
+ It inherits PreTrainedModel from transformers.
352
+ """
353
+ def __init__(self, encoder_config, decoder_config):
354
+ super().__init__(encoder_config)
355
+ self.special_token_id = 0
356
+ self.bos_token_id = 1
357
+ self.eos_token_id = 2
358
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
359
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
360
+
361
+ def forward(self,
362
+ patches: torch.Tensor,
363
+ masks: torch.Tensor):
364
+ """
365
+ The forward pass of the bGPT model.
366
+ :param patches: the patches to be encoded
367
+ :param masks: the masks for the patches
368
+ :return: the decoded patches
369
+ """
370
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
371
+ encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
372
+
373
+ left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
374
+ masks[:, 0] = 0
375
+
376
+ encoded_patches = encoded_patches[left_shift_masks == 1]
377
+ patches = patches[masks == 1]
378
+
379
+ return self.char_level_decoder(encoded_patches, patches)
380
+
381
+ def generate(self,
382
+ patches: torch.Tensor,
383
+ top_k=0,
384
+ top_p=1,
385
+ temperature=1.0):
386
+ """
387
+ The generate function for generating patches based on patches.
388
+ :param patches: the patches to be encoded
389
+ :param top_k: the top k for sampling
390
+ :param top_p: the top p for sampling
391
+ :param temperature: the temperature for sampling
392
+ :return: the generated patches
393
+ """
394
+ if patches.shape[-1] % PATCH_SIZE != 0:
395
+ tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
396
+ tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
397
+ patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
398
+ else:
399
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
400
+
401
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
402
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
403
+ generated_patch = []
404
+
405
+ while True:
406
+ prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy()
407
+ prob = safe_normalize_probs(prob)
408
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
409
+ prob = safe_normalize_probs(prob)
410
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
411
+ prob = safe_normalize_probs(prob)
412
+ token = temperature_sampling(prob, temperature=temperature)
413
+ char = chr(token)
414
+ generated_patch.append(token)
415
+
416
+ if len(tokens) >= PATCH_SIZE:
417
+ break
418
+ else:
419
+ tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
420
+
421
+ return generated_patch
utils.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import bisect
4
+ import json
5
+ import re
6
+ import numpy as np
7
+ from config import *
8
+ from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
9
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
+ from tokenizers import Tokenizer
11
+
12
+
13
+ class Patchilizer:
14
+ def __init__(self, stream=PATCH_STREAM):
15
+ self.stream = stream
16
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
17
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
18
+ self.bos_token_id = 1
19
+ self.eos_token_id = 2
20
+ self.special_token_id = 0
21
+
22
+ def split_bars(self, body_lines):
23
+ """
24
+ Split a body of music into individual bars.
25
+ """
26
+ new_bars = []
27
+ try:
28
+ for line in body_lines:
29
+ line_bars = re.split(self.regexPattern, line)
30
+ line_bars = list(filter(None, line_bars))
31
+ new_line_bars = []
32
+
33
+ if len(line_bars) == 1:
34
+ new_line_bars = line_bars
35
+ else:
36
+ if line_bars[0] in self.delimiters:
37
+ new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
38
+ else:
39
+ new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
40
+ if 'V' not in new_line_bars[-1]:
41
+ new_line_bars[-2] += new_line_bars[-1]
42
+ new_line_bars = new_line_bars[:-1]
43
+ new_bars += new_line_bars
44
+ except:
45
+ pass
46
+
47
+ return new_bars
48
+
49
+ def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
50
+ if not generate_last and len(abc_text) % patch_size != 0:
51
+ abc_text += chr(self.eos_token_id)
52
+ patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
53
+ return patches
54
+
55
+ def patch2chars(self, patch):
56
+ """
57
+ Convert a patch into a bar.
58
+ """
59
+ bytes = ''
60
+ for idx in patch:
61
+ if idx == self.eos_token_id:
62
+ break
63
+ if idx < self.eos_token_id:
64
+ pass
65
+ bytes += chr(idx)
66
+ return bytes
67
+
68
+
69
+ def patchilize_metadata(self, metadata_lines):
70
+
71
+ metadata_patches = []
72
+ for line in metadata_lines:
73
+ metadata_patches += self.split_patches(line)
74
+
75
+ return metadata_patches
76
+
77
+ def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
78
+
79
+ tunebody_patches = []
80
+ bars = self.split_bars(tunebody_lines)
81
+ if encode_mode == 'train':
82
+ for bar in bars:
83
+ tunebody_patches += self.split_patches(bar)
84
+ elif encode_mode == 'generate':
85
+ for bar in bars[:-1]:
86
+ tunebody_patches += self.split_patches(bar)
87
+ tunebody_patches += self.split_patches(bars[-1], generate_last=True)
88
+
89
+ return tunebody_patches
90
+
91
+ def encode(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
92
+
93
+ lines = abc_text.split('\n')
94
+ lines = list(filter(None, lines))
95
+ lines = [line + '\n' for line in lines]
96
+
97
+ tunebody_index = -1
98
+ for i, line in enumerate(lines):
99
+ if line.startswith('[r:'):
100
+ tunebody_index = i
101
+ break
102
+
103
+ metadata_lines = lines[: tunebody_index]
104
+ tunebody_lines = lines[tunebody_index:]
105
+
106
+ metadata_patches = self.patchilize_metadata(metadata_lines)
107
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
108
+
109
+ if add_special_patches:
110
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
111
+ eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
112
+
113
+ metadata_patches = [bos_patch] + metadata_patches
114
+ tunebody_patches = tunebody_patches + [eos_patch]
115
+
116
+ if self.stream:
117
+ if len(metadata_patches) + len(tunebody_patches) > patch_length:
118
+ available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if
119
+ '\n' in patch]
120
+ line_index_for_cut_index = list(range(len(available_cut_indexes)))
121
+ end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
122
+ biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
123
+ available_cut_indexes = available_cut_indexes[:biggest_index + 1]
124
+
125
+ if len(available_cut_indexes) == 1:
126
+ choices = ['head']
127
+ elif len(available_cut_indexes) == 2:
128
+ choices = ['head', 'tail']
129
+ else:
130
+ choices = ['head', 'tail', 'middle']
131
+ choice = random.choice(choices)
132
+ if choice == 'head':
133
+ patches = metadata_patches + tunebody_patches[0:]
134
+ else:
135
+ if choice == 'tail':
136
+ cut_index = len(available_cut_indexes) - 1
137
+ else:
138
+ cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
139
+
140
+ line_index = line_index_for_cut_index[cut_index]
141
+ stream_tunebody_lines = tunebody_lines[line_index:]
142
+
143
+ stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
144
+ if add_special_patches:
145
+ stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
146
+ patches = metadata_patches + stream_tunebody_patches
147
+ else:
148
+ patches = metadata_patches + tunebody_patches
149
+ else:
150
+ patches = metadata_patches + tunebody_patches
151
+
152
+ patches = patches[: patch_length]
153
+
154
+ # encode to ids
155
+ id_patches = []
156
+ for patch in patches:
157
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
158
+ id_patches.append(id_patch)
159
+
160
+ return id_patches
161
+
162
+ def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
163
+
164
+ lines = abc_code.split('\n')
165
+ lines = list(filter(None, lines))
166
+
167
+ tunebody_index = None
168
+ for i, line in enumerate(lines):
169
+ if line.startswith('[V:') or line.startswith('[r:'):
170
+ tunebody_index = i
171
+ break
172
+
173
+ metadata_lines = lines[ : tunebody_index]
174
+ tunebody_lines = lines[tunebody_index : ]
175
+
176
+ metadata_lines = [line + '\n' for line in metadata_lines]
177
+ if self.stream:
178
+ if not abc_code.endswith('\n'):
179
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
180
+ else:
181
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
182
+ else:
183
+ tunebody_lines = [line + '\n' for line in tunebody_lines]
184
+
185
+ metadata_patches = self.patchilize_metadata(metadata_lines)
186
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
187
+
188
+ if add_special_patches:
189
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
190
+
191
+ metadata_patches = [bos_patch] + metadata_patches
192
+
193
+ patches = metadata_patches + tunebody_patches
194
+ patches = patches[ : patch_length]
195
+
196
+ # encode to ids
197
+ id_patches = []
198
+ for patch in patches:
199
+ if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
200
+ id_patch = [ord(c) for c in patch]
201
+ else:
202
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
203
+ id_patches.append(id_patch)
204
+
205
+ return id_patches
206
+
207
+ def decode(self, patches):
208
+ """
209
+ Decode patches into music.
210
+ """
211
+ return ''.join(self.patch2chars(patch) for patch in patches)
212
+
213
+
214
+
215
+
216
+ class PatchLevelDecoder(PreTrainedModel):
217
+ """
218
+ A Patch-level Decoder model for generating patch features in an auto-regressive manner.
219
+ It inherits PreTrainedModel from transformers.
220
+ """
221
+ def __init__(self, config):
222
+ super().__init__(config)
223
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
224
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
225
+ self.base = GPT2Model(config)
226
+
227
+ def forward(self,
228
+ patches: torch.Tensor,
229
+ masks=None) -> torch.Tensor:
230
+ """
231
+ The forward pass of the patch-level decoder model.
232
+ :param patches: the patches to be encoded
233
+ :param masks: the masks for the patches
234
+ :return: the encoded patches
235
+ """
236
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
237
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
238
+ patches = self.patch_embedding(patches.to(self.device))
239
+
240
+ if masks==None:
241
+ return self.base(inputs_embeds=patches)
242
+ else:
243
+ return self.base(inputs_embeds=patches,
244
+ attention_mask=masks)
245
+
246
+
247
+ class CharLevelDecoder(PreTrainedModel):
248
+ """
249
+ A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
250
+ based on the encoded patch features. It inherits PreTrainedModel from transformers.
251
+ """
252
+ def __init__(self, config):
253
+ super().__init__(config)
254
+ self.special_token_id = 0
255
+ self.bos_token_id = 1
256
+
257
+ self.base = GPT2LMHeadModel(config)
258
+
259
+ def forward(self,
260
+ encoded_patches: torch.Tensor,
261
+ target_patches: torch.Tensor):
262
+ """
263
+ The forward pass of the char-level decoder model.
264
+ :param encoded_patches: the encoded patches
265
+ :param target_patches: the target patches
266
+ :return: the output of the model
267
+ """
268
+ target_patches = torch.cat((torch.ones_like(target_patches[:, 0:1]) * self.bos_token_id,
269
+ target_patches), dim=1) # [patch_len, patch_size + 1]
270
+
271
+ target_masks = target_patches == self.special_token_id # [patch_len, patch_size + 1]
272
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
273
+
274
+ target_masks = torch.ones_like(labels)
275
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
276
+
277
+ input_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
278
+ input_embeds = torch.cat((encoded_patches.unsqueeze(1), input_embeds[:, 1:, :]), dim=1)
279
+ logits = self.base(inputs_embeds=input_embeds,
280
+ attention_mask=target_masks).logits # [patch_len, patch_size + 1, vocab_size]
281
+ logits = logits[:, :-1, :]
282
+ token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=target_patches[:, 1:].unsqueeze(-1)).squeeze(-1) # [patch_len, patch_size]
283
+ token_logps = token_logps[target_masks[:, 1:] == 1]
284
+ all_logps = token_logps.sum()
285
+
286
+ return all_logps
287
+
288
+ def generate(self,
289
+ encoded_patch: torch.Tensor, # [hidden_size]
290
+ tokens: torch.Tensor): # [1]
291
+ """
292
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
293
+ :param encoded_patch: the encoded patch
294
+ :param tokens: already generated tokens in the patch
295
+ :return: the probability distribution of next token
296
+ """
297
+ encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
298
+ tokens = tokens.reshape(1, -1)
299
+
300
+ # Get input embeddings
301
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
302
+
303
+ # Concatenate the encoded patch with the input embeddings
304
+ tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
305
+
306
+ # Get output from model
307
+ outputs = self.base(inputs_embeds=tokens)
308
+
309
+ # Get probabilities of next token
310
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
311
+
312
+ return probs
313
+
314
+ def safe_normalize_probs(probs):
315
+ epsilon = 1e-12
316
+ probs = np.array(probs, dtype=np.float64)
317
+ probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
318
+ probs = probs + epsilon
319
+ s = probs.sum()
320
+ if s > 0:
321
+ probs = probs / s
322
+ else:
323
+ probs = np.zeros_like(probs)
324
+ probs[0] = 1.0
325
+ return probs
326
+
327
+ class NotaGenLMHeadModel(PreTrainedModel):
328
+ """
329
+ NotaGen is a language model with a hierarchical structure.
330
+ It includes a patch-level decoder and a char-level decoder.
331
+ The patch-level decoder is used to generate patch features in an auto-regressive manner.
332
+ The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
333
+ It inherits PreTrainedModel from transformers.
334
+ """
335
+ def __init__(self, encoder_config, decoder_config):
336
+ super().__init__(encoder_config)
337
+ self.special_token_id = 0
338
+ self.bos_token_id = 1
339
+ self.eos_token_id = 2
340
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
341
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
342
+
343
+ def forward(self,
344
+ patches: torch.Tensor,
345
+ masks: torch.Tensor):
346
+ """
347
+ The forward pass of the bGPT model.
348
+ :param patches: the patches to be encoded
349
+ :param masks: the masks for the patches
350
+ :return: the decoded patches
351
+ """
352
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
353
+ encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
354
+
355
+ left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
356
+ masks[:, 0] = 0
357
+
358
+ encoded_patches = encoded_patches[left_shift_masks == 1]
359
+ patches = patches[masks == 1]
360
+
361
+ return self.char_level_decoder(encoded_patches, patches)
362
+
363
+ def generate(self,
364
+ patches: torch.Tensor,
365
+ top_k=0,
366
+ top_p=1,
367
+ temperature=1.0):
368
+ """
369
+ The generate function for generating patches based on patches.
370
+ :param patches: the patches to be encoded
371
+ :param top_k: the top k for sampling
372
+ :param top_p: the top p for sampling
373
+ :param temperature: the temperature for sampling
374
+ :return: the generated patches
375
+ """
376
+ if patches.shape[-1] % PATCH_SIZE != 0:
377
+ tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
378
+ tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
379
+ patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
380
+ else:
381
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
382
+
383
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
384
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
385
+ generated_patch = []
386
+
387
+ while True:
388
+ prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
389
+ prob = safe_normalize_probs(prob)
390
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
391
+ prob = safe_normalize_probs(prob)
392
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
393
+ prob = safe_normalize_probs(prob)
394
+ token = temperature_sampling(prob, temperature=temperature) # int
395
+ char = chr(token)
396
+ generated_patch.append(token)
397
+
398
+ if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
399
+ break
400
+ else:
401
+ tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
402
+
403
+ return generated_patch
404
+
405
+
406
+
xml2abc.py ADDED
@@ -0,0 +1,1609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=latin-1
3
+ '''
4
+ Copyright (C) 2012-2018: W.G. Vree
5
+ Contributions: M. Tarenskeen, N. Liberg, Paul Villiger, Janus Meuris, Larry Myerscough,
6
+ Dick Jackson, Jan Wybren de Jong, Mark Zealey.
7
+
8
+ This program is free software; you can redistribute it and/or modify it under the terms of the
9
+ Lesser GNU General Public License as published by the Free Software Foundation;
10
+
11
+ This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
12
+ without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13
+ See the Lesser GNU General Public License for more details. <http://www.gnu.org/licenses/lgpl.html>.
14
+ '''
15
+
16
+ '''Small revisions made for NotaGen to improve the succuess rate of conversion.'''
17
+
18
+ try: import xml.etree.cElementTree as E
19
+ except: import xml.etree.ElementTree as E
20
+ import os, sys, types, re, math
21
+
22
+ VERSION = 143
23
+
24
+ python3 = sys.version_info.major > 2
25
+ if python3:
26
+ tupletype = tuple
27
+ listtype = list
28
+ max_int = sys.maxsize
29
+ else:
30
+ tupletype = types.TupleType
31
+ listtype = types.ListType
32
+ max_int = sys.maxint
33
+
34
+ note_ornamentation_map = { # for notations/, modified from EasyABC
35
+ 'ornaments/trill-mark': 'T',
36
+ 'ornaments/mordent': 'M',
37
+ 'ornaments/inverted-mordent': 'P',
38
+ 'ornaments/turn': '!turn!',
39
+ 'ornaments/inverted-turn': '!invertedturn!',
40
+ 'technical/up-bow': 'u',
41
+ 'technical/down-bow': 'v',
42
+ 'technical/harmonic': '!open!',
43
+ 'technical/open-string': '!open!',
44
+ 'technical/stopped': '!plus!',
45
+ 'technical/snap-pizzicato': '!snap!',
46
+ 'technical/thumb-position': '!thumb!',
47
+ 'articulations/accent': '!>!',
48
+ 'articulations/strong-accent':'!^!',
49
+ 'articulations/staccato': '.',
50
+ 'articulations/staccatissimo':'!wedge!',
51
+ 'articulations/scoop': '!slide!',
52
+ 'fermata': '!fermata!',
53
+ 'arpeggiate': '!arpeggio!',
54
+ 'articulations/tenuto': '!tenuto!',
55
+ 'articulations/staccatissimo':'!wedge!', # not sure whether this is the right translation
56
+ 'articulations/spiccato': '!wedge!', # not sure whether this is the right translation
57
+ 'articulations/breath-mark': '!breath!', # this may need to be tested to make sure it appears on the right side of the note
58
+ 'articulations/detached-legato': '!tenuto!.',
59
+ }
60
+
61
+ dynamics_map = { # for direction/direction-type/dynamics/
62
+ 'p': '!p!',
63
+ 'pp': '!pp!',
64
+ 'ppp': '!ppp!',
65
+ 'pppp': '!pppp!',
66
+ 'f': '!f!',
67
+ 'ff': '!ff!',
68
+ 'fff': '!fff!',
69
+ 'ffff': '!ffff!',
70
+ 'mp': '!mp!',
71
+ 'mf': '!mf!',
72
+ 'sfz': '!sfz!',
73
+ }
74
+
75
+ percSvg = '''%%beginsvg
76
+ <defs>
77
+ <text id="x" x="-3" y="0">&#xe263;</text>
78
+ <text id="x-" x="-3" y="0">&#xe263;</text>
79
+ <text id="x+" x="-3" y="0">&#xe263;</text>
80
+ <text id="normal" x="-3.7" y="0">&#xe0a3;</text>
81
+ <text id="normal-" x="-3.7" y="0">&#xe0a3;</text>
82
+ <text id="normal+" x="-3.7" y="0">&#xe0a4;</text>
83
+ <g id="circle-x"><text x="-3" y="0">&#xe263;</text><circle r="4" class="stroke"></circle></g>
84
+ <g id="circle-x-"><text x="-3" y="0">&#xe263;</text><circle r="4" class="stroke"></circle></g>
85
+ <path id="triangle" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
86
+ <path id="triangle-" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
87
+ <path id="triangle+" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="fill:#000"></path>
88
+ <path id="square" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
89
+ <path id="square-" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
90
+ <path id="square+" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="fill:#000"></path>
91
+ <path id="diamond" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
92
+ <path id="diamond-" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
93
+ <path id="diamond+" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="fill:#000"></path>
94
+ </defs>
95
+ %%endsvg'''
96
+
97
+ tabSvg = '''%%beginsvg
98
+ <style type="text/css">
99
+ .bf {font-family:sans-serif; font-size:7px}
100
+ </style>
101
+ <defs>
102
+ <rect id="clr" x="-3" y="-1" width="6" height="5" fill="white"></rect>
103
+ <rect id="clr2" x="-3" y="-1" width="11" height="5" fill="white"></rect>'''
104
+
105
+ kopSvg = '<g id="kop%s" class="bf"><use xlink:href="#clr"></use><text x="-2" y="3">%s</text></g>\n'
106
+ kopSvg2 = '<g id="kop%s" class="bf"><use xlink:href="#clr2"></use><text x="-2" y="3">%s</text></g>\n'
107
+
108
+ def info (s, warn=1): sys.stderr.write ((warn and '-- ' or '') + s + '\n')
109
+
110
+ #-------------------
111
+ # data abstractions
112
+ #-------------------
113
+ class Measure:
114
+ def __init__ (s, p):
115
+ s.reset ()
116
+ s.ixp = p # part number
117
+ s.ixm = 0 # measure number
118
+ s.mdur = 0 # measure duration (nominal metre value in divisions)
119
+ s.divs = 0 # number of divisions per 1/4
120
+ s.mtr = 4,4 # meter
121
+
122
+ def reset (s): # reset each measure
123
+ s.attr = '' # measure signatures, tempo
124
+ s.lline = '' # left barline, but only holds ':' at start of repeat, otherwise empty
125
+ s.rline = '|' # right barline
126
+ s.lnum = '' # (left) volta number
127
+
128
+ class Note:
129
+ def __init__ (s, dur=0, n=None):
130
+ s.tijd = 0 # the time in XML division units
131
+ s.dur = dur # duration of a note in XML divisions
132
+ s.fact = None # time modification for tuplet notes (num, div)
133
+ s.tup = [''] # start(s) and/or stop(s) of tuplet
134
+ s.tupabc = '' # abc tuplet string to issue before note
135
+ s.beam = 0 # 1 = beamed
136
+ s.grace = 0 # 1 = grace note
137
+ s.before = [] # abc string that goes before the note/chord
138
+ s.after = '' # the same after the note/chord
139
+ s.ns = n and [n] or [] # notes in the chord
140
+ s.lyrs = {} # {number -> syllabe}
141
+ s.tab = None # (string number, fret number)
142
+ s.ntdec = '' # !string!, !courtesy!
143
+
144
+ class Elem:
145
+ def __init__ (s, string):
146
+ s.tijd = 0 # the time in XML division units
147
+ s.str = string # any abc string that is not a note
148
+
149
+ class Counter:
150
+ def inc (s, key, voice): s.counters [key][voice] = s.counters [key].get (voice, 0) + 1
151
+ def clear (s, vnums): # reset all counters
152
+ tups = list( zip (vnums.keys (), len (vnums) * [0]))
153
+ s.counters = {'note': dict (tups), 'nopr': dict (tups), 'nopt': dict (tups)}
154
+ def getv (s, key, voice): return s.counters[key][voice]
155
+ def prcnt (s, ip): # print summary of all non zero counters
156
+ for iv in s.counters ['note']:
157
+ if s.getv ('nopr', iv) != 0:
158
+ info ( 'part %d, voice %d has %d skipped non printable notes' % (ip, iv, s.getv ('nopr', iv)))
159
+ if s.getv ('nopt', iv) != 0:
160
+ info ( 'part %d, voice %d has %d notes without pitch' % (ip, iv, s.getv ('nopt', iv)))
161
+ if s.getv ('note', iv) == 0: # no real notes counted in this voice
162
+ info ( 'part %d, skipped empty voice %d' % (ip, iv))
163
+
164
+ class Music:
165
+ def __init__(s, options):
166
+ s.tijd = 0 # the current time
167
+ s.maxtime = 0 # maximum time in a measure
168
+ s.gMaten = [] # [voices,.. for all measures in a part]
169
+ s.gLyrics = [] # [{num: (abc_lyric_string, melis)},.. for all measures in a part]
170
+ s.vnums = {} # all used voice id's in a part (xml voice id's == numbers)
171
+ s.cnt = Counter () # global counter object
172
+ s.vceCnt = 1 # the global voice count over all parts
173
+ s.lastnote = None # the last real note record inserted in s.voices
174
+ s.bpl = options.b # the max number of bars per line when writing abc
175
+ s.cpl = options.n # the number of chars per line when writing abc
176
+ s.repbra = 0 # true if volta is used somewhere
177
+ s.nvlt = options.v # no volta on higher voice numbers
178
+ s.jscript = options.j # compatibility with javascript version
179
+
180
+ def initVoices (s, newPart=0):
181
+ s.vtimes, s.voices, s.lyrics = {}, {}, {}
182
+ for v in s.vnums:
183
+ s.vtimes [v] = 0 # {voice: the end time of the last item in each voice}
184
+ s.voices [v] = [] # {voice: [Note|Elem, ..]}
185
+ s.lyrics [v] = [] # {voice: [{num: syl}, ..]}
186
+ if newPart: s.cnt.clear (s.vnums) # clear counters once per part
187
+
188
+ def incTime (s, dt):
189
+ s.tijd += dt
190
+ if s.tijd < 0: s.tijd = 0 # erroneous <backup> element
191
+ if s.tijd > s.maxtime: s.maxtime = s.tijd
192
+
193
+ def appendElemCv (s, voices, elem):
194
+ for v in voices:
195
+ s.appendElem (v, elem) # insert element in all voices
196
+
197
+ def insertElem (s, v, elem): # insert at the start of voice v in the current measure
198
+ obj = Elem (elem)
199
+ obj.tijd = 0 # because voice is sorted later
200
+ s.voices [v].insert (0, obj)
201
+
202
+ def appendObj (s, v, obj, dur):
203
+ obj.tijd = s.tijd
204
+ s.voices [v].append (obj)
205
+ s.incTime (dur)
206
+ if s.tijd > s.vtimes[v]: s.vtimes[v] = s.tijd # don't update for inserted earlier items
207
+
208
+ def appendElem (s, v, elem, tel=0):
209
+ s.appendObj (v, Elem (elem), 0)
210
+ if tel: s.cnt.inc ('note', v) # count number of certain elements in each voice (in addition to notes)
211
+
212
+ def appendElemT (s, v, elem, tijd): # insert element at specified time
213
+ obj = Elem (elem)
214
+ obj.tijd = tijd
215
+ s.voices [v].append (obj)
216
+
217
+ def appendNote (s, v, note, noot):
218
+ note.ns.append (note.ntdec + noot)
219
+ s.appendObj (v, note, int (note.dur))
220
+ s.lastnote = note # remember last note/rest for later modifications (chord, grace)
221
+ if noot != 'z' and noot != 'x': # real notes and grace notes
222
+ s.cnt.inc ('note', v) # count number of real notes in each voice
223
+ if not note.grace: # for every real note
224
+ s.lyrics[v].append (note.lyrs) # even when it has no lyrics
225
+
226
+ def getLastRec (s, voice):
227
+ if s.gMaten: return s.gMaten[-1][voice][-1] # the last record in the last measure
228
+ return None # no previous records in the first measure
229
+
230
+ def getLastMelis (s, voice, num): # get melisma of last measure
231
+ if s.gLyrics:
232
+ lyrdict = s.gLyrics[-1][voice] # the previous lyrics dict in this voice
233
+ if num in lyrdict: return lyrdict[num][1] # lyrdict = num -> (lyric string, melisma)
234
+ return 0 # no previous lyrics in voice or line number
235
+
236
+ def addChord (s, note, noot): # careful: we assume that chord notes follow immediately
237
+ for d in note.before: # put all decorations before chord
238
+ if d not in s.lastnote.before:
239
+ s.lastnote.before += [d]
240
+ s.lastnote.ns.append (note.ntdec + noot)
241
+
242
+ def addBar (s, lbrk, m): # linebreak, measure data
243
+ if m.mdur and s.maxtime > m.mdur: info ('measure %d in part %d longer than metre' % (m.ixm+1, m.ixp+1))
244
+ s.tijd = s.maxtime # the time of the bar lines inserted here
245
+ for v in s.vnums:
246
+ if m.lline or m.lnum: # if left barline or left volta number
247
+ p = s.getLastRec (v) # get the previous barline record
248
+ if p: # in measure 1 no previous measure is available
249
+ x = p.str # p.str is the ABC barline string
250
+ if m.lline: # append begin of repeat, m.lline == ':'
251
+ x = (x + m.lline).replace (':|:','::').replace ('||','|')
252
+ if s.nvlt == 3: # add volta number only to lowest voice in part 0
253
+ if m.ixp + v == min (s.vnums): x += m.lnum
254
+ elif m.lnum: # new behaviour with I:repbra 0
255
+ x += m.lnum # add volta number(s) or text to all voices
256
+ s.repbra = 1 # signal occurrence of a volta
257
+ p.str = x # modify previous right barline
258
+ elif m.lline: # begin of new part and left repeat bar is required
259
+ s.insertElem (v, '|:')
260
+ if lbrk:
261
+ p = s.getLastRec (v) # get the previous barline record
262
+ if p: p.str += lbrk # insert linebreak char after the barlines+volta
263
+ if m.attr: # insert signatures at front of buffer
264
+ s.insertElem (v, '%s' % m.attr)
265
+ s.appendElem (v, ' %s' % m.rline) # insert current barline record at time maxtime
266
+ s.voices[v] = sortMeasure (s.voices[v], m) # make all times consistent
267
+ lyrs = s.lyrics[v] # [{number: sylabe}, .. for all notes]
268
+ lyrdict = {} # {number: (abc_lyric_string, melis)} for this voice
269
+ nums = [num for d in lyrs for num in d.keys ()] # the lyrics numbers in this measure
270
+ maxNums = max (nums + [0]) # the highest lyrics number in this measure
271
+ for i in range (maxNums, 0, -1):
272
+ xs = [syldict.get (i, '') for syldict in lyrs] # collect the syllabi with number i
273
+ melis = s.getLastMelis (v, i) # get melisma from last measure
274
+ lyrdict [i] = abcLyr (xs, melis)
275
+ s.lyrics[v] = lyrdict # {number: (abc_lyric_string, melis)} for this measure
276
+ mkBroken (s.voices[v])
277
+ s.gMaten.append (s.voices)
278
+ s.gLyrics.append (s.lyrics)
279
+ s.tijd = s.maxtime = 0
280
+ s.initVoices ()
281
+
282
+ def outVoices (s, divs, ip, isSib): # output all voices of part ip
283
+ vvmap = {} # xml voice number -> abc voice number (one part)
284
+ vnum_keys = list (s.vnums.keys ())
285
+ if s.jscript or isSib: vnum_keys.sort ()
286
+ lvc = min (vnum_keys or [1]) # lowest xml voice number of this part
287
+ for iv in vnum_keys:
288
+ if s.cnt.getv ('note', iv) == 0: # no real notes counted in this voice
289
+ continue # skip empty voices
290
+ if abcOut.denL: unitL = abcOut.denL # take the unit length from the -d option
291
+ else: unitL = compUnitLength (iv, s.gMaten, divs) # compute the best unit length for this voice
292
+ abcOut.cmpL.append (unitL) # remember for header output
293
+ vn, vl = [], {} # for voice iv: collect all notes to vn and all lyric lines to vl
294
+ for im in range (len (s.gMaten)):
295
+ measure = s.gMaten [im][iv]
296
+ vn.append (outVoice (measure, divs [im], im, ip, unitL))
297
+ checkMelismas (s.gLyrics, s.gMaten, im, iv)
298
+ for n, (lyrstr, melis) in s.gLyrics [im][iv].items ():
299
+ if n in vl:
300
+ while len (vl[n]) < im: vl[n].append ('') # fill in skipped measures
301
+ vl[n].append (lyrstr)
302
+ else:
303
+ vl[n] = im * [''] + [lyrstr] # must skip im measures
304
+ for n, lyrs in vl.items (): # fill up possibly empty lyric measures at the end
305
+ mis = len (vn) - len (lyrs)
306
+ lyrs += mis * ['']
307
+ abcOut.add ('V:%d' % s.vceCnt)
308
+ if s.repbra:
309
+ if s.nvlt == 1 and s.vceCnt > 1: abcOut.add ('I:repbra 0') # only volta on first voice
310
+ if s.nvlt == 2 and iv > lvc: abcOut.add ('I:repbra 0') # only volta on first voice of each part
311
+ if s.cpl > 0: s.bpl = 0 # option -n (max chars per line) overrules -b (max bars per line)
312
+ elif s.bpl == 0: s.cpl = 100 # the default: 100 chars per line
313
+ bn = 0 # count bars
314
+ while vn: # while still measures available
315
+ ib = 1
316
+ chunk = vn [0]
317
+ while ib < len (vn):
318
+ if s.cpl > 0 and len (chunk) + len (vn [ib]) >= s.cpl: break # line full (number of chars)
319
+ if s.bpl > 0 and ib >= s.bpl: break # line full (number of bars)
320
+ chunk += vn [ib]
321
+ ib += 1
322
+ bn += ib
323
+ abcOut.add (chunk + ' %%%d' % bn) # line with barnumer
324
+ del vn[:ib] # chop ib bars
325
+ lyrlines = sorted (vl.items ()) # order the numbered lyric lines for output
326
+ for n, lyrs in lyrlines:
327
+ abcOut.add ('w: ' + '|'.join (lyrs[:ib]) + '|')
328
+ del lyrs[:ib]
329
+ vvmap [iv] = s.vceCnt # xml voice number -> abc voice number
330
+ s.vceCnt += 1 # count voices over all parts
331
+ s.gMaten = [] # reset the follwing instance vars for each part
332
+ s.gLyrics = []
333
+ s.cnt.prcnt (ip+1) # print summary of skipped items in this part
334
+ return vvmap
335
+
336
+ class ABCoutput:
337
+ pagekeys = 'scale,pageheight,pagewidth,leftmargin,rightmargin,topmargin,botmargin'.split (',')
338
+ def __init__ (s, fnmext, pad, X, options):
339
+ s.fnmext = fnmext
340
+ s.outlist = [] # list of ABC strings
341
+ s.title = 'T:Title'
342
+ s.key = 'none'
343
+ s.clefs = {} # clefs for all abc-voices
344
+ s.mtr = 'none'
345
+ s.tempo = 0 # 0 -> no tempo field
346
+ s.tempo_units = (1,4) # note type of tempo direction
347
+ s.pad = pad # the output path or none
348
+ s.X = X + 1 # the abc tune number
349
+ s.denL = options.d # denominator of the unit length (L:) from -d option
350
+ s.volpan = int (options.m) # 0 -> no %%MIDI, 1 -> only program, 2 -> all %%MIDI
351
+ s.cmpL = [] # computed optimal unit length for all voices
352
+ s.jscript = options.j # compatibility with javascript version
353
+ s.tstep = options.t # translate percmap to voicemap
354
+ s.stemless = 0 # use U:s=!stemless!
355
+ s.shiftStem = options.s # shift note heads 3 units left
356
+ if pad:
357
+ _, base_name = os.path.split (fnmext)
358
+ s.outfile = open (os.path.join (pad, base_name), 'w', encoding='utf-8')
359
+ else: s.outfile = sys.stdout
360
+ if s.jscript: s.X = 1 # always X:1 in javascript version
361
+ s.pageFmt = {}
362
+ for k in s.pagekeys: s.pageFmt [k] = None
363
+ if len (options.p) == 7:
364
+ for k, v in zip (s.pagekeys, options.p):
365
+ try: s.pageFmt [k] = float (v)
366
+ except: info ('illegal float %s for %s', (k, v)); continue
367
+
368
+ def add (s, str):
369
+ s.outlist.append (str + '\n') # collect all ABC output
370
+
371
+ def mkHeader (s, stfmap, partlist, midimap, vmpdct, koppen): # stfmap = [parts], part = [staves], stave = [voices]
372
+ accVce, accStf, staffs = [], [], stfmap[:] # staffs is consumed
373
+ for x in partlist: # collect partnames into accVce and staff groups into accStf
374
+ try: prgroupelem (x, ('', ''), '', stfmap, accVce, accStf)
375
+ except: info ('lousy musicxml: error in part-list')
376
+ staves = ' '.join (accStf)
377
+ clfnms = {}
378
+ for part, (partname, partabbrv) in zip (staffs, accVce):
379
+ if not part: continue # skip empty part
380
+ firstVoice = part[0][0] # the first voice number in this part
381
+ nm = partname.replace ('\n','\\n').replace ('.:','.').strip (':')
382
+ snm = partabbrv.replace ('\n','\\n').replace ('.:','.').strip (':')
383
+ clfnms [firstVoice] = (nm and 'nm="%s"' % nm or '') + (snm and ' snm="%s"' % snm or '')
384
+ hd = ['X:%d\n%s\n' % (s.X, s.title)]
385
+ for i, k in enumerate (s.pagekeys):
386
+ if s.jscript and k in ['pageheight','topmargin', 'botmargin']: continue
387
+ if s.pageFmt [k] != None: hd.append ('%%%%%s %.2f%s\n' % (k, s.pageFmt [k], i > 0 and 'cm' or ''))
388
+ if staves and len (accStf) > 1: hd.append ('%%score ' + staves + '\n')
389
+ tempo = s.tempo and 'Q:%d/%d=%s\n' % (s.tempo_units [0], s.tempo_units [1], s.tempo) or '' # default no tempo field
390
+ d = {} # determine the most frequently occurring unit length over all voices
391
+ for x in s.cmpL: d[x] = d.get (x, 0) + 1
392
+ if s.jscript: defLs = sorted (d.items (), key=lambda x: (-x[1], x[0])) # when tie (1) sort on key (0)
393
+ else: defLs = sorted (d.items (), key=lambda x: -x[1])
394
+ defL = s.denL and s.denL or defLs [0][0] # override default unit length with -d option
395
+ hd.append ('L:1/%d\n%sM:%s\n' % (defL, tempo, s.mtr))
396
+ hd.append ('K:%s\n' % s.key)
397
+ if s.stemless: hd.append ('U:s=!stemless!\n')
398
+ vxs = sorted (vmpdct.keys ())
399
+ for vx in vxs: hd.extend (vmpdct [vx])
400
+ s.dojef = 0 # translate percmap to voicemap
401
+ for vnum, clef in s.clefs.items ():
402
+ ch, prg, vol, pan = midimap [vnum-1][:4]
403
+ dmap = midimap [vnum - 1][4:] # map of abc percussion notes to midi notes
404
+ if dmap and 'perc' not in clef: clef = (clef + ' map=perc').strip ();
405
+ hd.append ('V:%d %s %s\n' % (vnum, clef, clfnms.get (vnum, '')))
406
+ if vnum in vmpdct:
407
+ hd.append ('%%%%voicemap tab%d\n' % vnum)
408
+ hd.append ('K:none\nM:none\n%%clef none\n%%staffscale 1.6\n%%flatbeams true\n%%stemdir down\n')
409
+ if 'perc' in clef: hd.append ('K:none\n'); # no key for a perc voice
410
+ if s.volpan > 1: # option -m 2 -> output all recognized midi commands when needed and present in xml
411
+ if ch > 0 and ch != vnum: hd.append ('%%%%MIDI channel %d\n' % ch)
412
+ if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
413
+ if vol >= 0: hd.append ('%%%%MIDI control 7 %.0f\n' % vol) # volume == 0 is possible ...
414
+ if pan >= 0: hd.append ('%%%%MIDI control 10 %.0f\n' % pan)
415
+ elif s.volpan > 0: # default -> only output midi program command when present in xml
416
+ if dmap and ch > 0: hd.append ('%%%%MIDI channel %d\n' % ch) # also channel if percussion part
417
+ if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
418
+ for abcNote, step, midiNote, notehead in dmap:
419
+ if not notehead: notehead = 'normal'
420
+ if abcMid (abcNote) != midiNote or abcNote != step:
421
+ if s.volpan > 0: hd.append ('%%%%MIDI drummap %s %s\n' % (abcNote, midiNote))
422
+ hd.append ('I:percmap %s %s %s %s\n' % (abcNote, step, midiNote, notehead))
423
+ s.dojef = s.tstep
424
+ if defL != s.cmpL [vnum-1]: # only if computed unit length different from header
425
+ hd.append ('L:1/%d\n' % s.cmpL [vnum-1])
426
+ s.outlist = hd + s.outlist
427
+ if koppen: # output SVG stuff needed for tablature
428
+ k1 = kopSvg.replace ('-2','-5') if s.shiftStem else kopSvg # shift note heads 3 units left
429
+ k2 = kopSvg2.replace ('-2','-5') if s.shiftStem else kopSvg2
430
+ tb = tabSvg.replace ('-3','-6') if s.shiftStem else tabSvg
431
+ ks = sorted (koppen.keys ()) # javascript compatibility
432
+ ks = [k2 % (k, k) if len (k) == 2 else k1 % (k, k) for k in ks]
433
+ tbs = map (lambda x: x.strip () + '\n', tb.splitlines ()) # javascript compatibility
434
+ s.outlist = tbs + ks + ['</defs>\n%%endsvg\n'] + s.outlist
435
+
436
+ def writeall (s): # determine the required encoding of the entire ABC output
437
+ str = ''.join (s.outlist)
438
+ # print(str)
439
+ if s.dojef: str = perc2map (str)
440
+ if python3: s.outfile.write (str)
441
+ else: s.outfile.write (str)
442
+ if s.pad: s.outfile.close () # close each file with -o option
443
+ else: s.outfile.write ('\n') # add empty line between tunes on stdout
444
+ info ('%s written with %d voices' % (s.fnmext, len (s.clefs)), warn=0)
445
+
446
+ #----------------
447
+ # functions
448
+ #----------------
449
+ def abcLyr (xs, melis): # Convert list xs to abc lyrics.
450
+ if not ''.join (xs): return '', 0 # there is no lyrics in this measure
451
+ res = []
452
+ for x in xs: # xs has for every note a lyrics syllabe or an empty string
453
+ if x == '': # note without lyrics
454
+ if melis: x = '_' # set melisma
455
+ else: x = '*' # skip note
456
+ elif x.endswith ('_') and not x.endswith ('\_'): # start of new melisma
457
+ x = x.replace ('_', '') # remove and set melis boolean
458
+ melis = 1 # so next skips will become melisma
459
+ else: melis = 0 # melisma stops on first syllable
460
+ res.append (x)
461
+ return (' '.join (res), melis)
462
+
463
+ def simplify (a, b): # divide a and b by their greatest common divisor
464
+ x, y = a, b
465
+ while b: a, b = b, a % b
466
+ return x // a, y // a
467
+
468
+ def abcdur (nx, divs, uL): # convert an musicXML duration d to abc units with L:1/uL
469
+ if nx.dur == 0: return '' # when called for elements without duration
470
+ num, den = simplify (uL * nx.dur, divs * 4) # L=1/8 -> uL = 8 units
471
+ if nx.fact: # apply tuplet time modification
472
+ numfac, denfac = nx.fact
473
+ num, den = simplify (num * numfac, den * denfac)
474
+ if den > 64: # limit the denominator to a maximum of 64
475
+ x = float (num) / den; n = math.floor (x); # when just above an integer n
476
+ if x - n < 0.1 * x: num, den = n, 1; # round to n
477
+ num64 = 64. * num / den + 1.0e-15 # to get Python2 behaviour of round
478
+ num, den = simplify (int (round (num64)), 64)
479
+ if num == 1:
480
+ if den == 1: dabc = ''
481
+ elif den == 2: dabc = '/'
482
+ else: dabc = '/%d' % den
483
+ elif den == 1: dabc = '%d' % num
484
+ else: dabc = '%d/%d' % (num, den)
485
+ return dabc
486
+
487
+ def abcMid (note): # abc note -> midi pitch
488
+ r = re.search (r"([_^]*)([A-Ga-g])([',]*)", note)
489
+ if not r: return -1
490
+ acc, n, oct = r.groups ()
491
+ nUp = n.upper ()
492
+ p = 60 + [0,2,4,5,7,9,11]['CDEFGAB'.index (nUp)] + (12 if nUp != n else 0);
493
+ if acc: p += (1 if acc[0] == '^' else -1) * len (acc)
494
+ if oct: p += (12 if oct[0] == "'" else -12) * len (oct)
495
+ return p
496
+
497
+ def staffStep (ptc, o, clef, tstep):
498
+ ndif = 0
499
+ if 'stafflines=1' in clef: ndif += 4 # meaning of one line: E (xml) -> B (abc)
500
+ if not tstep and clef.startswith ('bass'): ndif += 12 # transpose bass -> treble (C3 -> A4)
501
+ if ndif: # diatonic transposition == addition modulo 7
502
+ nm7 = 'C,D,E,F,G,A,B'.split (',')
503
+ n = nm7.index (ptc) + ndif
504
+ ptc, o = nm7 [n % 7], o + n // 7
505
+ if o > 4: ptc = ptc.lower ()
506
+ if o > 5: ptc = ptc + (o-5) * "'"
507
+ if o < 4: ptc = ptc + (4-o) * ","
508
+ return ptc
509
+
510
+ def setKey (fifths, mode):
511
+ sharpness = ['Fb', 'Cb','Gb','Db','Ab','Eb','Bb','F','C','G','D','A', 'E', 'B', 'F#','C#','G#','D#','A#','E#','B#']
512
+ offTab = {'maj':8, 'ion':8, 'm':11, 'min':11, 'aeo':11, 'mix':9, 'dor':10, 'phr':12, 'lyd':7, 'loc':13, 'non':8}
513
+ mode = mode.lower ()[:3] # only first three chars, no case
514
+ key = sharpness [offTab [mode] + fifths] + (mode if offTab [mode] != 8 else '')
515
+ accs = ['F','C','G','D','A','E','B']
516
+ if fifths >= 0: msralts = dict (zip (accs[:fifths], fifths * [1]))
517
+ else: msralts = dict (zip (accs[fifths:], -fifths * [-1]))
518
+ return key, msralts
519
+
520
+ def insTup (ix, notes, fact): # read one nested tuplet
521
+ tupcnt = 0
522
+ nx = notes [ix]
523
+ if 'start' in nx.tup:
524
+ nx.tup.remove ('start') # do recursive calls when starts remain
525
+ tix = ix # index of first tuplet note
526
+ fn, fd = fact # xml time-mod of the higher level
527
+ fnum, fden = nx.fact # xml time-mod of the current level
528
+ tupfact = fnum//fn, fden//fd # abc time mod of this level
529
+ while ix < len (notes):
530
+ nx = notes [ix]
531
+ if isinstance (nx, Elem) or nx.grace:
532
+ ix += 1 # skip all non tuplet elements
533
+ continue
534
+ if 'start' in nx.tup: # more nested tuplets to start
535
+ ix, tupcntR = insTup (ix, notes, tupfact) # ix is on the stop note!
536
+ tupcnt += tupcntR
537
+ elif nx.fact:
538
+ tupcnt += 1 # count tuplet elements
539
+ if 'stop' in nx.tup:
540
+ nx.tup.remove ('stop')
541
+ break
542
+ if not nx.fact: # stop on first non tuplet note
543
+ ix = lastix # back to last tuplet note
544
+ break
545
+ lastix = ix
546
+ ix += 1
547
+ # put abc tuplet notation before the recursive ones
548
+ tup = (tupfact[0], tupfact[1], tupcnt)
549
+ if tup == (3, 2, 3): tupPrefix = '(3'
550
+ else: tupPrefix = '(%d:%d:%d' % tup
551
+ notes [tix].tupabc = tupPrefix + notes [tix].tupabc
552
+ return ix, tupcnt # ix is on the last tuplet note
553
+
554
+ def mkBroken (vs): # introduce broken rhythms (vs: one voice, one measure)
555
+ vs = [n for n in vs if isinstance (n, Note)]
556
+ i = 0
557
+ while i < len (vs) - 1:
558
+ n1, n2 = vs[i], vs[i+1] # scan all adjacent pairs
559
+ # skip if note in tuplet or has no duration or outside beam
560
+ if not n1.fact and not n2.fact and n1.dur > 0 and n2.beam:
561
+ if n1.dur * 3 == n2.dur:
562
+ n2.dur = (2 * n2.dur) // 3
563
+ n1.dur = n1.dur * 2
564
+ n1.after = '<' + n1.after
565
+ i += 1 # do not chain broken rhythms
566
+ elif n2.dur * 3 == n1.dur:
567
+ n1.dur = (2 * n1.dur) // 3
568
+ n2.dur = n2.dur * 2
569
+ n1.after = '>' + n1.after
570
+ i += 1 # do not chain broken rhythms
571
+ i += 1
572
+
573
+ def outVoice (measure, divs, im, ip, unitL): # note/elem objects of one measure in one voice
574
+ ix = 0
575
+ while ix < len (measure): # set all (nested) tuplet annotations
576
+ nx = measure [ix]
577
+ if isinstance (nx, Note) and nx.fact and not nx.grace:
578
+ ix, tupcnt = insTup (ix, measure, (1, 1)) # read one tuplet, insert annotation(s)
579
+ ix += 1
580
+ vs = []
581
+ for nx in measure:
582
+ if isinstance (nx, Note):
583
+ durstr = abcdur (nx, divs, unitL) # xml -> abc duration string
584
+ chord = len (nx.ns) > 1
585
+ cns = [nt[:-1] for nt in nx.ns if nt.endswith ('-')]
586
+ tie = ''
587
+ if chord and len (cns) == len (nx.ns): # all chord notes tied
588
+ nx.ns = cns # chord notes without tie
589
+ tie = '-' # one tie for whole chord
590
+ s = nx.tupabc + ''.join (nx.before)
591
+ if chord: s += '['
592
+ for nt in nx.ns: s += nt
593
+ if chord: s += ']' + tie
594
+ if s.endswith ('-'): s, tie = s[:-1], '-' # split off tie
595
+ s += durstr + tie # and put it back again
596
+ s += nx.after
597
+ nospace = nx.beam
598
+ else:
599
+ if isinstance (nx.str, listtype): nx.str = nx.str [0]
600
+ s = nx.str
601
+ nospace = 1
602
+ if nospace: vs.append (s)
603
+ else: vs.append (' ' + s)
604
+ vs = ''.join (vs) # ad hoc: remove multiple pedal directions
605
+ while vs.find ('!ped!!ped!') >= 0: vs = vs.replace ('!ped!!ped!','!ped!')
606
+ while vs.find ('!ped-up!!ped-up!') >= 0: vs = vs.replace ('!ped-up!!ped-up!','!ped-up!')
607
+ while vs.find ('!8va(!!8va)!') >= 0: vs = vs.replace ('!8va(!!8va)!','') # remove empty ottava's
608
+ return vs
609
+
610
+ def sortMeasure (voice, m):
611
+ voice.sort (key=lambda o: o.tijd) # sort on time
612
+ time = 0
613
+ v = []
614
+ rs = [] # holds rests in between notes
615
+ for i, nx in enumerate (voice): # establish sequentiality
616
+ if nx.tijd > time and chkbug (nx.tijd - time, m):
617
+ v.append (Note (nx.tijd - time, 'x')) # fill hole with invisble rest
618
+ rs.append (len (v) - 1)
619
+ if isinstance (nx, Elem):
620
+ if nx.tijd < time: nx.tijd = time # shift elems without duration to where they fit
621
+ v.append (nx)
622
+ time = nx.tijd
623
+ continue
624
+ if nx.tijd < time: # overlapping element
625
+ if nx.ns[0] == 'z': continue # discard overlapping rest
626
+ if v[-1].tijd <= nx.tijd: # we can do something
627
+ if v[-1].ns[0] == 'z': # shorten rest
628
+ v[-1].dur = nx.tijd - v[-1].tijd
629
+ if v[-1].dur == 0: del v[-1] # nothing left
630
+ info ('overlap in part %d, measure %d: rest shortened' % (m.ixp+1, m.ixm+1))
631
+ else: # make a chord of overlap
632
+ v[-1].ns += nx.ns
633
+ info ('overlap in part %d, measure %d: added chord' % (m.ixp+1, m.ixm+1))
634
+ nx.dur = (nx.tijd + nx.dur) - time # the remains
635
+ if nx.dur <= 0: continue # nothing left
636
+ nx.tijd = time # append remains
637
+ else: # give up
638
+ info ('overlapping notes in one voice! part %d, measure %d, note %s discarded' % (m.ixp+1, m.ixm+1, isinstance (nx, Note) and nx.ns or nx.str))
639
+ continue
640
+ v.append (nx)
641
+ if isinstance (nx, Note):
642
+ if nx.ns [0] in 'zx':
643
+ rs.append (len (v) - 1) # remember rests between notes
644
+ elif len (rs):
645
+ if nx.beam and not nx.grace: # copy beam into rests
646
+ for j in rs: v[j].beam = nx.beam
647
+ rs = [] # clear rests on each note
648
+ time = nx.tijd + nx.dur
649
+ # when a measure contains no elements and no forwards -> no incTime -> s.maxtime = 0 -> right barline
650
+ # is inserted at time == 0 (in addbar) and is only element in the voice when sortMeasure is called
651
+ if time == 0: info ('empty measure in part %d, measure %d, it should contain at least a rest to advance the time!' % (m.ixp+1, m.ixm+1))
652
+ return v
653
+
654
+ def getPartlist (ps): # correct part-list (from buggy xml-software)
655
+ xs = [] # the corrected part-list
656
+ e = [] # stack of opened part-groups
657
+ for x in list (ps): # insert missing stops, delete double starts
658
+ if x.tag == 'part-group':
659
+ num, type = x.get ('number'), x.get ('type')
660
+ if type == 'start':
661
+ if num in e: # missing stop: insert one
662
+ xs.append (E.Element ('part-group', number = num, type = 'stop'))
663
+ xs.append (x)
664
+ else: # normal start
665
+ xs.append (x)
666
+ e.append (num)
667
+ else:
668
+ if num in e: # normal stop
669
+ e.remove (num)
670
+ xs.append (x)
671
+ else: pass # double stop: skip it
672
+ else: xs.append (x)
673
+ for num in reversed (e): # fill missing stops at the end
674
+ xs.append (E.Element ('part-group', number = num, type = 'stop'))
675
+ return xs
676
+
677
+ def parseParts (xs, d, e): # -> [elems on current level], rest of xs
678
+ if not xs: return [],[]
679
+ x = xs.pop (0)
680
+ if x.tag == 'part-group':
681
+ num, type = x.get ('number'), x.get ('type')
682
+ if type == 'start': # go one level deeper
683
+ s = [x.findtext (n, '') for n in ['group-symbol','group-barline','group-name','group-abbreviation']]
684
+ d [num] = s # remember groupdata by group number
685
+ e.append (num) # make stack of open group numbers
686
+ elemsnext, rest1 = parseParts (xs, d, e) # parse one level deeper to next stop
687
+ elems, rest2 = parseParts (rest1, d, e) # parse the rest on this level
688
+ return [elemsnext] + elems, rest2
689
+ else: # stop: close level and return group-data
690
+ nums = e.pop () # last open group number in stack order
691
+ if xs and xs[0].get ('type') == 'stop': # two consequetive stops
692
+ if num != nums: # in the wrong order (tempory solution)
693
+ d[nums], d[num] = d[num], d[nums] # exchange values (only works for two stops!!!)
694
+ sym = d[num] # retrieve an return groupdata as last element of the group
695
+ return [sym], xs
696
+ else:
697
+ elems, rest = parseParts (xs, d, e) # parse remaining elements on current level
698
+ name = x.findtext ('part-name',''), x.findtext ('part-abbreviation','')
699
+ return [name] + elems, rest
700
+
701
+ def bracePart (part): # put a brace on multistaff part and group voices
702
+ if not part: return [] # empty part in the score
703
+ brace = []
704
+ for ivs in part:
705
+ if len (ivs) == 1: # stave with one voice
706
+ brace.append ('%s' % ivs[0])
707
+ else: # stave with multiple voices
708
+ brace += ['('] + ['%s' % iv for iv in ivs] + [')']
709
+ brace.append ('|')
710
+ del brace[-1] # no barline at the end
711
+ if len (part) > 1:
712
+ brace = ['{'] + brace + ['}']
713
+ return brace
714
+
715
+ def prgroupelem (x, gnm, bar, pmap, accVce, accStf): # collect partnames (accVce) and %%score map (accStf)
716
+ if type (x) == tupletype: # partname-tuple = (part-name, part-abbrev)
717
+ y = pmap.pop (0)
718
+ if gnm[0]: x = [n1 + ':' + n2 for n1, n2 in zip (gnm, x)] # put group-name before part-name
719
+ accVce.append (x)
720
+ accStf.extend (bracePart (y))
721
+ elif len (x) == 2 and type (x[0]) == tupletype: # misuse of group just to add extra name to stave
722
+ y = pmap.pop (0)
723
+ nms = [n1 + ':' + n2 for n1, n2 in zip (x[0], x[1][2:])] # x[0] = partname-tuple, x[1][2:] = groupname-tuple
724
+ accVce.append (nms)
725
+ accStf.extend (bracePart (y))
726
+ else:
727
+ prgrouplist (x, bar, pmap, accVce, accStf)
728
+
729
+ def prgrouplist (x, pbar, pmap, accVce, accStf): # collect partnames, scoremap for a part-group
730
+ sym, bar, gnm, gabbr = x[-1] # bracket symbol, continue barline, group-name-tuple
731
+ bar = bar == 'yes' or pbar # pbar -> the parent has bar
732
+ accStf.append (sym == 'brace' and '{' or '[')
733
+ for z in x[:-1]:
734
+ prgroupelem (z, (gnm, gabbr), bar, pmap, accVce, accStf)
735
+ if bar: accStf.append ('|')
736
+ if bar: del accStf [-1] # remove last one before close
737
+ accStf.append (sym == 'brace' and '}' or ']')
738
+
739
+ def compUnitLength (iv, maten, divs): # compute optimal unit length
740
+ uLmin, minLen = 0, max_int
741
+ for uL in [4,8,16]: # try 1/4, 1/8 and 1/16
742
+ vLen = 0 # total length of abc duration strings in this voice
743
+ for im, m in enumerate (maten): # all measures
744
+ for e in m[iv]: # all notes in voice iv
745
+ if isinstance (e, Elem) or e.dur == 0: continue # no real durations
746
+ vLen += len (abcdur (e, divs [im], uL)) # add len of duration string
747
+ if vLen < minLen: uLmin, minLen = uL, vLen # remember the smallest
748
+ return uLmin
749
+
750
+ def doSyllable (syl):
751
+ txt = ''
752
+ for e in syl:
753
+ if e.tag == 'elision': txt += '~'
754
+ elif e.tag == 'text': # escape - and space characters
755
+ txt += (e.text or '').replace ('_','\_').replace('-', r'\-').replace(' ', '~')
756
+ if not txt: return txt
757
+ if syl.findtext('syllabic') in ['begin', 'middle']: txt += '-'
758
+ if syl.find('extend') is not None: txt += '_'
759
+ return txt
760
+
761
+ def checkMelismas (lyrics, maten, im, iv):
762
+ if im == 0: return
763
+ maat = maten [im][iv] # notes of the current measure
764
+ curlyr = lyrics [im][iv] # lyrics dict of current measure
765
+ prvlyr = lyrics [im-1][iv] # lyrics dict of previous measure
766
+ for n, (lyrstr, melis) in prvlyr.items (): # all lyric numbers in the previous measure
767
+ if n not in curlyr and melis: # melisma required, but no lyrics present -> make one!
768
+ ms = getMelisma (maat) # get a melisma for the current measure
769
+ if ms: curlyr [n] = (ms, 0) # set melisma as the n-th lyrics of the current measure
770
+
771
+ def getMelisma (maat): # get melisma from notes in maat
772
+ ms = []
773
+ for note in maat: # every note should get an underscore
774
+ if not isinstance (note, Note): continue # skip Elem's
775
+ if note.grace: continue # skip grace notes
776
+ if note.ns [0] in 'zx': break # stop on first rest
777
+ ms.append ('_')
778
+ return ' '.join (ms)
779
+
780
+ def perc2map (abcIn):
781
+ fillmap = {'diamond':1, 'triangle':1, 'square':1, 'normal':1};
782
+ abc = map (lambda x: x.strip (), percSvg.splitlines ())
783
+ id='default'
784
+ maps = {'default': []};
785
+ dmaps = {'default': []}
786
+ r1 = re.compile (r'V:\s*(\S+)')
787
+ ls = abcIn.splitlines ()
788
+ for x in ls:
789
+ if 'I:percmap' in x:
790
+ noot, step, midi, kop = map (lambda x: x.strip (), x.split ()[1:])
791
+ if kop in fillmap: kop = kop + '+' + ',' + kop
792
+ x = '%%%%map perc%s %s print=%s midi=%s heads=%s' % (id, noot, step, midi, kop)
793
+ maps [id].append (x)
794
+ if '%%MIDI' in x: dmaps [id].append (x)
795
+ if 'V:' in x:
796
+ r = r1.match (x)
797
+ if r:
798
+ id = r.group (1);
799
+ if id not in maps: maps [id] = []; dmaps [id] = []
800
+ ids = sorted (maps.keys ())
801
+ for id in ids: abc += maps [id]
802
+ id='default'
803
+ for x in ls:
804
+ if 'I:percmap' in x: continue
805
+ if '%%MIDI' in x: continue
806
+ if 'V:' in x or 'K:' in x:
807
+ r = r1.match (x)
808
+ if r: id = r.group (1)
809
+ abc.append (x)
810
+ if id in dmaps and len (dmaps [id]) > 0: abc.extend (dmaps [id]); del dmaps [id]
811
+ if 'perc' in x and 'map=' not in x: x += ' map=perc';
812
+ if 'map=perc' in x and len (maps [id]) > 0: abc.append ('%%voicemap perc' + id);
813
+ if 'map=off' in x: abc.append ('%%voicemap');
814
+ else:
815
+ abc.append (x)
816
+ return '\n'.join (abc) + '\n'
817
+
818
+ def addoct (ptc, o): # xml staff step, xml octave number
819
+ p = ptc
820
+ if o > 4: p = ptc.lower ()
821
+ if o > 5: p = p + (o-5) * "'"
822
+ if o < 4: p = p + (4-o) * ","
823
+ return p # abc pitch == abc note without accidental
824
+
825
+ def chkbug (dt, m):
826
+ if dt > m.divs / 16: return 1 # duration should be > 1/64 note
827
+ info ('MuseScore bug: incorrect duration, smaller then 1/64! in measure %d, part %d' % (m.ixm, m.ixp))
828
+ return 0
829
+
830
+ #----------------
831
+ # parser
832
+ #----------------
833
+ class Parser:
834
+ note_alts = [ # 3 alternative notations of the same note for tablature mapping
835
+ [x.strip () for x in '=C, ^C, =D, ^D, =E, =F, ^F, =G, ^G, =A, ^A, =B'.split (',')],
836
+ [x.strip () for x in '^B, _D,^^C, _E, _F, ^E, _G,^^F, _A,^^G, _B, _C'.split (',')],
837
+ [x.strip () for x in '__D,^^B,__E,__F,^^D,__G,^^E,__A,_/A,__B,__C,^^A'.split (',')] ]
838
+ step_map = {'C':0,'D':2,'E':4,'F':5,'G':7,'A':9,'B':11}
839
+ def __init__ (s, options):
840
+ # unfold repeats, number of chars per line, credit filter level, volta option
841
+ s.slurBuf = {} # dict of open slurs keyed by slur number
842
+ s.dirStk = {} # {direction-type + number -> (type, voice | time)} dict for proper closing
843
+ s.ingrace = 0 # marks a sequence of grace notes
844
+ s.msc = Music (options) # global music data abstraction
845
+ s.unfold = options.u # turn unfolding repeats on
846
+ s.ctf = options.c # credit text filter level
847
+ s.gStfMap = [] # [[abc voice numbers] for all parts]
848
+ s.midiMap = [] # midi-settings for each abc voice, in order
849
+ s.drumInst = {} # inst_id -> midi pitch for channel 10 notes
850
+ s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
851
+ s.instMid = [] # [{inst id -> midi-settings} for all parts]
852
+ s.midDflt = [-1,-1,-1,-91] # default midi settings for channel, program, volume, panning
853
+ s.msralts = {} # xml-notenames (without octave) with accidentals from the key
854
+ s.curalts = {} # abc-notenames (with voice number) with passing accidentals
855
+ s.stfMap = {} # xml staff number -> [xml voice number]
856
+ s.vce2stf = {} # xml voice number -> allocated staff number
857
+ s.clefMap = {} # xml staff number -> abc clef (for header only)
858
+ s.curClef = {} # xml staff number -> current abc clef
859
+ s.stemDir = {} # xml voice number -> current stem direction
860
+ s.clefOct = {} # xml staff number -> current clef-octave-change
861
+ s.curStf = {} # xml voice number -> current xml staff number
862
+ s.nolbrk = options.x; # generate no linebreaks ($)
863
+ s.jscript = options.j # compatibility with javascript version
864
+ s.ornaments = sorted (note_ornamentation_map.items ())
865
+ s.doPageFmt = len (options.p) == 1 # translate xml page format
866
+ s.tstep = options.t # clef determines step on staff (percussion)
867
+ s.dirtov1 = options.v1 # all directions to first voice of staff
868
+ s.ped = options.ped # render pedal directions
869
+ s.wstems = options.stm # translate stem elements
870
+ s.pedVce = None # voice for pedal directions
871
+ s.repeat_str = {} # staff number -> [measure number, repeat-text]
872
+ s.tabVceMap = {} # abc voice num -> [%%map ...] for tab voices
873
+ s.koppen = {} # noteheads needed for %%map
874
+
875
+ def matchSlur (s, type2, n, v2, note2, grace, stopgrace): # match slur number n in voice v2, add abc code to before/after
876
+ if type2 not in ['start', 'stop']: return # slur type continue has no abc equivalent
877
+ if n == None: n = '1'
878
+ if n in s.slurBuf:
879
+ type1, v1, note1, grace1 = s.slurBuf [n]
880
+ if type2 != type1: # slur complete, now check the voice
881
+ if v2 == v1: # begins and ends in the same voice: keep it
882
+ if type1 == 'start' and (not grace1 or not stopgrace): # normal slur: start before stop and no grace slur
883
+ note1.before = ['('] + note1.before # keep left-right order!
884
+ note2.after += ')'
885
+ # no else: don't bother with reversed stave spanning slurs
886
+ del s.slurBuf [n] # slur finished, remove from stack
887
+ else: # double definition, keep the last
888
+ info ('double slur numbers %s-%s in part %d, measure %d, voice %d note %s, first discarded' % (type2, n, s.msr.ixp+1, s.msr.ixm+1, v2, note2.ns))
889
+ s.slurBuf [n] = (type2, v2, note2, grace)
890
+ else: # unmatched slur, put in dict
891
+ s.slurBuf [n] = (type2, v2, note2, grace)
892
+
893
+ def doNotations (s, note, nttn, isTab):
894
+ for key, val in s.ornaments:
895
+ if nttn.find (key) != None: note.before += [val] # just concat all ornaments
896
+ trem = nttn.find ('ornaments/tremolo')
897
+ if trem != None:
898
+ type = trem.get ('type')
899
+ if type == 'single':
900
+ note.before.insert (0, '!%s!' % (int (trem.text) * '/'))
901
+ else:
902
+ note.fact = None # no time modification in ABC
903
+ if s.tstep: # abc2svg version
904
+ if type == 'stop': note.before.insert (0, '!trem%s!' % trem.text);
905
+ else: # abc2xml version
906
+ if type == 'start': note.before.insert (0, '!%s-!' % (int (trem.text) * '/'));
907
+ fingering = nttn.findall ('technical/fingering')
908
+ for finger in fingering: # handle multiple finger annotations
909
+ if not isTab: note.before += ['!%s!' % finger.text] # fingering goes before chord (addChord)
910
+ snaar = nttn.find ('technical/string')
911
+ if snaar != None and isTab:
912
+ if s.tstep:
913
+ fret = nttn.find ('technical/fret')
914
+ if fret != None: note.tab = (snaar.text, fret.text)
915
+ else:
916
+ deco = '!%s!' % snaar.text # no double string decos (bug in musescore)
917
+ if deco not in note.ntdec: note.ntdec += deco
918
+ wvln = nttn.find ('ornaments/wavy-line')
919
+ if wvln != None:
920
+ if wvln.get ('type') == 'start': note.before = ['!trill(!'] + note.before # keep left-right order!
921
+ elif wvln.get ('type') == 'stop': note.before = ['!trill)!'] + note.before
922
+ glis = nttn.find ('glissando')
923
+ if glis == None: glis = nttn.find ('slide') # treat slide as glissando
924
+ if glis != None:
925
+ lt = '~' if glis.get ('line-type') =='wavy' else '-'
926
+ if glis.get ('type') == 'start': note.before = ['!%s(!' % lt] + note.before # keep left-right order!
927
+ elif glis.get ('type') == 'stop': note.before = ['!%s)!' % lt] + note.before
928
+
929
+ def tabnote (s, alt, ptc, oct, v, ntrec):
930
+ p = s.step_map [ptc] + int (alt or '0') # p in -2 .. 13
931
+ if p > 11: oct += 1 # octave correction
932
+ if p < 0: oct -= 1
933
+ p = p % 12 # remap p into 0..11
934
+ snaar_nw, fret_nw = ntrec.tab # the computed/annotated allocation of nt
935
+ for i in range (4): # support same note on 4 strings
936
+ na = s.note_alts [i % 3] [p] # get alternative representation of same note
937
+ o = oct
938
+ if na in ['^B', '^^B']: o -= 1 # because in adjacent octave
939
+ if na in ['_C', '__C']: o += 1
940
+ if '/' in na or i == 3: o = 9 # emergency notation for 4th string case
941
+ nt = addoct (na, o)
942
+ snaar, fret = s.tabmap.get ((v, nt), ('', '')) # the current allocation of nt
943
+ if not snaar: break # note not yet allocated
944
+ if snaar_nw == snaar: return nt # use present allocation
945
+ if i == 3: # new allocaion needed but none is free
946
+ fmt = 'rejected: voice %d note %3s string %s fret %2s remains: string %s fret %s'
947
+ info (fmt % (v, nt, snaar_nw, fret_nw, snaar, fret), 1)
948
+ ntrec.tab = (snaar, fret)
949
+ s.tabmap [v, nt] = ntrec.tab # for tablature map (voice, note) -> (string, fret)
950
+ return nt # ABC code always in key C (with midi pitch alterations)
951
+
952
+ def ntAbc (s, ptc, oct, note, v, ntrec, isTab): # pitch, octave -> abc notation
953
+ acc2alt = {
954
+ 'double-flat': -2,
955
+ 'flat-flat': -2,
956
+ 'flat': -1,
957
+ 'natural-flat': -1,
958
+ 'natural': 0,
959
+ 'sharp': 1,
960
+ 'natural-sharp': 1,
961
+ 'sharp-sharp': 2,
962
+ 'double-sharp': 2
963
+ }
964
+ oct += s.clefOct.get (s.curStf [v], 0) # minus clef-octave-change value
965
+ acc = note.findtext ('accidental') # should be the notated accidental
966
+ alt = note.findtext ('pitch/alter') # pitch alteration (midi)
967
+ if ntrec.tab: return s.tabnote (alt, ptc, oct, v, ntrec) # implies s.tstep is true (options.t was given)
968
+ elif isTab and s.tstep:
969
+ nt = ['__','_','','^','^^'][int (alt or '0') + 2] + addoct (ptc, oct)
970
+ info ('no string notation found for note %s in voice %d' % (nt, v), 1)
971
+ p = addoct (ptc, oct)
972
+ if alt == None and s.msralts.get (ptc, 0): alt = 0 # no alt but key implies alt -> natural!!
973
+ if alt == None and (p, v) in s.curalts: alt = 0 # no alt but previous note had one -> natural!!
974
+ if acc == None and alt == None: return p # no acc, no alt
975
+ elif acc != None:
976
+ alt = acc2alt [acc] # acc takes precedence over the pitch here!
977
+ else: # now see if we really must add an accidental
978
+ alt = int (float (alt))
979
+ if (p, v) in s.curalts: # the note in this voice has been altered before
980
+ if alt == s.curalts [(p, v)]: return p # alteration still the same
981
+ elif alt == s.msralts.get (ptc, 0): return p # alteration implied by the key
982
+ tieElms = note.findall ('tie') + note.findall ('notations/tied') # in xml we have separate notated ties and playback ties
983
+ if 'stop' in [e.get ('type') for e in tieElms]: return p # don't alter tied notes
984
+ info ('accidental %d added in part %d, measure %d, voice %d note %s' % (alt, s.msr.ixp+1, s.msr.ixm+1, v+1, p))
985
+ s.curalts [(p, v)] = alt
986
+ p = ['__','_','=','^','^^'][alt+2] + p # and finally ... prepend the accidental
987
+ return p
988
+
989
+ def doNote (s, n): # parse a musicXML note tag
990
+ note = Note ()
991
+ v = int (n.findtext ('voice', '1'))
992
+ if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
993
+ chord = n.find ('chord') != None
994
+ p = n.findtext ('pitch/step') or n.findtext ('unpitched/display-step')
995
+ o = n.findtext ('pitch/octave') or n.findtext ('unpitched/display-octave')
996
+ r = n.find ('rest')
997
+ numer = n.findtext ('time-modification/actual-notes')
998
+ if numer:
999
+ denom = n.findtext ('time-modification/normal-notes')
1000
+ note.fact = (int (numer), int (denom))
1001
+ note.tup = [x.get ('type') for x in n.findall ('notations/tuplet')]
1002
+ dur = n.findtext ('duration')
1003
+ grc = n.find ('grace')
1004
+ note.grace = grc != None
1005
+ note.before, note.after = [], '' # strings with ABC stuff that goes before or after a note/chord
1006
+ if note.grace and not s.ingrace: # open a grace sequence
1007
+ s.ingrace = 1
1008
+ note.before = ['{']
1009
+ if grc.get ('slash') == 'yes': note.before += ['/'] # acciaccatura
1010
+ stopgrace = not note.grace and s.ingrace
1011
+ if stopgrace: # close the grace sequence
1012
+ s.ingrace = 0
1013
+ s.msc.lastnote.after += '}' # close grace on lastenote.after
1014
+ if dur == None or note.grace: dur = 0
1015
+ if r == None and n.get ('print-object') == 'no':
1016
+ if chord: return
1017
+ r = 1 # turn invisible notes (that advance the time) into invisible rests
1018
+ note.dur = int (dur)
1019
+ if r == None and (not p or not o): # not a rest and no pitch
1020
+ s.msc.cnt.inc ('nopt', v) # count unpitched notes
1021
+ o, p = 5,'E' # make it an E5 ??
1022
+ isTab = s.curClef and s.curClef.get (s.curStf [v], '').startswith ('tab')
1023
+ nttn = n.find ('notations') # add ornaments
1024
+ if nttn != None: s.doNotations (note, nttn, isTab)
1025
+ e = n.find ('stem') if r == None else None # no !stemless! before rest
1026
+ if e != None and e.text == 'none' and (not isTab or v in s.hasStems or s.tstep):
1027
+ note.before += ['!stemless!']; abcOut.stemless = 0; # ???????????U???????????!stemless!
1028
+ # note.before += ['s']; abcOut.stemless = 1;
1029
+ e = n.find ('accidental')
1030
+ if e != None and e.get ('parentheses') == 'yes': note.ntdec += '!courtesy!'
1031
+ if r != None: noot = 'x' if n.get ('print-object') == 'no' or isTab else 'z'
1032
+ else: noot = s.ntAbc (p, int (o), n, v, note, isTab)
1033
+ if n.find ('unpitched') != None:
1034
+ clef = s.curClef [s.curStf [v]] # the current clef for this voice
1035
+ step = staffStep (p, int (o), clef, s.tstep) # (clef independent) step value of note on the staff
1036
+ instr = n.find ('instrument')
1037
+ instId = instr.get ('id') if instr != None else 'dummyId'
1038
+ midi = s.drumInst.get (instId, abcMid (noot))
1039
+ nh = n.findtext ('notehead', '').replace (' ','-') # replace spaces in xml notehead names for percmap
1040
+ if nh == 'x': noot = '^' + noot.replace ('^','').replace ('_','')
1041
+ if nh in ['circle-x','diamond','triangle']: noot = '_' + noot.replace ('^','').replace ('_','')
1042
+ if nh and n.find ('notehead').get ('filled','') == 'yes': nh += '+'
1043
+ if nh and n.find ('notehead').get ('filled','') == 'no': nh += '-'
1044
+ s.drumNotes [(v, noot)] = (step, midi, nh) # keep data for percussion map
1045
+ tieElms = n.findall ('tie') + n.findall ('notations/tied') # in xml we have separate notated ties and playback ties
1046
+ if 'start' in [e.get ('type') for e in tieElms]: # n can have stop and start tie
1047
+ noot = noot + '-'
1048
+ note.beam = sum ([1 for b in n.findall('beam') if b.text in ['continue', 'end']]) + int (note.grace)
1049
+ lyrlast = 0; rsib = re.compile (r'^.*verse')
1050
+ for e in n.findall ('lyric'):
1051
+ lyrnum = int (rsib.sub ('', e.get ('number', '1'))) # also do Sibelius numbers
1052
+ if lyrnum == 0: lyrnum = lyrlast + 1 # and correct Sibelius bugs
1053
+ else: lyrlast = lyrnum
1054
+ note.lyrs [lyrnum] = doSyllable (e)
1055
+ stemdir = n.findtext ('stem')
1056
+ if s.wstems and (stemdir == 'up' or stemdir == 'down'):
1057
+ if stemdir != s.stemDir.get (v, ''):
1058
+ s.stemDir [v] = stemdir
1059
+ s.msc.appendElem (v, '[I:stemdir %s]' % stemdir)
1060
+ if chord: s.msc.addChord (note, noot)
1061
+ else:
1062
+ xmlstaff = int (n.findtext ('staff', '1'))
1063
+ if s.curStf [v] != xmlstaff: # the note should go to another staff
1064
+ dstaff = xmlstaff - s.curStf [v] # relative new staff number
1065
+ s.curStf [v] = xmlstaff # remember the new staff for this voice
1066
+ s.msc.appendElem (v, '[I:staff %+d]' % dstaff) # insert a move before the note
1067
+ s.msc.appendNote (v, note, noot)
1068
+ for slur in n.findall ('notations/slur'): # s.msc.lastnote points to the last real note/chord inserted above
1069
+ s.matchSlur (slur.get ('type'), slur.get ('number'), v, s.msc.lastnote, note.grace, stopgrace) # match slur definitions
1070
+
1071
+ def doAttr (s, e): # parse a musicXML attribute tag
1072
+ teken = {'C1':'alto1','C2':'alto2','C3':'alto','C4':'tenor','F4':'bass','F3':'bass3','G2':'treble','TAB':'tab','percussion':'perc'}
1073
+ dvstxt = e.findtext ('divisions')
1074
+ if dvstxt: s.msr.divs = int (dvstxt)
1075
+ steps = int (e.findtext ('transpose/chromatic', '0')) # for transposing instrument
1076
+ fifths = e.findtext ('key/fifths')
1077
+ first = s.msc.tijd == 0 and s.msr.ixm == 0 # first attributes in first measure
1078
+ if fifths:
1079
+ key, s.msralts = setKey (int (fifths), e.findtext ('key/mode','major'))
1080
+ if first and not steps and abcOut.key == 'none':
1081
+ abcOut.key = key # first measure -> header, if not transposing instrument or percussion part!
1082
+ elif key != abcOut.key or not first:
1083
+ s.msr.attr += '[K:%s]' % key # otherwise -> voice
1084
+ beats = e.findtext ('time/beats')
1085
+ if beats:
1086
+ unit = e.findtext ('time/beat-type')
1087
+ mtr = beats + '/' + unit
1088
+ if first: abcOut.mtr = mtr # first measure -> header
1089
+ else: s.msr.attr += '[M:%s]' % mtr # otherwise -> voice
1090
+ s.msr.mtr = int (beats), int (unit)
1091
+ s.msr.mdur = (s.msr.divs * s.msr.mtr[0] * 4) // s.msr.mtr[1] # duration of measure in xml-divisions
1092
+ for ms in e.findall('measure-style'):
1093
+ n = int (ms.get ('number', '1')) # staff number
1094
+ voices = s.stfMap [n] # all voices of staff n
1095
+ for mr in ms.findall('measure-repeat'):
1096
+ ty = mr.get('type')
1097
+ if ty == 'start': # remember start measure number and text voor each staff
1098
+ s.repeat_str [n] = [s.msr.ixm, mr.text]
1099
+ for v in voices: # insert repeat into all voices, value will be overwritten at stop
1100
+ s.msc.insertElem (v, s.repeat_str [n])
1101
+ elif ty == 'stop': # calculate repeat measure count for this staff n
1102
+ start_ix, text_ = s.repeat_str [n]
1103
+ repeat_count = s.msr.ixm - start_ix
1104
+ if text_:
1105
+ mid_str = "%s " % text_
1106
+ repeat_count /= int (text_)
1107
+ else:
1108
+ mid_str = "" # overwrite repeat with final string
1109
+ s.repeat_str [n][0] = '[I:repeat %s%d]' % (mid_str, repeat_count)
1110
+ del s.repeat_str [n] # remove closed repeats
1111
+ toct = e.findtext ('transpose/octave-change', '')
1112
+ if toct: steps += 12 * int (toct) # extra transposition of toct octaves
1113
+ for clef in e.findall ('clef'): # a part can have multiple staves
1114
+ n = int (clef.get ('number', '1')) # local staff number for this clef
1115
+ sgn = clef.findtext ('sign')
1116
+ line = clef.findtext ('line', '') if sgn not in ['percussion','TAB'] else ''
1117
+ cs = teken.get (sgn + line, '')
1118
+ oct = clef.findtext ('clef-octave-change', '') or '0'
1119
+ if oct: cs += {-2:'-15', -1:'-8', 1:'+8', 2:'+15'}.get (int (oct), '')
1120
+ s.clefOct [n] = -int (oct); # xml playback pitch -> abc notation pitch
1121
+ if steps: cs += ' transpose=' + str (steps)
1122
+ stfdtl = e.find ('staff-details')
1123
+ if stfdtl and int (stfdtl.get ('number', '1')) == n:
1124
+ lines = stfdtl.findtext ('staff-lines')
1125
+ if lines:
1126
+ lns= '|||' if lines == '3' and sgn == 'TAB' else lines
1127
+ cs += ' stafflines=%s' % lns
1128
+ s.stafflines = int (lines) # remember for tab staves
1129
+ strings = stfdtl.findall ('staff-tuning')
1130
+ if strings:
1131
+ tuning = [st.findtext ('tuning-step') + st.findtext ('tuning-octave') for st in strings]
1132
+ cs += ' strings=%s' % ','.join (tuning)
1133
+ capo = stfdtl.findtext ('capo')
1134
+ if capo: cs += ' capo=%s' % capo
1135
+ s.curClef [n] = cs # keep track of current clef (for percmap)
1136
+ if first: s.clefMap [n] = cs # clef goes to header (where it is mapped to voices)
1137
+ else:
1138
+ voices = s.stfMap[n] # clef change to all voices of staff n
1139
+ for v in voices:
1140
+ if n != s.curStf [v]: # voice is not at its home staff n
1141
+ dstaff = n - s.curStf [v]
1142
+ s.curStf [v] = n # reset current staff at start of measure to home position
1143
+ s.msc.appendElem (v, '[I:staff %+d]' % dstaff)
1144
+ s.msc.appendElem (v, '[K:%s]' % cs)
1145
+
1146
+ def findVoice (s, i, es):
1147
+ stfnum = int (es[i].findtext ('staff',1)) # directions belong to a staff
1148
+ vs = s.stfMap [stfnum] # voices in this staff
1149
+ v1 = vs [0] if vs else 1 # directions to first voice of staff
1150
+ if s.dirtov1: return stfnum, v1, v1 # option --v1
1151
+ for e in es [i+1:]: # or to the voice of the next note
1152
+ if e.tag == 'note':
1153
+ v = int (e.findtext ('voice', '1'))
1154
+ if s.isSib: v += 100 * int (e.findtext ('staff', '1')) # repair bug in Sibelius
1155
+ stf = s.vce2stf [v] # use our own staff allocation
1156
+ return stf, v, v1 # voice of next note, first voice of staff
1157
+ if e.tag == 'backup': break
1158
+ return stfnum, v1, v1 # no note found, fall back to v1
1159
+
1160
+ def doDirection (s, e, i, es): # parse a musicXML direction tag
1161
+ def addDirection (x, vs, tijd, stfnum):
1162
+ if not x: return
1163
+ vs = s.stfMap [stfnum] if '!8v' in x else [vs] # ottava's go to all voices of staff
1164
+ for v in vs:
1165
+ if tijd != None: # insert at time of encounter
1166
+ s.msc.appendElemT (v, x.replace ('(',')').replace ('ped','ped-up'), tijd)
1167
+ else:
1168
+ s.msc.appendElem (v, x)
1169
+ def startStop (dtype, vs, stfnum=1):
1170
+ typmap = {'down':'!8va(!', 'up':'!8vb(!', 'crescendo':'!<(!', 'diminuendo':'!>(!', 'start':'!ped!'}
1171
+ type = t.get ('type', '')
1172
+ k = dtype + t.get ('number', '1') # key to match the closing direction
1173
+ if type in typmap: # opening the direction
1174
+ x = typmap [type]
1175
+ if k in s.dirStk: # closing direction already encountered
1176
+ stype, tijd = s.dirStk [k]; del s.dirStk [k]
1177
+ if stype == 'stop':
1178
+ addDirection (x, vs, tijd, stfnum)
1179
+ else:
1180
+ info ('%s direction %s has no stop in part %d, measure %d, voice %d' % (dtype, stype, s.msr.ixp+1, s.msr.ixm+1, vs+1))
1181
+ s.dirStk [k] = ((type , vs)) # remember voice and type for closing
1182
+ else:
1183
+ s.dirStk [k] = ((type , vs)) # remember voice and type for closing
1184
+ elif type == 'stop':
1185
+ if k in s.dirStk: # matching open direction found
1186
+ type, vs = s.dirStk [k]; del s.dirStk [k] # into the same voice
1187
+ if type == 'stop':
1188
+ info ('%s direction %s has double stop in part %d, measure %d, voice %d' % (dtype, type, s.msr.ixp+1, s.msr.ixm+1, vs+1))
1189
+ x = ''
1190
+ else:
1191
+ x = typmap [type].replace ('(',')').replace ('ped','ped-up')
1192
+ else: # closing direction found before opening
1193
+ s.dirStk [k] = ('stop', s.msc.tijd)
1194
+ x = '' # delay code generation until opening found
1195
+ elif type in ['continue', 'resume', 'discontinue', 'change']:
1196
+ # ?? 'continue' ? 'resume'????????????????
1197
+ # ??????????????
1198
+ # info('Ignoring unsupported direction type: %s' % type)
1199
+ x = ''
1200
+ else: raise ValueError ('wrong direction type')
1201
+ addDirection (x, vs, None, stfnum)
1202
+ tempo, wrdstxt = None, ''
1203
+ plcmnt = e.get ('placement')
1204
+ stf, vs, v1 = s.findVoice (i, es)
1205
+ jmp = '' # for jump sound elements: dacapo, dalsegno and family
1206
+ jmps = [('dacapo','D.C.'),('dalsegno','D.S.'),('tocoda','dacoda'),('fine','fine'),('coda','O'),('segno','S')]
1207
+ t = e.find ('sound') # there are many possible attributes for sound
1208
+ if t != None:
1209
+ minst = t.find ('midi-instrument')
1210
+ if minst:
1211
+ prg = t.findtext ('midi-instrument/midi-program')
1212
+ chn = t.findtext ('midi-instrument/midi-channel')
1213
+ vids = [v for v, id in s.vceInst.items () if id == minst.get ('id')]
1214
+ if vids: vs = vids [0] # direction for the indentified voice, not the staff
1215
+ parm, inst = ('program', str (int (prg) - 1)) if prg else ('channel', chn)
1216
+ if inst and abcOut.volpan > 0: s.msc.appendElem (vs, '[I:MIDI= %s %s]' % (parm, inst))
1217
+ tempo = t.get ('tempo') # look for tempo attribute
1218
+ if tempo:
1219
+ tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
1220
+ tempo_units = (1,4) # always 1/4 for sound elements!
1221
+ for r, v in jmps:
1222
+ if t.get (r, ''): jmp = v; break
1223
+ dirtypes = e.findall ('direction-type')
1224
+ for dirtyp in dirtypes:
1225
+ units = { 'whole': (1,1), 'half': (1,2), 'quarter': (1,4), 'eighth': (1,8) }
1226
+ metr = dirtyp.find ('metronome')
1227
+ if metr != None:
1228
+ t = metr.findtext ('beat-unit', '')
1229
+ if t in units: tempo_units = units [t]
1230
+ else: tempo_units = units ['quarter']
1231
+ if metr.find ('beat-unit-dot') != None:
1232
+ tempo_units = simplify (tempo_units [0] * 3, tempo_units [1] * 2)
1233
+
1234
+ debugtext = metr.findtext ('per-minute')
1235
+ tmpro = None
1236
+ if metr.findtext ('per-minute'):
1237
+ tmpro = re.search ('[.\d]+', metr.findtext ('per-minute')) # look for a number #####
1238
+ if tmpro: tempo = tmpro.group () # overwrites the value set by the sound element of this direction
1239
+ t = dirtyp.find ('wedge')
1240
+ if t != None: startStop ('wedge', vs)
1241
+ allwrds = dirtyp.findall ('words') # insert text annotations
1242
+ if not allwrds: allwrds = dirtyp.findall ('rehearsal') # treat rehearsal mark as text annotation
1243
+ for wrds in allwrds:
1244
+ if jmp: # ignore the words when a jump sound element is present in this direction
1245
+ s.msc.appendElem (vs, '!%s!' % jmp , 1) # to voice
1246
+ break
1247
+ plc = plcmnt == 'below' and '_' or '^'
1248
+ if float (wrds.get ('default-y', '0')) < 0: plc = '_'
1249
+ wrdstxt += (wrds.text or '').replace ('"','\\"').replace ('\n', '\\n')
1250
+ wrdstxt = wrdstxt.strip ()
1251
+ for key, val in dynamics_map.items ():
1252
+ if dirtyp.find ('dynamics/' + key) != None:
1253
+ s.msc.appendElem (vs, val, 1) # to voice
1254
+ if dirtyp.find ('coda') != None: s.msc.appendElem (vs, 'O', 1)
1255
+ if dirtyp.find ('segno') != None: s.msc.appendElem (vs, 'S', 1)
1256
+ t = dirtyp.find ('octave-shift')
1257
+ if t != None: startStop ('octave-shift', vs, stf) # assume size == 8 for the time being
1258
+ t = dirtyp.find ('pedal')
1259
+ if t != None and s.ped:
1260
+ if not s.pedVce: s.pedVce = vs
1261
+ startStop ('pedal', s.pedVce)
1262
+ if dirtyp.findtext ('other-direction') == 'diatonic fretting': s.diafret = 1;
1263
+ if tempo:
1264
+ tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
1265
+ if s.msc.tijd == 0 and s.msr.ixm == 0: # first measure -> header
1266
+ abcOut.tempo = tempo
1267
+ abcOut.tempo_units = tempo_units
1268
+ else:
1269
+ s.msc.appendElem (v1, '[Q:%d/%d=%s]' % (tempo_units [0], tempo_units [1], tempo)) # otherwise -> 1st voice
1270
+ if wrdstxt: s.msc.appendElem (vs, '"%s%s"' % (plc, wrdstxt), 1) # to voice, but after tempo
1271
+
1272
+ def doHarmony (s, e, i, es): # parse a musicXMl harmony tag
1273
+ _, vt, _ = s.findVoice (i, es)
1274
+ short = {'major':'', 'minor':'m', 'augmented':'+', 'diminished':'dim', 'dominant':'7', 'half-diminished':'m7b5'}
1275
+ accmap = {'major':'maj', 'dominant':'', 'minor':'m', 'diminished':'dim', 'augmented':'+', 'suspended':'sus'}
1276
+ modmap = {'second':'2', 'fourth':'4', 'seventh':'7', 'sixth':'6', 'ninth':'9', '11th':'11', '13th':'13'}
1277
+ altmap = {'1':'#', '0':'', '-1':'b'}
1278
+ root = e.findtext ('root/root-step','')
1279
+ alt = altmap.get (e.findtext ('root/root-alter'), '')
1280
+ sus = ''
1281
+ kind = e.findtext ('kind', '')
1282
+ if kind in short: kind = short [kind]
1283
+ elif '-' in kind: # xml chord names: <triad name>-<modification>
1284
+ triad, mod = kind.split ('-')
1285
+ kind = accmap.get (triad, '') + modmap.get (mod, '')
1286
+ if kind.startswith ('sus'): kind, sus = '', kind # sus-suffix goes to the end
1287
+ elif kind == 'none': kind = e.find ('kind').get ('text','')
1288
+ degrees = e.findall ('degree')
1289
+ for d in degrees: # chord alterations
1290
+ kind += altmap.get (d.findtext ('degree-alter'),'') + d.findtext ('degree-value','')
1291
+ kind = kind.replace ('79','9').replace ('713','13').replace ('maj6','6')
1292
+ bass = e.findtext ('bass/bass-step','') + altmap.get (e.findtext ('bass/bass-alter'),'')
1293
+ s.msc.appendElem (vt, '"%s%s%s%s%s"' % (root, alt, kind, sus, bass and '/' + bass), 1)
1294
+
1295
+ def doBarline (s, e): # 0 = no repeat, 1 = begin repeat, 2 = end repeat
1296
+ rep = e.find ('repeat')
1297
+ if rep != None: rep = rep.get ('direction')
1298
+ if s.unfold: # unfold repeat, don't translate barlines
1299
+ return rep and (rep == 'forward' and 1 or 2) or 0
1300
+ loc = e.get ('location', 'right') # right is the default
1301
+ if loc == 'right': # only change style for the right side
1302
+ style = e.findtext ('bar-style')
1303
+ if style == 'light-light': s.msr.rline = '||'
1304
+ elif style == 'light-heavy': s.msr.rline = '|]'
1305
+ if rep != None: # repeat found
1306
+ if rep == 'forward': s.msr.lline = ':'
1307
+ else: s.msr.rline = ':|' # override barline style
1308
+ end = e.find ('ending')
1309
+ if end != None:
1310
+ if end.get ('type') == 'start':
1311
+ n = end.get ('number', '1').replace ('.','').replace (' ','')
1312
+ try: list (map (int, n.split (','))) # should be a list of integers
1313
+ except: n = '"%s"' % n.strip () # illegal musicXML
1314
+ s.msr.lnum = n # assume a start is always at the beginning of a measure
1315
+ elif s.msr.rline == '|': # stop and discontinue the same in ABC ?
1316
+ s.msr.rline = '||' # to stop on a normal barline use || in ABC ?
1317
+ return 0
1318
+
1319
+ def doPrint (s, e): # print element, measure number -> insert a line break
1320
+ if e.get ('new-system') == 'yes' or e.get ('new-page') == 'yes':
1321
+ if not s.nolbrk: return '$' # a line break
1322
+
1323
+ def doPartList (s, e): # translate the start/stop-event-based xml-partlist into proper tree
1324
+ for sp in e.findall ('part-list/score-part'):
1325
+ midi = {}
1326
+ for m in sp.findall ('midi-instrument'):
1327
+ x = [m.findtext (p, s.midDflt [i]) for i,p in enumerate (['midi-channel','midi-program','volume','pan'])]
1328
+ pan = float (x[3])
1329
+ if pan >= -90 and pan <= 90: # would be better to map behind-pannings
1330
+ pan = (float (x[3]) + 90) / 180 * 127 # xml between -90 and +90
1331
+ midi [m.get ('id')] = [int (x[0]), int (x[1]), float (x[2]) * 1.27, pan] # volume 100 -> midi 127
1332
+ up = m.findtext ('midi-unpitched')
1333
+ if up: s.drumInst [m.get ('id')] = int (up) - 1 # store midi-pitch for channel 10 notes
1334
+ s.instMid.append (midi)
1335
+ ps = e.find ('part-list') # partlist = [groupelem]
1336
+ xs = getPartlist (ps) # groupelem = partname | grouplist
1337
+ partlist, _ = parseParts (xs, {}, []) # grouplist = [groupelem, ..., groupdata]
1338
+ return partlist # groupdata = [group-symbol, group-barline, group-name, group-abbrev]
1339
+
1340
+ def mkTitle (s, e):
1341
+ def filterCredits (y): # y == filter level, higher filters less
1342
+ cs = []
1343
+ for x in credits: # skip redundant credit lines
1344
+ if y < 6 and (x in title or x in mvttl): continue # sure skip
1345
+ if y < 5 and (x in composer or x in lyricist): continue # almost sure skip
1346
+ if y < 4 and ((title and title in x) or (mvttl and mvttl in x)): continue # may skip too much
1347
+ if y < 3 and ([1 for c in composer if c in x] or [1 for c in lyricist if c in x]): continue # skips too much
1348
+ if y < 2 and re.match (r'^[\d\W]*$', x): continue # line only contains numbers and punctuation
1349
+ cs.append (x)
1350
+ if y == 0 and (title + mvttl): cs = '' # default: only credit when no title set
1351
+ return cs
1352
+ title = e.findtext ('work/work-title', '').strip ()
1353
+ mvttl = e.findtext ('movement-title', '').strip ()
1354
+ composer, lyricist, credits = [], [], []
1355
+ for creator in e.findall ('identification/creator'):
1356
+ if creator.text:
1357
+ if creator.get ('type') == 'composer':
1358
+ composer += [line.strip () for line in creator.text.split ('\n')]
1359
+ elif creator.get ('type') in ('lyricist', 'transcriber'):
1360
+ lyricist += [line.strip () for line in creator.text.split ('\n')]
1361
+ for rights in e.findall ('identification/rights'):
1362
+ if rights.text:
1363
+ lyricist += [line.strip () for line in rights.text.split ('\n')]
1364
+ for credit in e.findall('credit'):
1365
+ cs = ''.join (e.text or '' for e in credit.findall('credit-words'))
1366
+ credits += [re.sub (r'\s*[\r\n]\s*', ' ', cs)]
1367
+ credits = filterCredits (s.ctf)
1368
+ if title: title = 'T:%s\n' % title.replace ('\n', '\nT:')
1369
+ if mvttl: title += 'T:%s\n' % mvttl.replace ('\n', '\nT:')
1370
+ if credits: title += '\n'.join (['T:%s' % c for c in credits]) + '\n'
1371
+ if composer: title += '\n'.join (['C:%s' % c for c in composer]) + '\n'
1372
+ if lyricist: title += '\n'.join (['Z:%s' % c for c in lyricist]) + '\n'
1373
+ if title: abcOut.title = title[:-1]
1374
+ s.isSib = 'Sibelius' in (e.findtext ('identification/encoding/software') or '')
1375
+ if s.isSib: info ('Sibelius MusicXMl is unreliable')
1376
+
1377
+ def doDefaults (s, e):
1378
+ if not s.doPageFmt: return # return if -pf option absent
1379
+ d = e.find ('defaults');
1380
+ if d == None: return;
1381
+ mils = d.findtext ('scaling/millimeters') # mills == staff height (mm)
1382
+ tenths = d.findtext ('scaling/tenths') # staff height in tenths
1383
+ if not mils or not tenths: return
1384
+ xmlScale = float (mils) / float (tenths) / 10 # tenths -> mm
1385
+ space = 10 * xmlScale # space between staff lines == 10 tenths
1386
+ abcScale = space / 0.2117 # 0.2117 cm = 6pt = space between staff lines for scale = 1.0 in abcm2ps
1387
+ abcOut.pageFmt ['scale'] = abcScale
1388
+ eks = 2 * ['page-layout/'] + 4 * ['page-layout/page-margins/']
1389
+ eks = [a+b for a,b in zip (eks, 'page-height,page-width,left-margin,right-margin,top-margin,bottom-margin'.split (','))]
1390
+ for i in range (6):
1391
+ v = d.findtext (eks [i])
1392
+ k = abcOut.pagekeys [i+1] # pagekeys [0] == scale already done, skip it
1393
+ if not abcOut.pageFmt [k] and v:
1394
+ try: abcOut.pageFmt [k] = float (v) * xmlScale # -> cm
1395
+ except: info ('illegal value %s for xml element %s', (v, eks [i])); continue # just skip illegal values
1396
+
1397
+ def locStaffMap (s, part, maten): # map voice to staff with majority voting
1398
+ vmap = {} # {voice -> {staff -> n}} count occurrences of voice in staff
1399
+ s.vceInst = {} # {voice -> instrument id} for this part
1400
+ s.msc.vnums = {} # voice id's
1401
+ s.hasStems = {} # XML voice nums with at least one note with a stem (for tab key)
1402
+ s.stfMap, s.clefMap = {}, {} # staff -> [voices], staff -> clef
1403
+ ns = part.findall ('measure/note')
1404
+ for n in ns: # count staff allocations for all notes
1405
+ v = int (n.findtext ('voice', '1'))
1406
+ if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
1407
+ s.msc.vnums [v] = 1 # collect all used voice id's in this part
1408
+ sn = int (n.findtext ('staff', '1'))
1409
+ s.stfMap [sn] = []
1410
+ if v not in vmap:
1411
+ vmap [v] = {sn:1}
1412
+ else:
1413
+ d = vmap[v] # counter for voice v
1414
+ d[sn] = d.get (sn, 0) + 1 # ++ number of allocations for staff sn
1415
+ x = n.find ('instrument')
1416
+ if x != None: s.vceInst [v] = x.get ('id')
1417
+ x, noRest = n.findtext ('stem'), n.find ('rest') == None
1418
+ if noRest and (not x or x != 'none'): s.hasStems [v] = 1 # XML voice v has at least one stem
1419
+ vks = list (vmap.keys ())
1420
+ if s.jscript or s.isSib: vks.sort ()
1421
+ for v in vks: # choose staff with most allocations for each voice
1422
+ xs = [(n, sn) for sn, n in vmap[v].items ()]
1423
+ xs.sort ()
1424
+ stf = xs[-1][1] # the winner: staff with most notes of voice v
1425
+ s.stfMap [stf].append (v)
1426
+ s.vce2stf [v] = stf # reverse map
1427
+ s.curStf [v] = stf # current staff of XML voice v
1428
+
1429
+ def addStaffMap (s, vvmap): # vvmap: xml voice number -> global abc voice number
1430
+ part = [] # default: brace on staffs of one part
1431
+ for stf, voices in sorted (s.stfMap.items ()): # s.stfMap has xml staff and voice numbers
1432
+ locmap = [vvmap [iv] for iv in voices if iv in vvmap]
1433
+ nostem = [(iv not in s.hasStems) for iv in voices if iv in vvmap] # same order as locmap
1434
+ if locmap: # abc voice number of staff stf
1435
+ part.append (locmap)
1436
+ clef = s.clefMap.get (stf, 'treble') # {xml staff number -> clef}
1437
+ for i, iv in enumerate (locmap):
1438
+ clef_attr = ''
1439
+ if clef.startswith ('tab'):
1440
+ if nostem [i] and 'nostems' not in clef: clef_attr = ' nostems'
1441
+ if s.diafret and 'diafret' not in clef: clef_attr += ' diafret' # for all voices in the part
1442
+ abcOut.clefs [iv] = clef + clef_attr # add nostems when all notes of voice had no stem
1443
+ s.gStfMap.append (part)
1444
+
1445
+ def addMidiMap (s, ip, vvmap): # map abc voices to midi settings
1446
+ instr = s.instMid [ip] # get the midi settings for this part
1447
+ if instr.values (): defInstr = list(instr.values ())[0] # default settings = first instrument
1448
+ else: defInstr = s.midDflt # no instruments defined
1449
+ xs = []
1450
+ for v, vabc in vvmap.items (): # xml voice num, abc voice num
1451
+ ks = sorted (s.drumNotes.items ())
1452
+ ds = [(nt, step, midi, head) for (vd, nt), (step, midi, head) in ks if v == vd] # map perc notes
1453
+ id = s.vceInst.get (v, '') # get the instrument-id for part with multiple instruments
1454
+ if id in instr: # id is defined as midi-instrument in part-list
1455
+ xs.append ((vabc, instr [id] + ds)) # get midi settings for id
1456
+ else: xs.append ((vabc, defInstr + ds)) # only one instrument for this part
1457
+ xs.sort () # put abc voices in order
1458
+ s.midiMap.extend ([midi for v, midi in xs])
1459
+ snaarmap = ['E','G','B','d', 'f', 'a', "c'", "e'"]
1460
+ diamap = '0,1-,1,1+,2,3,3,4,4,5,6,6+,7,8-,8,8+,9,10,10,11,11,12,13,13+,14'.split (',')
1461
+ for k in sorted (s.tabmap.keys ()): # add %%map's for all tab voices
1462
+ v, noot = k;
1463
+ snaar, fret = s.tabmap [k];
1464
+ if s.diafret: fret = diamap [int (fret)]
1465
+ vabc = vvmap [v]
1466
+ snaar = s.stafflines - int (snaar)
1467
+ xs = s.tabVceMap.get (vabc, [])
1468
+ xs.append ('%%%%map tab%d %s print=%s heads=kop%s\n' % (vabc, noot, snaarmap [snaar], fret))
1469
+ s.tabVceMap [vabc] = xs
1470
+ s.koppen [fret] = 1 # collect noteheads for SVG defs
1471
+
1472
+ def parse (s, fobj):
1473
+ vvmapAll = {} # collect xml->abc voice maps (vvmap) of all parts
1474
+ e = E.parse (fobj)
1475
+ s.mkTitle (e)
1476
+ s.doDefaults (e)
1477
+ partlist = s.doPartList (e)
1478
+ parts = e.findall ('part')
1479
+ for ip, p in enumerate (parts):
1480
+ maten = p.findall ('measure')
1481
+ s.locStaffMap (p, maten) # {voice -> staff} for this part
1482
+ s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
1483
+ s.clefOct = {} # xml staff number -> current clef-octave-change
1484
+ s.curClef = {} # xml staff number -> current abc clef
1485
+ s.stemDir = {} # xml voice number -> current stem direction
1486
+ s.tabmap = {} # (xml voice, abc note) -> (string, fret)
1487
+ s.diafret = 0 # use diatonic fretting
1488
+ s.stafflines = 5
1489
+ s.msc.initVoices (newPart = 1) # create all voices
1490
+ aantalHerhaald = 0 # keep track of number of repititions
1491
+ herhaalMaat = 0 # target measure of the repitition
1492
+ divisions = [] # current value of <divisions> for each measure
1493
+ s.msr = Measure (ip) # various measure data
1494
+ while s.msr.ixm < len (maten):
1495
+ if ip == 31 and s.msr.ixm == 405:
1496
+ print('')
1497
+ maat = maten [s.msr.ixm]
1498
+ herhaal, lbrk = 0, ''
1499
+ s.msr.reset ()
1500
+ s.curalts = {} # passing accidentals are reset each measure
1501
+ es = list (maat)
1502
+ for i, e in enumerate (es):
1503
+ if e.tag == 'note': s.doNote (e)
1504
+ elif e.tag == 'attributes': s.doAttr (e)
1505
+ elif e.tag == 'direction':
1506
+ s.doDirection (e, i, es)
1507
+ elif e.tag == 'sound': s.doDirection (maat, i, es) # sound element directly in measure!
1508
+ elif e.tag == 'harmony': s.doHarmony (e, i, es)
1509
+ elif e.tag == 'barline':
1510
+ herhaal = s.doBarline (e)
1511
+ elif e.tag == 'backup':
1512
+ dt = int (e.findtext ('duration'))
1513
+ if chkbug (dt, s.msr): s.msc.incTime (-dt)
1514
+ elif e.tag == 'forward':
1515
+ dt = int (e.findtext ('duration'))
1516
+ if chkbug (dt, s.msr): s.msc.incTime (dt)
1517
+ elif e.tag == 'print': lbrk = s.doPrint (e)
1518
+ s.msc.addBar (lbrk, s.msr)
1519
+ divisions.append (s.msr.divs)
1520
+ if herhaal == 1:
1521
+ herhaalMaat = s.msr.ixm
1522
+ s.msr.ixm += 1
1523
+ elif herhaal == 2:
1524
+ if aantalHerhaald < 1: # jump
1525
+ s.msr.ixm = herhaalMaat
1526
+ aantalHerhaald += 1
1527
+ else:
1528
+ aantalHerhaald = 0 # reset
1529
+ s.msr.ixm += 1 # just continue
1530
+ else: s.msr.ixm += 1 # on to the next measure
1531
+ for rv in s.repeat_str.values (): # close hanging measure-repeats without stop
1532
+ rv [0] = '[I:repeat %s %d]' % (rv [1], 1)
1533
+ vvmap = s.msc.outVoices (divisions, ip, s.isSib)
1534
+ s.addStaffMap (vvmap) # update global staff map
1535
+ s.addMidiMap (ip, vvmap)
1536
+ vvmapAll.update (vvmap)
1537
+ if vvmapAll: # skip output if no part has any notes
1538
+ abcOut.mkHeader (s.gStfMap, partlist, s.midiMap, s.tabVceMap, s.koppen)
1539
+ abcOut.writeall ()
1540
+ else: info ('nothing written, %s has no notes ...' % abcOut.fnmext)
1541
+
1542
+ #----------------
1543
+ # Main Program
1544
+ #----------------
1545
+ if __name__ == '__main__':
1546
+ from optparse import OptionParser
1547
+ from glob import glob
1548
+ from zipfile import ZipFile
1549
+ ustr = '%prog [-h] [-u] [-m] [-c C] [-d D] [-n CPL] [-b BPL] [-o DIR] [-v V]\n'
1550
+ ustr += '[-x] [-p PFMT] [-t] [-s] [-i] [--v1] [--noped] [--stems] <file1> [<file2> ...]'
1551
+ parser = OptionParser (usage=ustr, version=str(VERSION))
1552
+ parser.add_option ("-u", action="store_true", help="unfold simple repeats")
1553
+ parser.add_option ("-m", action="store", help="0 -> no %%MIDI, 1 -> minimal %%MIDI, 2-> all %%MIDI", default=0)
1554
+ parser.add_option ("-c", action="store", type="int", help="set credit text filter to C", default=0, metavar='C')
1555
+ parser.add_option ("-d", action="store", type="int", help="set L:1/D", default=0, metavar='D') # ??????????????L
1556
+ parser.add_option ("-n", action="store", type="int", help="CPL: max number of characters per line (default 100)", default=0, metavar='CPL')
1557
+ parser.add_option ("-b", action="store", type="int", help="BPL: max number of bars per line", default=0, metavar='BPL')
1558
+ parser.add_option ("-o", action="store", help="store abc files in DIR", default='', metavar='DIR')
1559
+ parser.add_option ("-v", action="store", type="int", help="set volta typesetting behaviour to V", default=0, metavar='V')
1560
+ parser.add_option ("-x", action="store_true", help="output no line breaks")
1561
+ parser.add_option ("-p", action="store", help="pageformat PFMT (cm) = scale, pageheight, pagewidth, leftmargin, rightmargin, topmargin, botmargin", default='', metavar='PFMT')
1562
+ parser.add_option ("-j", action="store_true", help="switch for compatibility with javascript version")
1563
+ parser.add_option ("-t", action="store_true", help="translate perc- and tab-staff to ABC code with %%map, %%voicemap")
1564
+ parser.add_option ("-s", action="store_true", help="shift node heads 3 units left in a tab staff")
1565
+ parser.add_option ("--v1", action="store_true", help="start-stop directions allways to first voice of staff")
1566
+ parser.add_option ("--noped", action="store_false", help="skip all pedal directions", dest='ped', default=True)
1567
+ parser.add_option ("--stems", action="store_true", help="translate stem directions", dest='stm', default=False)
1568
+ parser.add_option ("-i", action="store_true", help="read xml file from standard input")
1569
+ options, args = parser.parse_args ()
1570
+ if options.n < 0: parser.error ('only values >= 0')
1571
+ if options.b < 0: parser.error ('only values >= 0')
1572
+ if options.d and options.d not in [2**n for n in range (10)]:
1573
+ parser.error ('D should be on of %s' % ','.join ([str(2**n) for n in range (10)]))
1574
+ options.p = options.p and options.p.split (',') or [] # ==> [] | [string]
1575
+ if len (args) == 0 and not options.i: parser.error ('no input file given')
1576
+ pad = options.o
1577
+ if pad:
1578
+ if not os.path.exists (pad): os.mkdir (pad)
1579
+ if not os.path.isdir (pad): parser.error ('%s is not a directory' % pad)
1580
+ fnmext_list = []
1581
+ for i in args: fnmext_list += glob (i)
1582
+ if options.i: fnmext_list = ['stdin.xml']
1583
+ if not fnmext_list: parser.error ('none of the input files exist')
1584
+ for X, fnmext in enumerate (fnmext_list):
1585
+ fnm, ext = os.path.splitext (fnmext)
1586
+ if ext.lower () not in ('.xml','.mxl','.musicxml'):
1587
+ info ('skipped input file %s, it should have extension .xml or .mxl' % fnmext)
1588
+ continue
1589
+ if os.path.isdir (fnmext):
1590
+ info ('skipped directory %s. Only files are accepted' % fnmext)
1591
+ continue
1592
+ if fnmext == 'stdin.xml':
1593
+ fobj = sys.stdin
1594
+ elif ext.lower () == '.mxl': # extract .xml file from .mxl file
1595
+ z = ZipFile(fnmext)
1596
+ for n in z.namelist(): # assume there is always an xml file in a mxl archive !!
1597
+ if (n[:4] != 'META') and (n[-4:].lower() == '.xml'):
1598
+ fobj = z.open (n)
1599
+ break # assume only one MusicXML file per archive
1600
+ else:
1601
+ fobj = open (fnmext, 'rb') # open regular xml file
1602
+
1603
+ abcOut = ABCoutput (fnm + '.abc', pad, X, options) # create global ABC output object
1604
+ psr = Parser (options) # xml parser
1605
+ try:
1606
+ psr.parse (fobj) # parse file fobj and write abc to <fnm>.abc
1607
+ except:
1608
+ etype, value, traceback = sys.exc_info () # works in python 2 & 3
1609
+ info ('** %s occurred: %s in %s' % (etype, value, fnmext), 0)