Upload 38 files
Browse files- .gitattributes +3 -0
- 1_batch_xml2abc.py +54 -0
- 2_data_preprocess.py +181 -0
- 3_batch_abc2xml.py +56 -0
- LICENSE.txt +21 -0
- README (1).md +31 -0
- README (2).md +293 -0
- README.md +44 -0
- abc2xml (1).py +0 -0
- abc2xml (2).py +0 -0
- abc2xml.py +0 -0
- config (1).py +67 -0
- config (2).py +38 -0
- config (3).py +15 -0
- config (4).py +18 -0
- config (5).py +39 -0
- config.py +35 -0
- data.py +136 -0
- demo.ipynb +821 -0
- demo.py +236 -0
- extract_clamp2.py +194 -0
- illustration.png +3 -0
- illustration_online.png +3 -0
- inference (1).py +271 -0
- inference.py +318 -0
- notagen.png +3 -0
- prompts.txt +112 -0
- requirements (6).txt +7 -0
- statistics.py +68 -0
- train-gen (1).py +325 -0
- train-gen.py +374 -0
- train.py +186 -0
- utils (1).py +483 -0
- utils (2).py +423 -0
- utils (3).py +423 -0
- utils (4).py +423 -0
- utils (5).py +421 -0
- utils.py +406 -0
- xml2abc.py +1609 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
illustration_online.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
illustration.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
notagen.png filter=lfs diff=lfs merge=lfs -text
|
1_batch_xml2abc.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ORI_FOLDER = "" # Replace with the path to your folder containing XML (.xml, .mxl, .musicxml) files
|
| 2 |
+
DES_FOLDER = "" # The script will convert the musicxml files and output standard abc notation files to this folder
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
import subprocess
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from multiprocessing import Pool
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def convert_xml2abc(file_list):
|
| 13 |
+
cmd = 'python xml2abc.py -d 8 -c 6 -x '
|
| 14 |
+
for file in tqdm(file_list):
|
| 15 |
+
filename = os.path.basename(file)
|
| 16 |
+
os.makedirs(DES_FOLDER, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
|
| 20 |
+
result = p.communicate()
|
| 21 |
+
output = result[0].decode('utf-8')
|
| 22 |
+
|
| 23 |
+
if output == '':
|
| 24 |
+
with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
|
| 25 |
+
f.write(file + '\n')
|
| 26 |
+
continue
|
| 27 |
+
else:
|
| 28 |
+
with open(os.path.join(DES_FOLDER, filename.rsplit('.', 1)[0] + '.abc'), 'w', encoding='utf-8') as f:
|
| 29 |
+
f.write(output)
|
| 30 |
+
except Exception as e:
|
| 31 |
+
with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
|
| 32 |
+
f.write(file + ' ' + str(e) + '\n')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == '__main__':
|
| 36 |
+
file_list = []
|
| 37 |
+
os.makedirs("logs", exist_ok=True)
|
| 38 |
+
|
| 39 |
+
# Traverse the specified folder for XML/MXL files
|
| 40 |
+
for root, dirs, files in os.walk(os.path.abspath(ORI_FOLDER)):
|
| 41 |
+
for file in files:
|
| 42 |
+
if file.endswith((".mxl", ".xml", ".musicxml")):
|
| 43 |
+
filename = os.path.join(root, file).replace("\\", "/")
|
| 44 |
+
file_list.append(filename)
|
| 45 |
+
|
| 46 |
+
# Shuffle and prepare for multiprocessing
|
| 47 |
+
random.shuffle(file_list)
|
| 48 |
+
num_files = len(file_list)
|
| 49 |
+
num_processes = os.cpu_count()
|
| 50 |
+
file_lists = [file_list[i::num_processes] for i in range(num_processes)]
|
| 51 |
+
|
| 52 |
+
# Create a pool for processing
|
| 53 |
+
with Pool(processes=num_processes) as pool:
|
| 54 |
+
pool.map(convert_xml2abc, file_lists)
|
2_data_preprocess.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ORI_FOLDER = '' # Replace with the path to your folder containing standard ABC notation files
|
| 2 |
+
INTERLEAVED_FOLDER = '' # Output interleaved ABC notation files to this folder
|
| 3 |
+
AUGMENTED_FOLDER = '' # Output key-augmented and rest-omitted ABC notation files to this folder
|
| 4 |
+
EVAL_SPLIT = 0.1 # The ratio of eval data
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import json
|
| 9 |
+
import shutil
|
| 10 |
+
import random
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from abctoolkit.utils import (
|
| 13 |
+
remove_information_field,
|
| 14 |
+
remove_bar_no_annotations,
|
| 15 |
+
Quote_re,
|
| 16 |
+
Barlines,
|
| 17 |
+
extract_metadata_and_parts,
|
| 18 |
+
extract_global_and_local_metadata,
|
| 19 |
+
extract_barline_and_bartext_dict)
|
| 20 |
+
from abctoolkit.convert import unidecode_abc_lines
|
| 21 |
+
from abctoolkit.rotate import rotate_abc
|
| 22 |
+
from abctoolkit.check import check_alignment_unrotated
|
| 23 |
+
from abctoolkit.transpose import Key2index, transpose_an_abc_text
|
| 24 |
+
|
| 25 |
+
os.makedirs(INTERLEAVED_FOLDER, exist_ok=True)
|
| 26 |
+
os.makedirs(AUGMENTED_FOLDER, exist_ok=True)
|
| 27 |
+
for key in Key2index.keys():
|
| 28 |
+
key_folder = os.path.join(AUGMENTED_FOLDER, key)
|
| 29 |
+
os.makedirs(key_folder, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def abc_preprocess_pipeline(abc_path):
|
| 33 |
+
|
| 34 |
+
with open(abc_path, 'r', encoding='utf-8') as f:
|
| 35 |
+
abc_lines = f.readlines()
|
| 36 |
+
|
| 37 |
+
# delete blank lines
|
| 38 |
+
abc_lines = [line for line in abc_lines if line.strip() != '']
|
| 39 |
+
|
| 40 |
+
# unidecode
|
| 41 |
+
abc_lines = unidecode_abc_lines(abc_lines)
|
| 42 |
+
|
| 43 |
+
# clean information field
|
| 44 |
+
abc_lines = remove_information_field(abc_lines=abc_lines, info_fields=['X:', 'T:', 'C:', 'W:', 'w:', 'Z:', '%%MIDI'])
|
| 45 |
+
|
| 46 |
+
# delete bar number annotations
|
| 47 |
+
abc_lines = remove_bar_no_annotations(abc_lines)
|
| 48 |
+
|
| 49 |
+
# delete \"
|
| 50 |
+
for i, line in enumerate(abc_lines):
|
| 51 |
+
if re.search(r'^[A-Za-z]:', line) or line.startswith('%'):
|
| 52 |
+
continue
|
| 53 |
+
else:
|
| 54 |
+
if r'\"' in line:
|
| 55 |
+
abc_lines[i] = abc_lines[i].replace(r'\"', '')
|
| 56 |
+
|
| 57 |
+
# delete text annotations with quotes
|
| 58 |
+
for i, line in enumerate(abc_lines):
|
| 59 |
+
quote_contents = re.findall(Quote_re, line)
|
| 60 |
+
for quote_content in quote_contents:
|
| 61 |
+
for barline in Barlines:
|
| 62 |
+
if barline in quote_content:
|
| 63 |
+
line = line.replace(quote_content, '')
|
| 64 |
+
abc_lines[i] = line
|
| 65 |
+
|
| 66 |
+
# check bar alignment
|
| 67 |
+
try:
|
| 68 |
+
_, bar_no_equal_flag, _ = check_alignment_unrotated(abc_lines)
|
| 69 |
+
if not bar_no_equal_flag:
|
| 70 |
+
print(abc_path, 'Unequal bar number')
|
| 71 |
+
raise Exception
|
| 72 |
+
except:
|
| 73 |
+
raise Exception
|
| 74 |
+
|
| 75 |
+
# deal with text annotations: remove too long text annotations; remove consecutive non-alphabet/number characters
|
| 76 |
+
for i, line in enumerate(abc_lines):
|
| 77 |
+
quote_matches = re.findall(r'"[^"]*"', line)
|
| 78 |
+
for match in quote_matches:
|
| 79 |
+
if match == '""':
|
| 80 |
+
line = line.replace(match, '')
|
| 81 |
+
if match[1] in ['^', '_']:
|
| 82 |
+
sub_string = match
|
| 83 |
+
pattern = r'([^a-zA-Z0-9])\1+'
|
| 84 |
+
sub_string = re.sub(pattern, r'\1', sub_string)
|
| 85 |
+
if len(sub_string) <= 40:
|
| 86 |
+
line = line.replace(match, sub_string)
|
| 87 |
+
else:
|
| 88 |
+
line = line.replace(match, '')
|
| 89 |
+
abc_lines[i] = line
|
| 90 |
+
|
| 91 |
+
abc_name = os.path.splitext(os.path.split(abc_path)[-1])[0]
|
| 92 |
+
|
| 93 |
+
# transpose
|
| 94 |
+
metadata_lines, part_text_dict = extract_metadata_and_parts(abc_lines)
|
| 95 |
+
global_metadata_dict, local_metadata_dict = extract_global_and_local_metadata(metadata_lines)
|
| 96 |
+
if global_metadata_dict['K'][0] == 'none':
|
| 97 |
+
global_metadata_dict['K'][0] = 'C'
|
| 98 |
+
ori_key = global_metadata_dict['K'][0]
|
| 99 |
+
|
| 100 |
+
interleaved_abc = rotate_abc(abc_lines)
|
| 101 |
+
interleaved_path = os.path.join(INTERLEAVED_FOLDER, abc_name + '.abc')
|
| 102 |
+
with open(interleaved_path, 'w') as w:
|
| 103 |
+
w.writelines(interleaved_abc)
|
| 104 |
+
|
| 105 |
+
for key in Key2index.keys():
|
| 106 |
+
transposed_abc_text = transpose_an_abc_text(abc_lines, key)
|
| 107 |
+
transposed_abc_lines = transposed_abc_text.split('\n')
|
| 108 |
+
transposed_abc_lines = list(filter(None, transposed_abc_lines))
|
| 109 |
+
transposed_abc_lines = [line + '\n' for line in transposed_abc_lines]
|
| 110 |
+
|
| 111 |
+
# rest reduction
|
| 112 |
+
metadata_lines, prefix_dict, left_barline_dict, bar_text_dict, right_barline_dict = \
|
| 113 |
+
extract_barline_and_bartext_dict(transposed_abc_lines)
|
| 114 |
+
reduced_abc_lines = metadata_lines
|
| 115 |
+
for i in range(len(bar_text_dict['V:1'])):
|
| 116 |
+
line = ''
|
| 117 |
+
for symbol in prefix_dict.keys():
|
| 118 |
+
valid_flag = False
|
| 119 |
+
for char in bar_text_dict[symbol][i]:
|
| 120 |
+
if char.isalpha() and not char in ['Z', 'z', 'X', 'x']:
|
| 121 |
+
valid_flag = True
|
| 122 |
+
break
|
| 123 |
+
if valid_flag:
|
| 124 |
+
if i == 0:
|
| 125 |
+
part_patch = '[' + symbol + ']' + prefix_dict[symbol] + left_barline_dict[symbol][0] + bar_text_dict[symbol][0] + right_barline_dict[symbol][0]
|
| 126 |
+
else:
|
| 127 |
+
part_patch = '[' + symbol + ']' + bar_text_dict[symbol][i] + right_barline_dict[symbol][i]
|
| 128 |
+
line += part_patch
|
| 129 |
+
line += '\n'
|
| 130 |
+
reduced_abc_lines.append(line)
|
| 131 |
+
|
| 132 |
+
reduced_abc_name = abc_name + '_' + key
|
| 133 |
+
reduced_abc_path = os.path.join(AUGMENTED_FOLDER, key, reduced_abc_name + '.abc')
|
| 134 |
+
|
| 135 |
+
with open(reduced_abc_path, 'w', encoding='utf-8') as w:
|
| 136 |
+
w.writelines(reduced_abc_lines)
|
| 137 |
+
|
| 138 |
+
return abc_name, ori_key
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if __name__ == '__main__':
|
| 145 |
+
|
| 146 |
+
data = []
|
| 147 |
+
file_list = os.listdir(ORI_FOLDER)
|
| 148 |
+
for file in tqdm(file_list):
|
| 149 |
+
ori_abc_path = os.path.join(ORI_FOLDER, file)
|
| 150 |
+
try:
|
| 151 |
+
abc_name, ori_key = abc_preprocess_pipeline(ori_abc_path)
|
| 152 |
+
except:
|
| 153 |
+
print(ori_abc_path, 'failed to pre-process.')
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
data.append({
|
| 157 |
+
'path': os.path.join(AUGMENTED_FOLDER, abc_name),
|
| 158 |
+
'key': ori_key
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
random.shuffle(data)
|
| 162 |
+
eval_data = data[ : int(EVAL_SPLIT * len(data))]
|
| 163 |
+
train_data = data[int(EVAL_SPLIT * len(data)) : ]
|
| 164 |
+
|
| 165 |
+
data_index_path = AUGMENTED_FOLDER + '.jsonl'
|
| 166 |
+
eval_index_path = AUGMENTED_FOLDER + '_eval.jsonl'
|
| 167 |
+
train_index_path = AUGMENTED_FOLDER + '_train.jsonl'
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
with open(data_index_path, 'w', encoding='utf-8') as w:
|
| 171 |
+
for d in data:
|
| 172 |
+
w.write(json.dumps(d) + '\n')
|
| 173 |
+
with open(eval_index_path, 'w', encoding='utf-8') as w:
|
| 174 |
+
for d in eval_data:
|
| 175 |
+
w.write(json.dumps(d) + '\n')
|
| 176 |
+
with open(train_index_path, 'w', encoding='utf-8') as w:
|
| 177 |
+
for d in train_data:
|
| 178 |
+
w.write(json.dumps(d) + '\n')
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
3_batch_abc2xml.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ORI_FOLDER = "" # Replace with the path to your folder containing standard/interleaved abc files
|
| 2 |
+
DES_FOLDER = "" # The script will convert the abc files and output musicxml files to this folder
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
import subprocess
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from multiprocessing import Pool
|
| 10 |
+
|
| 11 |
+
def convert_abc2xml(file_list):
|
| 12 |
+
cmd = 'python abc2xml.py '
|
| 13 |
+
for file in tqdm(file_list):
|
| 14 |
+
filename = file.split('/')[-1] # Extract file name
|
| 15 |
+
os.makedirs(DES_FOLDER, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
|
| 19 |
+
result = p.communicate()
|
| 20 |
+
output = result[0].decode('utf-8')
|
| 21 |
+
|
| 22 |
+
if output == '':
|
| 23 |
+
with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
|
| 24 |
+
f.write(file + '\n')
|
| 25 |
+
continue
|
| 26 |
+
else:
|
| 27 |
+
output_path = f"{DES_FOLDER}/" + ".".join(filename.split(".")[:-1]) + ".xml"
|
| 28 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 29 |
+
f.write(output)
|
| 30 |
+
except Exception as e:
|
| 31 |
+
with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
|
| 32 |
+
f.write(file + ' ' + str(e) + '\n')
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
if __name__ == '__main__':
|
| 36 |
+
file_list = []
|
| 37 |
+
os.makedirs("logs", exist_ok=True)
|
| 38 |
+
|
| 39 |
+
# Traverse the specified folder for ABC files
|
| 40 |
+
for root, dirs, files in os.walk(ORI_FOLDER):
|
| 41 |
+
for file in files:
|
| 42 |
+
if not file.endswith(".abc"):
|
| 43 |
+
continue
|
| 44 |
+
filename = os.path.join(root, file).replace("\\", "/")
|
| 45 |
+
file_list.append(filename)
|
| 46 |
+
|
| 47 |
+
# Prepare for multiprocessing
|
| 48 |
+
file_lists = []
|
| 49 |
+
random.shuffle(file_list)
|
| 50 |
+
for i in range(os.cpu_count()):
|
| 51 |
+
start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
|
| 52 |
+
end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
|
| 53 |
+
file_lists.append(file_list[start_idx:end_idx])
|
| 54 |
+
|
| 55 |
+
pool = Pool(processes=os.cpu_count())
|
| 56 |
+
pool.map(convert_abc2xml, file_lists)
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Yashan Wang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README (1).md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Local Gradio Demo
|
| 2 |
+
|
| 3 |
+
1. Set up the environment:
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
conda create --name notagen python=3.10
|
| 7 |
+
conda activate notagen
|
| 8 |
+
conda install pytorch==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
|
| 9 |
+
pip install accelerate
|
| 10 |
+
pip install optimum
|
| 11 |
+
pip install -r requirements.txt
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
2. Download [NotaGen-X](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth) and put it under ```gradio/```.
|
| 15 |
+
|
| 16 |
+
3. run ```demo.py```:
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
cd gradio/
|
| 20 |
+
python demo.py
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
4. Then you can view the demo page at 0.0.0.0:7861.
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
<img src="illustration.png" alt="NotaGen Gradio Demo">
|
| 27 |
+
</p>
|
| 28 |
+
|
| 29 |
+
You can choose period, composer, and instrumentation as a prompt combination for NotaGen's conditional generation. After generation completes, you can save the ABC notation and MusicXML files locally.
|
| 30 |
+
|
| 31 |
+
It is with some regret that the current combination of prompts is limited to 112, which is constrained by the number of pieces of music under each prompt in the fine-tuning dataset. We hope to expand the combinations and forms of prompts in the future.
|
README (2).md
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎵 NotaGen: Advancing Musicality in Symbolic Music Generation with Large Language Model Training Paradigms
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<!-- ArXiv -->
|
| 5 |
+
<a href="https://arxiv.org/abs/2502.18008">
|
| 6 |
+
<img src="https://img.shields.io/badge/NotaGen_Paper-ArXiv-%23B31B1B?logo=arxiv&logoColor=white" alt="Paper">
|
| 7 |
+
</a>
|
| 8 |
+
|
| 9 |
+
<!-- HuggingFace -->
|
| 10 |
+
<a href="https://huggingface.co/ElectricAlexis/NotaGen">
|
| 11 |
+
<img src="https://img.shields.io/badge/NotaGen_Weights-HuggingFace-%23FFD21F?logo=huggingface&logoColor=white" alt="Weights">
|
| 12 |
+
</a>
|
| 13 |
+
|
| 14 |
+
<!-- HuggingFace Space -->
|
| 15 |
+
<a href="https://huggingface.co/spaces/ElectricAlexis/NotaGen">
|
| 16 |
+
<img src="https://img.shields.io/badge/NotaGen_Space-Huggingface-✨️?logo=huggingface&logoColor=white" alt="Space">
|
| 17 |
+
</a>
|
| 18 |
+
|
| 19 |
+
<!-- Web Demo -->
|
| 20 |
+
<a href="https://electricalexis.github.io/notagen-demo/">
|
| 21 |
+
<img src="https://img.shields.io/badge/NotaGen_Demo-Web-%23007ACC?logo=google-chrome&logoColor=white" alt="Demo">
|
| 22 |
+
</a>
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
<img src="notagen.png" alt="NotaGen" width="50%">
|
| 27 |
+
</p>
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## 📖 Overview
|
| 31 |
+
**NotaGen** is a symbolic music generation model that explores the potential of producing **high-quality classical sheet music**. Inspired by the success of Large Language Models (LLMs), NotaGen adopts a three-stage training paradigm:
|
| 32 |
+
- 🧠 **Pre-training** on 1.6M musical pieces
|
| 33 |
+
- 🎯 **Fine-tuning** on ~9K classical compositions with `period-composer-instrumentation` prompts
|
| 34 |
+
- 🚀 **Reinforcement Learning** using our novel **CLaMP-DPO** method (no human annotations or pre-defined rewards required.)
|
| 35 |
+
|
| 36 |
+
Check our [demo page](https://electricalexis.github.io/notagen-demo/) and enjoy music composed by NotaGen!
|
| 37 |
+
|
| 38 |
+
## ⚙️ Environment Setup
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
conda create --name notagen python=3.10
|
| 42 |
+
conda activate notagen
|
| 43 |
+
conda install pytorch==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
|
| 44 |
+
pip install accelerate
|
| 45 |
+
pip install optimum
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## 🏋️ NotaGen Model Weights
|
| 50 |
+
|
| 51 |
+
### Pre-training
|
| 52 |
+
We provide pre-trained weights of different scales:
|
| 53 |
+
| Models | Parameters | Patch-level Decoder Layers | Character-level Decoder Layers | Hidden Size | Patch Length (Context Length) |
|
| 54 |
+
| ---- | ---- | ---- | ---- | ---- | ---- |
|
| 55 |
+
| [NotaGen-small](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain_p_size_16_p_length_2048_p_layers_12_c_layers_3_h_size_768_lr_0.0002_batch_8.pth) | 110M | 12 | 3 | 768 | 2048 |
|
| 56 |
+
| [NotaGen-medium](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain_p_size_16_p_length_2048_p_layers_16_c_layers_3_h_size_1024_lr_0.0001_batch_4.pth) | 244M | 16 | 3 | 1024 | 2048 |
|
| 57 |
+
| [NotaGen-large](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_0.0001_batch_4.pth) | 516M | 20 | 6 | 1280 | 1024 |
|
| 58 |
+
|
| 59 |
+
**Notice**: The pre-trained weights cannot be used for conditional generation based on 'period-composer-instrumentation'.
|
| 60 |
+
|
| 61 |
+
### Fine-tuning
|
| 62 |
+
|
| 63 |
+
We fine-tuned NotaGen-large on a corpus of approximately 9k classical pieces. You can download the weights [here](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain-finetune_p_size_16_p_length_1024_p_layers_c_layers_6_20_h_size_1280_lr_1e-05_batch_1.pth).
|
| 64 |
+
|
| 65 |
+
### Reinforcement-Learning
|
| 66 |
+
|
| 67 |
+
After pre-training and fine-tuning, we optimized NotaGen-large with 3 iterations of CLaMP-DPO. You can download the weights [here](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagen_pretrain-finetune-RL3_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-06_batch_1.pth).
|
| 68 |
+
|
| 69 |
+
### 🌟 NotaGen-X
|
| 70 |
+
|
| 71 |
+
Inspired by Deepseek-R1, we further optimized the training procedures of NotaGen and released a better version --- [NotaGen-X](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth). Compared to the version in the paper, NotaGen-X incorporates the following improvements:
|
| 72 |
+
|
| 73 |
+
- We introduced a post-training stage between pre-training and fine-tuning, refining the model with a classical-style subset of the pre-training dataset.
|
| 74 |
+
- We removed the key augmentation in the Fine-tune stage, making the instrument range of the generated compositions more reasonable.
|
| 75 |
+
- After RL, we utilized the resulting checkpoint to gather a new set of post-training data. Starting from the pre-trained checkpoint, we conducted another round of post-training, fine-tuning, and reinforcement learning.
|
| 76 |
+
|
| 77 |
+
If you want to add a new composer style to NotaGen-X, please refer to issue [#18](https://github.com/ElectricAlexis/NotaGen/issues/18) for more instructions :D
|
| 78 |
+
|
| 79 |
+
## 🎹 Demo
|
| 80 |
+
|
| 81 |
+
### Online Gradio Demo
|
| 82 |
+
|
| 83 |
+
We developed an [online gradio demo](https://huggingface.co/spaces/ElectricAlexis/NotaGen) on Huggingface Space for NotaGen-X. You can input **"Period-Composer-Instrumentation"** as the prompt to have NotaGen generate music, preview the audio / pdf scores, and download them :D
|
| 84 |
+
|
| 85 |
+
<p align="center">
|
| 86 |
+
<img src="gradio/illustration_online.png" alt="NotaGen Gradio Demo">
|
| 87 |
+
</p>
|
| 88 |
+
|
| 89 |
+
### Local Gradio Demo
|
| 90 |
+
|
| 91 |
+
We developed a local Gradio demo for NotaGen-X. You can input **"Period-Composer-Instrumentation"** as the prompt to have NotaGen generate music!
|
| 92 |
+
|
| 93 |
+
<p align="center">
|
| 94 |
+
<img src="gradio/illustration.png" alt="NotaGen Gradio Demo">
|
| 95 |
+
</p>
|
| 96 |
+
|
| 97 |
+
Deploying NotaGen-X inference locally may require 8GB of GPU memory. For implementation details, please view [gradio/README.md](https://github.com/ElectricAlexis/NotaGen/blob/main/gradio/README.md). We are also working on developing an online demo.
|
| 98 |
+
|
| 99 |
+
### Online Colab Notebook
|
| 100 |
+
|
| 101 |
+
Thanks for [@deeplearn-art](https://github.com/deeplearn-art/NotaGen)'s contribution of a [Google Colab notebook for NotaGen](https://colab.research.google.com/drive/1yJA1wG0fiwNeehdQxAUw56i4bTXzoVVv?usp=sharing)! You can run it and access to a Gradio public link to play with this demo. 🤩
|
| 102 |
+
|
| 103 |
+
### ComfyUI
|
| 104 |
+
|
| 105 |
+
Thanks for [@billwuhao](https://github.com/billwuhao/ComfyUI_NotaGen)'s contribution of [a ComfyUI node for NotaGen](https://github.com/billwuhao/ComfyUI_NotaGen)! It can automatically convert generated .abc to .xml, .mp3, and .png formats. You can listen to the generated music and see the sheet music too! Please visit the [repository page](https://github.com/billwuhao/ComfyUI_NotaGen) for more information. 🤩
|
| 106 |
+
|
| 107 |
+
<p align="center">
|
| 108 |
+
<img src="https://github.com/billwuhao/ComfyUI_NotaGen/blob/master/images/2025-03-10_06-24-03.png" alt="NotaGen ComfyUI">
|
| 109 |
+
</p>
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
## 🛠️ Data Pre-processing & Post-processing
|
| 113 |
+
|
| 114 |
+
For converting **ABC notation** files from / to **MusicXML** files, please view [data/README.md](https://github.com/ElectricAlexis/NotaGen/blob/main/data/README.md) for instructions.
|
| 115 |
+
|
| 116 |
+
To illustrate the specific data format, we provide a small dataset of **Schubert's lieder** compositions from the [OpenScore Lieder](https://github.com/OpenScore/Lieder), which includes:
|
| 117 |
+
- 🗂️ Interleaved ABC folders
|
| 118 |
+
- 🗂️ Augmented ABC folders
|
| 119 |
+
- 📄 Data index files for training and evaluation
|
| 120 |
+
|
| 121 |
+
You can download it [here](https://drive.google.com/drive/folders/1iVLkcywzXGcHFodce9nDQyEmK4UDmBtY?usp=sharing) and put it under ```data/```.
|
| 122 |
+
|
| 123 |
+
In the instructions of **Fine-tuning** and **Reinforcement Learning** below, we will use this dataset as an example of our implementation. **It won't include the "period-composer-instrumentation" conditioning**, just for showing how to adapt the pretrained NotaGen to a specific music style.
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## 🧠 Pre-train
|
| 127 |
+
If you want to use your own data to pre-train a blank **NotaGen** model, please:
|
| 128 |
+
1. Preprocess the data and generate the data index files following the instructions in [data/README.md](https://github.com/ElectricAlexis/NotaGen/blob/main/data/README.md)
|
| 129 |
+
2. Modify the parameters in ```pretrain/config.py```
|
| 130 |
+
|
| 131 |
+
Use this command for pre-training:
|
| 132 |
+
```bash
|
| 133 |
+
cd pretrain/
|
| 134 |
+
accelerate launch --multi_gpu --mixed_precision fp16 train-gen.py
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## 🎯 Fine-tune
|
| 138 |
+
|
| 139 |
+
Here we give an example on fine-tuning **NotaGen-large** with the **Schubert's lieder** data mentioned above.
|
| 140 |
+
|
| 141 |
+
**Notice:** The use of **NotaGen-large** requires at least **24GB of GPU memory** for training and inference. Alternatively, you may use **NotaGen-small** or **NotaGen-medium** and change the configuration of models in ```finetune/config.py```.
|
| 142 |
+
|
| 143 |
+
### Configuration
|
| 144 |
+
- In ```finetune/config.py```:
|
| 145 |
+
- Modify the ```DATA_TRAIN_INDEX_PATH``` and ```DATA_EVAL_INDEX_PATH```:
|
| 146 |
+
```python
|
| 147 |
+
# Configuration for the data
|
| 148 |
+
DATA_TRAIN_INDEX_PATH = "../data/schubert_augmented_train.jsonl"
|
| 149 |
+
DATA_EVAL_INDEX_PATH = "../data/schubert_augmented_eval.jsonl"
|
| 150 |
+
```
|
| 151 |
+
- Download pre-trained NotaGen weights, and modify the ```PRETRAINED_PATH```:
|
| 152 |
+
```python
|
| 153 |
+
PRETRAINED_PATH = "../pretrain/weights_notagen_pretrain_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_0.0001_batch_4.pth" # Use NotaGen-large
|
| 154 |
+
```
|
| 155 |
+
- ```EXP_TAG``` is for differentiating the models. It will be integrated into the ckpt's name. Here we set it to ```schubert```.
|
| 156 |
+
- You can also modify other parameters like the learning rate.
|
| 157 |
+
|
| 158 |
+
### Execution
|
| 159 |
+
Use this command for fine-tuning:
|
| 160 |
+
```bash
|
| 161 |
+
cd finetune/
|
| 162 |
+
CUDA_VISIBLE_DEVICES=0 python train-gen.py
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
## 🚀 Reinforcement Learning (CLaMP-DPO)
|
| 166 |
+
|
| 167 |
+
Here we give an example on how to use **CLaMP-DPO** to enhance the model fine-tuned with **Schubert's lieder** data.
|
| 168 |
+
|
| 169 |
+
### ⚙️ [CLaMP 2](https://github.com/sanderwood/clamp2) Setup
|
| 170 |
+
|
| 171 |
+
Download model weights and put them under the ```clamp2/```folder:
|
| 172 |
+
- [CLaMP 2 Model Weights](https://huggingface.co/sander-wood/clamp2/blob/main/weights_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.pth)
|
| 173 |
+
- [M3 Model Weights](https://huggingface.co/sander-wood/clamp2/blob/main/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth)
|
| 174 |
+
|
| 175 |
+
### 🔍 Extract Ground Truth Features
|
| 176 |
+
Modify ```input_dir``` and ```output_dir``` in ```clamp2/extract_clamp2.py```:
|
| 177 |
+
```python
|
| 178 |
+
input_dir = '../data/schubert_interleaved' # interleaved abc folder
|
| 179 |
+
output_dir = 'feature/schubert_interleaved' # feature folder
|
| 180 |
+
```
|
| 181 |
+
Extract the features:
|
| 182 |
+
```
|
| 183 |
+
cd clamp2/
|
| 184 |
+
python extract_clamp2.py
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### 🔄 CLaMP-DPO
|
| 188 |
+
|
| 189 |
+
Here we give an example of an iteration of **CLaMP-DPO** from the initial model fine-tuned on **Schubert's lieder** data.
|
| 190 |
+
|
| 191 |
+
#### 1. Inference
|
| 192 |
+
- Modify the ```INFERENCE_WEIGHTS_PATH``` to path of the fine-tuned weights and ```NUM_SAMPLES``` to generate in ```inference/config.py```:
|
| 193 |
+
```python
|
| 194 |
+
INFERENCE_WEIGHTS_PATH = '../finetune/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1.pth'
|
| 195 |
+
NUM_SAMPLES = 1000
|
| 196 |
+
```
|
| 197 |
+
- Inference:
|
| 198 |
+
```
|
| 199 |
+
cd inference/
|
| 200 |
+
python inference.py
|
| 201 |
+
```
|
| 202 |
+
This will generate an ```output/```folder with two subfolders: ```original``` and ```interleaved```. The ```original/``` subdirectory stores the raw inference outputs from the model, while the ```interleaved/``` subdirectory contains data post-processed with rest measure completion, compatible with CLaMP 2. Each of these subdirectories will contain a model-specific folder, named as a combination of the model's name and its sampling parameters.
|
| 203 |
+
|
| 204 |
+
#### 2. Extract Generated Data Features
|
| 205 |
+
|
| 206 |
+
Modify ```input_dir``` and ```output_dir``` in ```clamp2/extract_clamp2.py```:
|
| 207 |
+
```python
|
| 208 |
+
input_dir = '../output/interleaved/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2' # interleaved abc folder
|
| 209 |
+
output_dir = 'feature/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2' # feature folder
|
| 210 |
+
```
|
| 211 |
+
Extract the features:
|
| 212 |
+
```
|
| 213 |
+
cd clamp2/
|
| 214 |
+
python extract_clamp2.py
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
#### 3. Statistics on Averge CLaMP 2 Score (Optional)
|
| 218 |
+
If you're interested in the **Average CLaMP 2 Score** of the current model, modify the parameters in ```clamp2/statistics.py```:
|
| 219 |
+
```python
|
| 220 |
+
gt_feature_folder = 'feature/schubert_interleaved'
|
| 221 |
+
output_feature_folder = 'feature/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
|
| 222 |
+
```
|
| 223 |
+
Then run this script:
|
| 224 |
+
```
|
| 225 |
+
cd clamp2/
|
| 226 |
+
python statistics.py
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
#### 4. Construct Preference Data
|
| 230 |
+
Modify the parameters in ```RL/data.py```:
|
| 231 |
+
```python
|
| 232 |
+
gt_feature_folder = '../clamp2/feature/schubert_interleaved'
|
| 233 |
+
output_feature_folder = '../clamp2/feature/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
|
| 234 |
+
output_original_abc_folder = '../output/original/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
|
| 235 |
+
output_interleaved_abc_folder = '../output/interleaved/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1_k_9_p_0.9_temp_1.2'
|
| 236 |
+
data_index_path = 'schubert_RL1.json' # Data for the first iteration of RL
|
| 237 |
+
data_select_portion = 0.1
|
| 238 |
+
```
|
| 239 |
+
In this script, the **CLaMP 2 Score** of each generated piece will be calculated and sorted. The portion of data in the chosen and rejected sets is determined by ```data_select_portion```. Additionally, there are also three rules to exclude problematic sheets from the chosen set:
|
| 240 |
+
- Sheets with duration alignment problems are excluded;
|
| 241 |
+
- Sheets that may plagiarize from ground truth data (ld_sim>0.95) are excluded;
|
| 242 |
+
- Sheets where staves for the same instrument are not grouped together are excluded.
|
| 243 |
+
|
| 244 |
+
The prefence data file will be names as ```data_index_path```, which records the file paths in chosen and rejected sets.
|
| 245 |
+
|
| 246 |
+
Run this script:
|
| 247 |
+
```
|
| 248 |
+
cd RL/
|
| 249 |
+
python data.py
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
#### 5. DPO Training
|
| 253 |
+
|
| 254 |
+
Modify the parameters in ```RL/config.py```:
|
| 255 |
+
```python
|
| 256 |
+
DATA_INDEX_PATH = 'schubert_RL1.json' # Preference data path
|
| 257 |
+
PRETRAINED_PATH = '../finetune/weights_notagen_schubert_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-05_batch_1.pth' # The model to go through DPO optimization
|
| 258 |
+
EXP_TAG = 'schubert-RL1' # Model tag for differentiation
|
| 259 |
+
```
|
| 260 |
+
You can also modify other parameters like ```OPTIMATION_STEPS``` and DPO hyper-parameters.
|
| 261 |
+
|
| 262 |
+
Run this script:
|
| 263 |
+
```
|
| 264 |
+
cd RL/
|
| 265 |
+
CUDA_VISIBLE_DEVICES=0 python train.py
|
| 266 |
+
```
|
| 267 |
+
After training, a model named ```weights_notagen_schubert-RL1_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_c_layers_6_h_size_1280_lr_1e-06.pth``` will be saved under ```RL/```. For the second round of CLaMP-DPO, please go back to the first inference stage, and let the new model to generate pieces.
|
| 268 |
+
|
| 269 |
+
For this small experiment on **Schubert's lieder** data, we post our **Average CLaMP 2 Score** here for the fine-tuned model and models after each iteration of CLaMP-DPO, as a reference:
|
| 270 |
+
|
| 271 |
+
| CLaMP-DPO Iteration (K) | Average CLaMP 2 Score |
|
| 272 |
+
| ---- | ---- |
|
| 273 |
+
| 0 (fine-tuned) | 0.324 |
|
| 274 |
+
| 1 | 0.579 |
|
| 275 |
+
| 2 | 0.778 |
|
| 276 |
+
|
| 277 |
+
If you are interested in this method, have a try on your own style-specific dataset :D
|
| 278 |
+
|
| 279 |
+
## 📚 Citation
|
| 280 |
+
|
| 281 |
+
If you find **NotaGen** or **CLaMP-DPO** useful in your work, please cite our paper.
|
| 282 |
+
|
| 283 |
+
```bibtex
|
| 284 |
+
@misc{wang2025notagenadvancingmusicalitysymbolic,
|
| 285 |
+
title={NotaGen: Advancing Musicality in Symbolic Music Generation with Large Language Model Training Paradigms},
|
| 286 |
+
author={Yashan Wang and Shangda Wu and Jianhuai Hu and Xingjian Du and Yueqi Peng and Yongxin Huang and Shuai Fan and Xiaobing Li and Feng Yu and Maosong Sun},
|
| 287 |
+
year={2025},
|
| 288 |
+
eprint={2502.18008},
|
| 289 |
+
archivePrefix={arXiv},
|
| 290 |
+
primaryClass={cs.SD},
|
| 291 |
+
url={https://arxiv.org/abs/2502.18008},
|
| 292 |
+
}
|
| 293 |
+
```
|
README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Data Pre-processing
|
| 2 |
+
|
| 3 |
+
### Convert from MusicXML
|
| 4 |
+
|
| 5 |
+
- Navigate to the data folder ```cd data/```
|
| 6 |
+
- Modify the ```ORI_FOLDER``` and ```DES_FOLDER``` in ```1_batch_xml2abc.py```, then run this script:
|
| 7 |
+
```
|
| 8 |
+
python 1_batch_xml2abc.py
|
| 9 |
+
```
|
| 10 |
+
This will conver the MusicXML files into standard ABC notation files.
|
| 11 |
+
- Modify the ```ORI_FOLDER```, ```INTERLEAVED_FOLDER```, ```AUGMENTED_FOLDER```, and ```EVAL_SPLIT``` in ```2_data_preprocess.py```:
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
ORI_FOLDER = '' # Folder containing standard ABC notation files
|
| 15 |
+
INTERLEAVED_FOLDER = '' # Output interleaved ABC notation files that are compatible with CLaMP 2 to this folder
|
| 16 |
+
AUGMENTED_FOLDER = '' # On the basis of interleaved ABC, output key-augmented and rest-omitted files that are compatible with NotaGen to this folder
|
| 17 |
+
EVAL_SPLIT = 0.1 # Evaluation data ratio
|
| 18 |
+
```
|
| 19 |
+
then run this script:
|
| 20 |
+
```
|
| 21 |
+
python 2_data_preprocess.py
|
| 22 |
+
```
|
| 23 |
+
- The script will convert the standard ABC to interleaved ABC, which is compatible with CLaMP 2. The files will be under ```INTERLEAVED_FOLDER```.
|
| 24 |
+
|
| 25 |
+
- This script will make 15 key signature folders under the ```AUGMENTED_FOLDER```, and output interleaved ABC notation files with rest bars omitted. This is the data representation that NotaGen adopts.
|
| 26 |
+
|
| 27 |
+
- This script will also generate data index files for training NotaGen. It will randomly split train and eval sets according to the proportion ```EVAL_SPLIT``` defines. The index files will be named as ```{AUGMENTED_FOLDER}_train.jsonl``` and ```{AUGMENTED_FOLDER}_eval.jsonl```.
|
| 28 |
+
|
| 29 |
+
## Data Post-processing
|
| 30 |
+
|
| 31 |
+
### Preview Sheets in ABC Notation
|
| 32 |
+
|
| 33 |
+
We recommend [EasyABC](https://sourceforge.net/projects/easyabc/), a nice software for ABC Notation previewing, composing and editing.
|
| 34 |
+
|
| 35 |
+
It's needed to add a line "X:1" before each piece to present the score image in EasyABC :D
|
| 36 |
+
|
| 37 |
+
### Convert to MusicXML
|
| 38 |
+
|
| 39 |
+
- Go to the data folder ```cd data/```
|
| 40 |
+
- Modify the ```ORI_FOLDER``` and ```DES_FOLDER``` in ```3_batch_abc2xml.py```, then run this script:
|
| 41 |
+
```
|
| 42 |
+
python 3_batch_abc2xml.py
|
| 43 |
+
```
|
| 44 |
+
This will conver the standard/interleaved ABC notation files into MusicXML files.
|
abc2xml (1).py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
abc2xml (2).py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
abc2xml.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
config (1).py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation
|
| 2 |
+
WANDB_KEY = "<your_wandb_key>" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging
|
| 3 |
+
|
| 4 |
+
# -------------------- Configuration for M3 Training --------------------
|
| 5 |
+
TRAIN_FOLDERS = [
|
| 6 |
+
"<path_to_training_data>" # Directory containing training data
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
EVAL_FOLDERS = [
|
| 10 |
+
"" # (Optional) Directory containing evaluation data
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
PATCH_SIZE = 64 # Size of each patch
|
| 14 |
+
PATCH_LENGTH = 512 # Length of the patches
|
| 15 |
+
PATCH_NUM_LAYERS = 12 # Number of layers in the encoder
|
| 16 |
+
TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder
|
| 17 |
+
M3_HIDDEN_SIZE = 768 # Size of the hidden layer
|
| 18 |
+
|
| 19 |
+
M3_NUM_EPOCH = 100 # Maximum number of epochs for training
|
| 20 |
+
M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer
|
| 21 |
+
M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training
|
| 22 |
+
M3_MASK_RATIO = 0.45 # Ratio of masked elements during training
|
| 23 |
+
M3_DETERMINISTIC = True # Ensures deterministic results with random seeds
|
| 24 |
+
M3_WANDB_LOG = True # Enable logging to Weights and Biases
|
| 25 |
+
M3_LOAD_CKPT = True # Load model weights from a checkpoint if available
|
| 26 |
+
|
| 27 |
+
M3_WEIGHTS_PATH = (
|
| 28 |
+
"weights_m3_p_size_" + str(PATCH_SIZE) +
|
| 29 |
+
"_p_length_" + str(PATCH_LENGTH) +
|
| 30 |
+
"_t_layers_" + str(TOKEN_NUM_LAYERS) +
|
| 31 |
+
"_p_layers_" + str(PATCH_NUM_LAYERS) +
|
| 32 |
+
"_h_size_" + str(M3_HIDDEN_SIZE) +
|
| 33 |
+
"_lr_" + str(M3_LEARNING_RATE) +
|
| 34 |
+
"_batch_" + str(M3_BATCH_SIZE) +
|
| 35 |
+
"_mask_" + str(M3_MASK_RATIO) + ".pth"
|
| 36 |
+
) # Path to store the model weights
|
| 37 |
+
M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
|
| 38 |
+
|
| 39 |
+
# -------------------- Configuration for CLaMP2 Training ----------------
|
| 40 |
+
TRAIN_JSONL = "<path_to_training_jsonl>" # Path to the JSONL file with training data
|
| 41 |
+
EVAL_JSONL = "" # (Optional) Path to the JSONL file with evaluation data
|
| 42 |
+
|
| 43 |
+
CLAMP2_HIDDEN_SIZE = 768 # Size of the hidden layer
|
| 44 |
+
TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model
|
| 45 |
+
|
| 46 |
+
CLAMP2_NUM_EPOCH = 100 # Maximum number of epochs for training
|
| 47 |
+
CLAMP2_LEARNING_RATE = 5e-5 # Learning rate for the optimizer
|
| 48 |
+
CLAMP2_BATCH_SIZE = 128 # Batch size per GPU (single card) during training
|
| 49 |
+
LOGIT_SCALE = 1 # Scaling factor for contrastive loss
|
| 50 |
+
MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input
|
| 51 |
+
TEXT_DROPOUT = True # Whether to apply dropout during text processing
|
| 52 |
+
CLAMP2_DETERMINISTIC = True # Ensures deterministic results with random seeds
|
| 53 |
+
CLAMP2_LOAD_M3 = True # Load weights from the M3 model
|
| 54 |
+
CLAMP2_WANDB_LOG = True # Enable logging to Weights and Biases
|
| 55 |
+
CLAMP2_LOAD_CKPT = True # Load weights from a checkpoint if available
|
| 56 |
+
|
| 57 |
+
CLAMP2_WEIGHTS_PATH = (
|
| 58 |
+
"weights_clamp2_h_size_" + str(CLAMP2_HIDDEN_SIZE) +
|
| 59 |
+
"_lr_" + str(CLAMP2_LEARNING_RATE) +
|
| 60 |
+
"_batch_" + str(CLAMP2_BATCH_SIZE) +
|
| 61 |
+
"_scale_" + str(LOGIT_SCALE) +
|
| 62 |
+
"_t_length_" + str(MAX_TEXT_LENGTH) +
|
| 63 |
+
"_t_model_" + TEXT_MODEL_NAME.replace("/", "_") +
|
| 64 |
+
"_t_dropout_" + str(TEXT_DROPOUT) +
|
| 65 |
+
"_m3_" + str(CLAMP2_LOAD_M3) + ".pth"
|
| 66 |
+
) # Path to store CLaMP2 model weights
|
| 67 |
+
CLAMP2_LOGS_PATH = CLAMP2_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
|
config (2).py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Configuration for the data
|
| 4 |
+
DATA_TRAIN_INDEX_PATH = ""
|
| 5 |
+
DATA_EVAL_INDEX_PATH = ""
|
| 6 |
+
|
| 7 |
+
# Configuration for the model
|
| 8 |
+
PATCH_STREAM = True # Stream training / inference
|
| 9 |
+
PATCH_SIZE = 16 # Patch Size
|
| 10 |
+
PATCH_LENGTH = 1024 # Patch Length
|
| 11 |
+
CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
|
| 12 |
+
PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
|
| 13 |
+
HIDDEN_SIZE = 1280 # Hidden Size
|
| 14 |
+
|
| 15 |
+
# Configuration for the training
|
| 16 |
+
BATCH_SIZE = 1
|
| 17 |
+
LEARNING_RATE = 1e-5
|
| 18 |
+
NUM_EPOCHS = 64 # Number of epochs to train for (if early stopping doesn't intervene)
|
| 19 |
+
ACCUMULATION_STEPS = 1 # Accumulation steps to simulate large batch size
|
| 20 |
+
PATCH_SAMPLING_BATCH_SIZE = 0 # Batch size for patch during training, 0 for full conaudio
|
| 21 |
+
LOAD_FROM_CHECKPOINT = False # Whether to load weights from a checkpoint
|
| 22 |
+
WANDB_LOGGING = False # Whether to log to wandb
|
| 23 |
+
WANDB_KEY = '<your_wandb_key>'
|
| 24 |
+
|
| 25 |
+
PRETRAINED_PATH = "" # Path of pretrained weights
|
| 26 |
+
EXP_TAG = '' # Experiment tag for name differentiation
|
| 27 |
+
NAME = EXP_TAG + \
|
| 28 |
+
"_p_size_" + str(PATCH_SIZE) + \
|
| 29 |
+
"_p_length_" + str(PATCH_LENGTH) + \
|
| 30 |
+
"_p_layers_" + str(PATCH_NUM_LAYERS) + \
|
| 31 |
+
"_c_layers_" + str(CHAR_NUM_LAYERS) + \
|
| 32 |
+
"_h_size_" + str(HIDDEN_SIZE) + \
|
| 33 |
+
"_lr_" + str(LEARNING_RATE) + \
|
| 34 |
+
"_batch_" + str(BATCH_SIZE)
|
| 35 |
+
|
| 36 |
+
WEIGHTS_PATH = "weights_notagen_" + NAME + ".pth" # Path to save weights
|
| 37 |
+
LOGS_PATH = "logs_notagen_" + NAME + ".txt" # Path to save logs
|
| 38 |
+
WANDB_NAME = NAME
|
config (3).py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Configurations for inference
|
| 4 |
+
INFERENCE_WEIGHTS_PATH = 'weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth' # Path to weights for inference# Folder to save output files
|
| 5 |
+
TOP_K = 9 # Top k for sampling
|
| 6 |
+
TOP_P = 0.9 # Top p for sampling
|
| 7 |
+
TEMPERATURE = 1.2 # Temperature for sampling
|
| 8 |
+
|
| 9 |
+
# Configurations for model
|
| 10 |
+
PATCH_STREAM = True # Stream training / inference
|
| 11 |
+
PATCH_SIZE = 16 # Patch Size
|
| 12 |
+
PATCH_LENGTH = 1024 # Patch Length
|
| 13 |
+
CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
|
| 14 |
+
PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
|
| 15 |
+
HIDDEN_SIZE = 1280 # Hidden Size
|
config (4).py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Configurations for inference
|
| 4 |
+
INFERENCE_WEIGHTS_PATH = '' # Path to weights for inference# Folder to save output files
|
| 5 |
+
NUM_SAMPLES = 1000 # Number of samples to generate (only for generate mode)
|
| 6 |
+
TOP_K = 9 # Top k for sampling
|
| 7 |
+
TOP_P = 0.9 # Top p for sampling
|
| 8 |
+
TEMPERATURE = 1.2 # Temperature for sampling
|
| 9 |
+
ORIGINAL_OUTPUT_FOLDER = os.path.join('../output/original', os.path.splitext(os.path.split(INFERENCE_WEIGHTS_PATH)[-1])[0] + '_k_' + str(TOP_K) + '_p_' + str(TOP_P) + '_temp_' + str(TEMPERATURE))
|
| 10 |
+
INTERLEAVED_OUTPUT_FOLDER = os.path.join('../output/interleaved', os.path.splitext(os.path.split(INFERENCE_WEIGHTS_PATH)[-1])[0] + '_k_' + str(TOP_K) + '_p_' + str(TOP_P) + '_temp_' + str(TEMPERATURE))
|
| 11 |
+
|
| 12 |
+
# Configurations for model
|
| 13 |
+
PATCH_STREAM = True # Stream training / inference
|
| 14 |
+
PATCH_SIZE = 16 # Patch Size
|
| 15 |
+
PATCH_LENGTH = 1024 # Patch Length
|
| 16 |
+
CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
|
| 17 |
+
PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
|
| 18 |
+
HIDDEN_SIZE = 1280 # Hidden Size
|
config (5).py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Configuration for the data
|
| 4 |
+
DATA_TRAIN_INDEX_PATH = ""
|
| 5 |
+
DATA_EVAL_INDEX_PATH = ""
|
| 6 |
+
|
| 7 |
+
# Configuration for the model
|
| 8 |
+
PATCH_STREAM = True
|
| 9 |
+
PATCH_SIZE = 16 # Patch Size
|
| 10 |
+
PATCH_LENGTH = 2048 # Patch Length
|
| 11 |
+
CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
|
| 12 |
+
PATCH_NUM_LAYERS = 12 # Number of layers in the encoder
|
| 13 |
+
HIDDEN_SIZE = 768 # Hidden Size
|
| 14 |
+
|
| 15 |
+
# Configuration for the training
|
| 16 |
+
BATCH_SIZE = 4
|
| 17 |
+
LEARNING_RATE = 1e-4
|
| 18 |
+
NUM_EPOCHS = 128 # Number of epochs to train for (if early stopping doesn't intervene)
|
| 19 |
+
ACCUMULATION_STEPS = 1 # Accumulation steps to simulate large batch size
|
| 20 |
+
PATCH_SAMPLING_BATCH_SIZE = 0 # Batch size for patch during training, 0 for full conaudio
|
| 21 |
+
LOAD_FROM_CHECKPOINT = False # Whether to load weights from a checkpoint
|
| 22 |
+
WANDB_LOGGING = False # Whether to log to wandb
|
| 23 |
+
WANDB_KEY = '<your_wandb_key>'
|
| 24 |
+
|
| 25 |
+
EXP_TAG = 'pretrain' # Experiment tag for differentiation
|
| 26 |
+
NAME = EXP_TAG + \
|
| 27 |
+
"_p_size_" + str(PATCH_SIZE) + \
|
| 28 |
+
"_p_length_" + str(PATCH_LENGTH) + \
|
| 29 |
+
"_p_layers_" + str(PATCH_NUM_LAYERS) + \
|
| 30 |
+
"_c_layers_" + str(CHAR_NUM_LAYERS) + \
|
| 31 |
+
"_h_size_" + str(HIDDEN_SIZE) + \
|
| 32 |
+
"_lr_" + str(LEARNING_RATE) + \
|
| 33 |
+
"_batch_" + str(BATCH_SIZE)
|
| 34 |
+
|
| 35 |
+
WEIGHTS_PATH = "weights_notagen_" + NAME + ".pth" # Path to save weights
|
| 36 |
+
LOGS_PATH = "logs_notagen_" + NAME + ".txt" # Path to save logs
|
| 37 |
+
WANDB_NAME = NAME
|
| 38 |
+
|
| 39 |
+
|
config.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Configuration for the data
|
| 4 |
+
DATA_INDEX_PATH = ''
|
| 5 |
+
|
| 6 |
+
# Configuration for the model
|
| 7 |
+
PATCH_STREAM = True
|
| 8 |
+
PATCH_SIZE = 16 # Patch Size
|
| 9 |
+
PATCH_LENGTH = 1024 # Patch Length
|
| 10 |
+
CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
|
| 11 |
+
PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
|
| 12 |
+
HIDDEN_SIZE = 1280 # Hidden Size
|
| 13 |
+
|
| 14 |
+
# Configuration for the training
|
| 15 |
+
BETA = 0.1 # beta in DPO's objective function
|
| 16 |
+
LAMBDA = 10 # lambda in DPOP's objective function
|
| 17 |
+
LEARNING_RATE = 1e-6
|
| 18 |
+
OPTIMIZATION_STEPS = 10000 # Optimization steps for DPO
|
| 19 |
+
WANDB_LOGGING = False # Whether to log to wandb
|
| 20 |
+
WANDB_KEY = '<your_wandb_key>'
|
| 21 |
+
|
| 22 |
+
PRETRAINED_PATH = ''
|
| 23 |
+
EXP_TAG = ''
|
| 24 |
+
NAME = EXP_TAG + \
|
| 25 |
+
"_beta_" + str(BETA) + \
|
| 26 |
+
"_lambda_" + str(LAMBDA) + \
|
| 27 |
+
"_p_size_" + str(PATCH_SIZE) + \
|
| 28 |
+
"_p_length_" + str(PATCH_LENGTH) + \
|
| 29 |
+
"_p_layers_" + str(PATCH_NUM_LAYERS) + \
|
| 30 |
+
"_c_layers_" + str(CHAR_NUM_LAYERS) + \
|
| 31 |
+
"_h_size_" + str(HIDDEN_SIZE) + \
|
| 32 |
+
"_lr_" + str(LEARNING_RATE)
|
| 33 |
+
|
| 34 |
+
WEIGHTS_PATH = "weights_notagen_" + NAME + ".pth" # Path to save weights
|
| 35 |
+
WANDB_NAME = NAME
|
data.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gt_feature_folder = '../clamp2/feature/schubert_interleaved'
|
| 2 |
+
output_feature_folder = '../clamp2/feature/weights_notagen_schubert-RL2_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_h_size_1280_lr_1e-06_k_9_p_0.9_temp_1.2'
|
| 3 |
+
output_original_abc_folder = '../output/original/weights_notagen_schubert-RL2_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_h_size_1280_lr_1e-06_k_9_p_0.9_temp_1.2'
|
| 4 |
+
output_interleaved_abc_folder = '../output/interleaved/weights_notagen_schubert-RL2_beta_0.1_lambda_10_p_size_16_p_length_1024_p_layers_20_h_size_1280_lr_1e-06_k_9_p_0.9_temp_1.2'
|
| 5 |
+
data_index_path = 'schubert_RL3.json'
|
| 6 |
+
data_select_portion = 0.1
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
from config import *
|
| 14 |
+
from abctoolkit.check import check_alignment_rotated, check_alignment_unrotated
|
| 15 |
+
from abctoolkit.rotate import unrotate_abc
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_npy_files(folder_path_list):
|
| 19 |
+
"""
|
| 20 |
+
Load all .npy files from a specified folder and return a list of numpy arrays.
|
| 21 |
+
"""
|
| 22 |
+
npy_list = []
|
| 23 |
+
for file_path in folder_path_list:
|
| 24 |
+
if file_path.endswith('.npy'):
|
| 25 |
+
# file_path = os.path.join(folder_path, file_name)
|
| 26 |
+
np_array = np.load(file_path)[0]
|
| 27 |
+
npy_list.append(np_array)
|
| 28 |
+
return npy_list
|
| 29 |
+
|
| 30 |
+
def average_npy(npy_list):
|
| 31 |
+
"""
|
| 32 |
+
Compute the average of a list of numpy arrays.
|
| 33 |
+
"""
|
| 34 |
+
return np.mean(npy_list, axis=0)
|
| 35 |
+
|
| 36 |
+
def cosine_similarity(vec1, vec2):
|
| 37 |
+
"""
|
| 38 |
+
Compute cosine similarity between two numpy arrays.
|
| 39 |
+
"""
|
| 40 |
+
dot_product = np.dot(vec1, vec2)
|
| 41 |
+
|
| 42 |
+
norm_vec1 = np.linalg.norm(vec1)
|
| 43 |
+
norm_vec2 = np.linalg.norm(vec2)
|
| 44 |
+
|
| 45 |
+
cosine_sim = dot_product / (norm_vec1 * norm_vec2)
|
| 46 |
+
|
| 47 |
+
return cosine_sim
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def generate_preference_dict():
|
| 51 |
+
|
| 52 |
+
gt_feature_paths = []
|
| 53 |
+
for gt_feature_file in os.listdir(gt_feature_folder):
|
| 54 |
+
gt_feature_paths.append(os.path.join(gt_feature_folder, gt_feature_file))
|
| 55 |
+
gt_features = load_npy_files(gt_feature_paths)
|
| 56 |
+
gt_avg_feature = average_npy(gt_features)
|
| 57 |
+
|
| 58 |
+
output_feature_sim_dict = {}
|
| 59 |
+
for file in os.listdir(output_feature_folder):
|
| 60 |
+
output_feature_path = os.path.join(output_feature_folder, file)
|
| 61 |
+
output_feature = np.load(output_feature_path)[0]
|
| 62 |
+
sim = cosine_similarity(gt_avg_feature, output_feature)
|
| 63 |
+
output_feature_sim_dict[file[:-4]] = sim
|
| 64 |
+
|
| 65 |
+
threshold = int(len(output_feature_sim_dict) * data_select_portion)
|
| 66 |
+
sorted_output_files = sorted(output_feature_sim_dict.keys(), key=lambda item: output_feature_sim_dict[item], reverse=True)
|
| 67 |
+
|
| 68 |
+
chosen_index = 0
|
| 69 |
+
i = 0
|
| 70 |
+
chosen_abc_paths = []
|
| 71 |
+
while chosen_index < threshold and i < len(sorted_output_files):
|
| 72 |
+
|
| 73 |
+
chosen_flag = True
|
| 74 |
+
|
| 75 |
+
file = sorted_output_files[i]
|
| 76 |
+
output_interleaved_abc_path = os.path.join(output_interleaved_abc_folder, file + '.abc')
|
| 77 |
+
|
| 78 |
+
with open(output_interleaved_abc_path, 'r') as f:
|
| 79 |
+
abc_lines = f.readlines()
|
| 80 |
+
|
| 81 |
+
# check aligment
|
| 82 |
+
try:
|
| 83 |
+
abc_lines_unrotated = unrotate_abc(abc_lines)
|
| 84 |
+
barline_equal_flag, bar_no_equal_flag, bar_dur_equal_flag = check_alignment_unrotated(abc_lines_unrotated)
|
| 85 |
+
if not (barline_equal_flag and bar_no_equal_flag and bar_dur_equal_flag):
|
| 86 |
+
raise Exception
|
| 87 |
+
except:
|
| 88 |
+
chosen_flag = False
|
| 89 |
+
|
| 90 |
+
# check header: sheets where staves for the same instrument are not grouped together are excluded from the chosen set.
|
| 91 |
+
appeared_inst = set()
|
| 92 |
+
last_inst = ''
|
| 93 |
+
for line in abc_lines:
|
| 94 |
+
if line.startswith('V:') and 'nm=' in line:
|
| 95 |
+
match = re.search(r'nm="([^"]+)"', line)
|
| 96 |
+
if match:
|
| 97 |
+
inst = match.group(1)
|
| 98 |
+
if inst != last_inst and inst in appeared_inst:
|
| 99 |
+
chosen_flag = False
|
| 100 |
+
break
|
| 101 |
+
else:
|
| 102 |
+
last_inst = inst
|
| 103 |
+
appeared_inst.add(inst)
|
| 104 |
+
|
| 105 |
+
# check plagiarism: sheets with sim > 0.95 are excluded
|
| 106 |
+
output_feature_path = os.path.join(output_feature_folder, file + '.npy')
|
| 107 |
+
output_feature = np.load(output_feature_path)[0]
|
| 108 |
+
for gt_feature_file in os.listdir(gt_feature_folder):
|
| 109 |
+
gt_feature_path = os.path.join(gt_feature_folder, gt_feature_file)
|
| 110 |
+
gt_feature = np.load(gt_feature_path)[0]
|
| 111 |
+
sim = cosine_similarity(output_feature, gt_feature)
|
| 112 |
+
if sim > 0.95:
|
| 113 |
+
chosen_flag = False
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
if chosen_flag:
|
| 117 |
+
original_abc_path = os.path.join(output_original_abc_folder, file + '.abc')
|
| 118 |
+
chosen_abc_paths.append(original_abc_path)
|
| 119 |
+
chosen_index += 1
|
| 120 |
+
else:
|
| 121 |
+
print(file, 'skipped')
|
| 122 |
+
|
| 123 |
+
i += 1
|
| 124 |
+
|
| 125 |
+
rejected_abc_paths = [os.path.join(output_original_abc_folder, file + '.abc') for file in sorted_output_files[-threshold:]]
|
| 126 |
+
preference_dict = {'chosen': chosen_abc_paths, 'rejected': rejected_abc_paths}
|
| 127 |
+
|
| 128 |
+
with open(data_index_path, 'w') as w:
|
| 129 |
+
json.dump(preference_dict, w, indent=4)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
|
| 134 |
+
generate_preference_dict()
|
| 135 |
+
|
| 136 |
+
|
demo.ipynb
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "6e5cf1e7-c275-4929-9c44-ec48e26a2d4d",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import os\n",
|
| 11 |
+
"import re\n",
|
| 12 |
+
"import time\n",
|
| 13 |
+
"import torch\n",
|
| 14 |
+
"import torch\n",
|
| 15 |
+
"import random\n",
|
| 16 |
+
"import bisect\n",
|
| 17 |
+
"import json\n",
|
| 18 |
+
"from pathlib import Path\n",
|
| 19 |
+
"from tokenizers import Tokenizer\n",
|
| 20 |
+
"from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, LlamaModel, LlamaForCausalLM, PreTrainedModel \n",
|
| 21 |
+
"from samplings import top_p_sampling, top_k_sampling, temperature_sampling\n",
|
| 22 |
+
"from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern\n",
|
| 23 |
+
"from abctoolkit.transpose import Note_list, Pitch_sign_list\n",
|
| 24 |
+
"from abctoolkit.duration import calculate_bartext_duration"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"id": "00fd2ebb-6e53-4038-af85-f9c5f02fde0e",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"# Configurations for inference\n",
|
| 35 |
+
"INFERENCE_WEIGHTS_PATH = '../weights/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth' # Path to weights for inference# Folder to save output files\n",
|
| 36 |
+
"TOP_K = 9 # Top k for sampling\n",
|
| 37 |
+
"TOP_P = 0.9 # Top p for sampling\n",
|
| 38 |
+
"TEMPERATURE = 1.2 # Temperature for sampling\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"# Configurations for model\n",
|
| 41 |
+
"PATCH_STREAM = True # Stream training / inference\n",
|
| 42 |
+
"PATCH_SIZE = 16 # Patch Size\n",
|
| 43 |
+
"PATCH_LENGTH = 1024 # Patch Length\n",
|
| 44 |
+
"CHAR_NUM_LAYERS = 6 # Number of layers in the decoder\n",
|
| 45 |
+
"PATCH_NUM_LAYERS = 20 # Number of layers in the encoder\n",
|
| 46 |
+
"HIDDEN_SIZE = 1280 # Hidden Size\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"device = torch.device(\"cuda\")"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": null,
|
| 54 |
+
"id": "fb70eb19-8b9c-4864-b711-7a0395b42c49",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"class Patchilizer:\n",
|
| 59 |
+
" def __init__(self, stream=PATCH_STREAM):\n",
|
| 60 |
+
" self.stream = stream\n",
|
| 61 |
+
" self.delimiters = [\"|:\", \"::\", \":|\", \"[|\", \"||\", \"|]\", \"|\"]\n",
|
| 62 |
+
" self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'\n",
|
| 63 |
+
" self.bos_token_id = 1\n",
|
| 64 |
+
" self.eos_token_id = 2\n",
|
| 65 |
+
" self.special_token_id = 0\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" def split_bars(self, body_lines):\n",
|
| 68 |
+
" \"\"\"\n",
|
| 69 |
+
" Split a body of music into individual bars.\n",
|
| 70 |
+
" \"\"\"\n",
|
| 71 |
+
" new_bars = []\n",
|
| 72 |
+
" try:\n",
|
| 73 |
+
" for line in body_lines:\n",
|
| 74 |
+
" line_bars = re.split(self.regexPattern, line)\n",
|
| 75 |
+
" line_bars = list(filter(None, line_bars))\n",
|
| 76 |
+
" new_line_bars = []\n",
|
| 77 |
+
"\n",
|
| 78 |
+
" if len(line_bars) == 1:\n",
|
| 79 |
+
" new_line_bars = line_bars\n",
|
| 80 |
+
" else:\n",
|
| 81 |
+
" if line_bars[0] in self.delimiters:\n",
|
| 82 |
+
" new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]\n",
|
| 83 |
+
" else:\n",
|
| 84 |
+
" new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]\n",
|
| 85 |
+
" if 'V' not in new_line_bars[-1]:\n",
|
| 86 |
+
" new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\\n 的组合\n",
|
| 87 |
+
" new_line_bars = new_line_bars[:-1]\n",
|
| 88 |
+
" new_bars += new_line_bars\n",
|
| 89 |
+
" except:\n",
|
| 90 |
+
" pass\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" return new_bars\n",
|
| 93 |
+
"\n",
|
| 94 |
+
" def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):\n",
|
| 95 |
+
" if not generate_last and len(abc_text) % patch_size != 0:\n",
|
| 96 |
+
" abc_text += chr(self.eos_token_id)\n",
|
| 97 |
+
" patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]\n",
|
| 98 |
+
" return patches\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" def patch2chars(self, patch):\n",
|
| 101 |
+
" \"\"\"\n",
|
| 102 |
+
" Convert a patch into a bar.\n",
|
| 103 |
+
" \"\"\"\n",
|
| 104 |
+
" bytes = ''\n",
|
| 105 |
+
" for idx in patch:\n",
|
| 106 |
+
" if idx == self.eos_token_id:\n",
|
| 107 |
+
" break\n",
|
| 108 |
+
" if idx < self.eos_token_id:\n",
|
| 109 |
+
" pass\n",
|
| 110 |
+
" bytes += chr(idx)\n",
|
| 111 |
+
" return bytes\n",
|
| 112 |
+
" \n",
|
| 113 |
+
"\n",
|
| 114 |
+
" def patchilize_metadata(self, metadata_lines):\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" metadata_patches = []\n",
|
| 117 |
+
" for line in metadata_lines:\n",
|
| 118 |
+
" metadata_patches += self.split_patches(line)\n",
|
| 119 |
+
"\n",
|
| 120 |
+
" return metadata_patches\n",
|
| 121 |
+
" \n",
|
| 122 |
+
" def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):\n",
|
| 123 |
+
"\n",
|
| 124 |
+
" tunebody_patches = []\n",
|
| 125 |
+
" bars = self.split_bars(tunebody_lines)\n",
|
| 126 |
+
" if encode_mode == 'train':\n",
|
| 127 |
+
" for bar in bars:\n",
|
| 128 |
+
" tunebody_patches += self.split_patches(bar)\n",
|
| 129 |
+
" elif encode_mode == 'generate':\n",
|
| 130 |
+
" for bar in bars[:-1]:\n",
|
| 131 |
+
" tunebody_patches += self.split_patches(bar)\n",
|
| 132 |
+
" tunebody_patches += self.split_patches(bars[-1], generate_last=True)\n",
|
| 133 |
+
" \n",
|
| 134 |
+
" return tunebody_patches\n",
|
| 135 |
+
"\n",
|
| 136 |
+
" def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):\n",
|
| 137 |
+
"\n",
|
| 138 |
+
" lines = abc_text.split('\\n')\n",
|
| 139 |
+
" lines = list(filter(None, lines))\n",
|
| 140 |
+
" lines = [line + '\\n' for line in lines]\n",
|
| 141 |
+
"\n",
|
| 142 |
+
" tunebody_index = -1\n",
|
| 143 |
+
" for i, line in enumerate(lines):\n",
|
| 144 |
+
" if '[V:' in line:\n",
|
| 145 |
+
" tunebody_index = i\n",
|
| 146 |
+
" break\n",
|
| 147 |
+
"\n",
|
| 148 |
+
" metadata_lines = lines[ : tunebody_index]\n",
|
| 149 |
+
" tunebody_lines = lines[tunebody_index : ]\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" if self.stream:\n",
|
| 152 |
+
" tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in\n",
|
| 153 |
+
" enumerate(tunebody_lines)] \n",
|
| 154 |
+
"\n",
|
| 155 |
+
" metadata_patches = self.patchilize_metadata(metadata_lines)\n",
|
| 156 |
+
" tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" if add_special_patches:\n",
|
| 159 |
+
" bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)\n",
|
| 160 |
+
" eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)\n",
|
| 161 |
+
"\n",
|
| 162 |
+
" metadata_patches = [bos_patch] + metadata_patches\n",
|
| 163 |
+
" tunebody_patches = tunebody_patches + [eos_patch]\n",
|
| 164 |
+
"\n",
|
| 165 |
+
" if self.stream:\n",
|
| 166 |
+
" if len(metadata_patches) + len(tunebody_patches) > patch_length:\n",
|
| 167 |
+
" available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\\n' in patch]\n",
|
| 168 |
+
" line_index_for_cut_index = list(range(len(available_cut_indexes))) \n",
|
| 169 |
+
" end_index = len(metadata_patches) + len(tunebody_patches) - patch_length\n",
|
| 170 |
+
" biggest_index = bisect.bisect_left(available_cut_indexes, end_index) \n",
|
| 171 |
+
" available_cut_indexes = available_cut_indexes[:biggest_index + 1]\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" if len(available_cut_indexes) == 1:\n",
|
| 174 |
+
" choices = ['head']\n",
|
| 175 |
+
" elif len(available_cut_indexes) == 2:\n",
|
| 176 |
+
" choices = ['head', 'tail']\n",
|
| 177 |
+
" else:\n",
|
| 178 |
+
" choices = ['head', 'tail', 'middle']\n",
|
| 179 |
+
" choice = random.choice(choices)\n",
|
| 180 |
+
" if choice == 'head':\n",
|
| 181 |
+
" patches = metadata_patches + tunebody_patches[0:]\n",
|
| 182 |
+
" else:\n",
|
| 183 |
+
" if choice == 'tail':\n",
|
| 184 |
+
" cut_index = len(available_cut_indexes) - 1\n",
|
| 185 |
+
" else:\n",
|
| 186 |
+
" cut_index = random.choice(range(1, len(available_cut_indexes) - 1))\n",
|
| 187 |
+
"\n",
|
| 188 |
+
" line_index = line_index_for_cut_index[cut_index] \n",
|
| 189 |
+
" stream_tunebody_lines = tunebody_lines[line_index : ]\n",
|
| 190 |
+
" \n",
|
| 191 |
+
" stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')\n",
|
| 192 |
+
" if add_special_patches:\n",
|
| 193 |
+
" stream_tunebody_patches = stream_tunebody_patches + [eos_patch]\n",
|
| 194 |
+
" patches = metadata_patches + stream_tunebody_patches\n",
|
| 195 |
+
" else:\n",
|
| 196 |
+
" patches = metadata_patches + tunebody_patches\n",
|
| 197 |
+
" else:\n",
|
| 198 |
+
" patches = metadata_patches + tunebody_patches\n",
|
| 199 |
+
"\n",
|
| 200 |
+
" if cut: \n",
|
| 201 |
+
" patches = patches[ : patch_length]\n",
|
| 202 |
+
" else: \n",
|
| 203 |
+
" pass\n",
|
| 204 |
+
"\n",
|
| 205 |
+
" # encode to ids\n",
|
| 206 |
+
" id_patches = []\n",
|
| 207 |
+
" for patch in patches:\n",
|
| 208 |
+
" id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))\n",
|
| 209 |
+
" id_patches.append(id_patch)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" return id_patches\n",
|
| 212 |
+
"\n",
|
| 213 |
+
" def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):\n",
|
| 214 |
+
"\n",
|
| 215 |
+
" lines = abc_code.split('\\n')\n",
|
| 216 |
+
" lines = list(filter(None, lines))\n",
|
| 217 |
+
" \n",
|
| 218 |
+
" tunebody_index = None\n",
|
| 219 |
+
" for i, line in enumerate(lines):\n",
|
| 220 |
+
" if line.startswith('[V:') or line.startswith('[r:'):\n",
|
| 221 |
+
" tunebody_index = i\n",
|
| 222 |
+
" break\n",
|
| 223 |
+
" \n",
|
| 224 |
+
" metadata_lines = lines[ : tunebody_index]\n",
|
| 225 |
+
" tunebody_lines = lines[tunebody_index : ] \n",
|
| 226 |
+
" \n",
|
| 227 |
+
" metadata_lines = [line + '\\n' for line in metadata_lines]\n",
|
| 228 |
+
" if self.stream:\n",
|
| 229 |
+
" if not abc_code.endswith('\\n'):\n",
|
| 230 |
+
" tunebody_lines = [tunebody_lines[i] + '\\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]\n",
|
| 231 |
+
" else:\n",
|
| 232 |
+
" tunebody_lines = [tunebody_lines[i] + '\\n' for i in range(len(tunebody_lines))]\n",
|
| 233 |
+
" else:\n",
|
| 234 |
+
" tunebody_lines = [line + '\\n' for line in tunebody_lines]\n",
|
| 235 |
+
" \n",
|
| 236 |
+
" metadata_patches = self.patchilize_metadata(metadata_lines)\n",
|
| 237 |
+
" tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')\n",
|
| 238 |
+
" \n",
|
| 239 |
+
" if add_special_patches:\n",
|
| 240 |
+
" bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" metadata_patches = [bos_patch] + metadata_patches\n",
|
| 243 |
+
" \n",
|
| 244 |
+
" patches = metadata_patches + tunebody_patches\n",
|
| 245 |
+
" patches = patches[ : patch_length]\n",
|
| 246 |
+
"\n",
|
| 247 |
+
" # encode to ids\n",
|
| 248 |
+
" id_patches = []\n",
|
| 249 |
+
" for patch in patches:\n",
|
| 250 |
+
" if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):\n",
|
| 251 |
+
" id_patch = [ord(c) for c in patch]\n",
|
| 252 |
+
" else:\n",
|
| 253 |
+
" id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))\n",
|
| 254 |
+
" id_patches.append(id_patch)\n",
|
| 255 |
+
" \n",
|
| 256 |
+
" return id_patches\n",
|
| 257 |
+
"\n",
|
| 258 |
+
" def decode(self, patches):\n",
|
| 259 |
+
" \"\"\"\n",
|
| 260 |
+
" Decode patches into music.\n",
|
| 261 |
+
" \"\"\"\n",
|
| 262 |
+
" return ''.join(self.patch2chars(patch) for patch in patches)\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"class PatchLevelDecoder(PreTrainedModel):\n",
|
| 266 |
+
" \"\"\"\n",
|
| 267 |
+
" A Patch-level Decoder model for generating patch features in an auto-regressive manner. \n",
|
| 268 |
+
" It inherits PreTrainedModel from transformers.\n",
|
| 269 |
+
" \"\"\"\n",
|
| 270 |
+
" def __init__(self, config):\n",
|
| 271 |
+
" super().__init__(config)\n",
|
| 272 |
+
" self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)\n",
|
| 273 |
+
" torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)\n",
|
| 274 |
+
" self.base = GPT2Model(config)\n",
|
| 275 |
+
"\n",
|
| 276 |
+
" def forward(self,\n",
|
| 277 |
+
" patches: torch.Tensor,\n",
|
| 278 |
+
" masks=None) -> torch.Tensor:\n",
|
| 279 |
+
" \"\"\"\n",
|
| 280 |
+
" The forward pass of the patch-level decoder model.\n",
|
| 281 |
+
" :param patches: the patches to be encoded\n",
|
| 282 |
+
" :param masks: the masks for the patches\n",
|
| 283 |
+
" :return: the encoded patches\n",
|
| 284 |
+
" \"\"\"\n",
|
| 285 |
+
" patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)\n",
|
| 286 |
+
" patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))\n",
|
| 287 |
+
" patches = self.patch_embedding(patches.to(self.device))\n",
|
| 288 |
+
"\n",
|
| 289 |
+
" if masks==None:\n",
|
| 290 |
+
" return self.base(inputs_embeds=patches)\n",
|
| 291 |
+
" else:\n",
|
| 292 |
+
" return self.base(inputs_embeds=patches,\n",
|
| 293 |
+
" attention_mask=masks)\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"class CharLevelDecoder(PreTrainedModel):\n",
|
| 297 |
+
" \"\"\"\n",
|
| 298 |
+
" A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner\n",
|
| 299 |
+
" based on the encoded patch features. It inherits PreTrainedModel from transformers.\n",
|
| 300 |
+
" \"\"\"\n",
|
| 301 |
+
" def __init__(self, config):\n",
|
| 302 |
+
" super().__init__(config)\n",
|
| 303 |
+
" self.special_token_id = 0\n",
|
| 304 |
+
" self.bos_token_id = 1\n",
|
| 305 |
+
"\n",
|
| 306 |
+
" self.base = GPT2LMHeadModel(config)\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" def forward(self,\n",
|
| 309 |
+
" encoded_patches: torch.Tensor,\n",
|
| 310 |
+
" target_patches: torch.Tensor):\n",
|
| 311 |
+
" \"\"\"\n",
|
| 312 |
+
" The forward pass of the char-level decoder model.\n",
|
| 313 |
+
" :param encoded_patches: the encoded patches\n",
|
| 314 |
+
" :param target_patches: the target patches\n",
|
| 315 |
+
" :return: the output of the model\n",
|
| 316 |
+
" \"\"\"\n",
|
| 317 |
+
" # preparing the labels for model training\n",
|
| 318 |
+
" target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)\n",
|
| 319 |
+
" # print('target_patches shape:', target_patches.shape)\n",
|
| 320 |
+
"\n",
|
| 321 |
+
" target_masks = target_patches == self.special_token_id\n",
|
| 322 |
+
" labels = target_patches.clone().masked_fill_(target_masks, -100)\n",
|
| 323 |
+
"\n",
|
| 324 |
+
" # masking the labels for model training\n",
|
| 325 |
+
" target_masks = torch.ones_like(labels)\n",
|
| 326 |
+
" target_masks = target_masks.masked_fill_(labels == -100, 0)\n",
|
| 327 |
+
"\n",
|
| 328 |
+
" # select patches\n",
|
| 329 |
+
" if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:\n",
|
| 330 |
+
" indices = list(range(len(target_patches)))\n",
|
| 331 |
+
" random.shuffle(indices)\n",
|
| 332 |
+
" selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])\n",
|
| 333 |
+
"\n",
|
| 334 |
+
" target_patches = target_patches[selected_indices,:]\n",
|
| 335 |
+
" target_masks = target_masks[selected_indices,:]\n",
|
| 336 |
+
" encoded_patches = encoded_patches[selected_indices,:]\n",
|
| 337 |
+
"\n",
|
| 338 |
+
" # get input embeddings\n",
|
| 339 |
+
" inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" # concatenate the encoded patches with the input embeddings\n",
|
| 342 |
+
" inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)\n",
|
| 343 |
+
"\n",
|
| 344 |
+
" output = self.base(inputs_embeds=inputs_embeds, \n",
|
| 345 |
+
" attention_mask=target_masks,\n",
|
| 346 |
+
" labels=labels)\n",
|
| 347 |
+
" # output_hidden_states=True=True)\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" return output\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" def generate(self,\n",
|
| 352 |
+
" encoded_patch: torch.Tensor, # [hidden_size]\n",
|
| 353 |
+
" tokens: torch.Tensor): # [1]\n",
|
| 354 |
+
" \"\"\"\n",
|
| 355 |
+
" The generate function for generating a patch based on the encoded patch and already generated tokens.\n",
|
| 356 |
+
" :param encoded_patch: the encoded patch\n",
|
| 357 |
+
" :param tokens: already generated tokens in the patch\n",
|
| 358 |
+
" :return: the probability distribution of next token\n",
|
| 359 |
+
" \"\"\"\n",
|
| 360 |
+
" encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]\n",
|
| 361 |
+
" tokens = tokens.reshape(1, -1)\n",
|
| 362 |
+
"\n",
|
| 363 |
+
" # Get input embeddings\n",
|
| 364 |
+
" tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)\n",
|
| 365 |
+
"\n",
|
| 366 |
+
" # Concatenate the encoded patch with the input embeddings\n",
|
| 367 |
+
" tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)\n",
|
| 368 |
+
" \n",
|
| 369 |
+
" # Get output from model\n",
|
| 370 |
+
" outputs = self.base(inputs_embeds=tokens)\n",
|
| 371 |
+
" \n",
|
| 372 |
+
" # Get probabilities of next token\n",
|
| 373 |
+
" probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)\n",
|
| 374 |
+
"\n",
|
| 375 |
+
" return probs\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"class NotaGenLMHeadModel(PreTrainedModel):\n",
|
| 378 |
+
" \"\"\"\n",
|
| 379 |
+
" NotaGen is a language model with a hierarchical structure.\n",
|
| 380 |
+
" It includes a patch-level decoder and a char-level decoder.\n",
|
| 381 |
+
" The patch-level decoder is used to generate patch features in an auto-regressive manner.\n",
|
| 382 |
+
" The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.\n",
|
| 383 |
+
" It inherits PreTrainedModel from transformers.\n",
|
| 384 |
+
" \"\"\"\n",
|
| 385 |
+
" def __init__(self, encoder_config, decoder_config):\n",
|
| 386 |
+
" super().__init__(encoder_config)\n",
|
| 387 |
+
" self.special_token_id = 0\n",
|
| 388 |
+
" self.bos_token_id = 1\n",
|
| 389 |
+
" self.eos_token_id = 2\n",
|
| 390 |
+
" self.patch_level_decoder = PatchLevelDecoder(encoder_config)\n",
|
| 391 |
+
" self.char_level_decoder = CharLevelDecoder(decoder_config)\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" def forward(self,\n",
|
| 394 |
+
" patches: torch.Tensor,\n",
|
| 395 |
+
" masks: torch.Tensor):\n",
|
| 396 |
+
" \"\"\"\n",
|
| 397 |
+
" The forward pass of the bGPT model.\n",
|
| 398 |
+
" :param patches: the patches to be encoded\n",
|
| 399 |
+
" :param masks: the masks for the patches\n",
|
| 400 |
+
" :return: the decoded patches\n",
|
| 401 |
+
" \"\"\"\n",
|
| 402 |
+
" patches = patches.reshape(len(patches), -1, PATCH_SIZE)\n",
|
| 403 |
+
" encoded_patches = self.patch_level_decoder(patches, masks)[\"last_hidden_state\"]\n",
|
| 404 |
+
" \n",
|
| 405 |
+
" left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)\n",
|
| 406 |
+
" masks[:, 0] = 0\n",
|
| 407 |
+
" \n",
|
| 408 |
+
" encoded_patches = encoded_patches[left_shift_masks == 1]\n",
|
| 409 |
+
" patches = patches[masks == 1] \n",
|
| 410 |
+
"\n",
|
| 411 |
+
" return self.char_level_decoder(encoded_patches, patches)\n",
|
| 412 |
+
" \n",
|
| 413 |
+
" def generate(self,\n",
|
| 414 |
+
" patches: torch.Tensor,\n",
|
| 415 |
+
" top_k=0,\n",
|
| 416 |
+
" top_p=1,\n",
|
| 417 |
+
" temperature=1.0):\n",
|
| 418 |
+
" \"\"\"\n",
|
| 419 |
+
" The generate function for generating patches based on patches.\n",
|
| 420 |
+
" :param patches: the patches to be encoded\n",
|
| 421 |
+
" :param top_k: the top k for sampling\n",
|
| 422 |
+
" :param top_p: the top p for sampling\n",
|
| 423 |
+
" :param temperature: the temperature for sampling\n",
|
| 424 |
+
" :return: the generated patches\n",
|
| 425 |
+
" \"\"\"\n",
|
| 426 |
+
" if patches.shape[-1] % PATCH_SIZE != 0:\n",
|
| 427 |
+
" tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)\n",
|
| 428 |
+
" tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)\n",
|
| 429 |
+
" patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]\n",
|
| 430 |
+
" else:\n",
|
| 431 |
+
" tokens = torch.tensor([self.bos_token_id], device=self.device)\n",
|
| 432 |
+
"\n",
|
| 433 |
+
" patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]\n",
|
| 434 |
+
" encoded_patches = self.patch_level_decoder(patches)[\"last_hidden_state\"] # [bs, seq, hidden_size]\n",
|
| 435 |
+
" generated_patch = [] \n",
|
| 436 |
+
"\n",
|
| 437 |
+
" while True:\n",
|
| 438 |
+
" prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]\n",
|
| 439 |
+
" prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]\n",
|
| 440 |
+
" prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]\n",
|
| 441 |
+
" token = temperature_sampling(prob, temperature=temperature) # int\n",
|
| 442 |
+
" char = chr(token)\n",
|
| 443 |
+
" generated_patch.append(token)\n",
|
| 444 |
+
"\n",
|
| 445 |
+
" if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:\n",
|
| 446 |
+
" break\n",
|
| 447 |
+
" else:\n",
|
| 448 |
+
" tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)\n",
|
| 449 |
+
" \n",
|
| 450 |
+
" return generated_patch\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"def clean_to_abc(raw_text, unreduce=True, output_path='output.abc'):\n",
|
| 453 |
+
" # Remove [r:x/y] tags\n",
|
| 454 |
+
" cleaned = re.sub(r'\\[r:\\d+/\\d+\\]', '', raw_text)\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" # Add required ABC headers\n",
|
| 457 |
+
" lines = cleaned.strip().splitlines()\n",
|
| 458 |
+
" header_inserted = False\n",
|
| 459 |
+
" abc_lines = []\n",
|
| 460 |
+
" for line in lines:\n",
|
| 461 |
+
" if not header_inserted and line.startswith('%%score'):\n",
|
| 462 |
+
" abc_lines.insert(0, 'T:Generated\\n')\n",
|
| 463 |
+
" abc_lines.insert(0, 'X:1\\n')\n",
|
| 464 |
+
" header_inserted = True\n",
|
| 465 |
+
" abc_lines.append(line if line.endswith('\\n') else line + '\\n')\n",
|
| 466 |
+
"\n",
|
| 467 |
+
" # Optional: fill missing rests\n",
|
| 468 |
+
" if unreduce:\n",
|
| 469 |
+
" try:\n",
|
| 470 |
+
" abc_lines = rest_unreduce(abc_lines)\n",
|
| 471 |
+
" except Exception as e:\n",
|
| 472 |
+
" print(\"Unreduce failed:\", e)\n",
|
| 473 |
+
"\n",
|
| 474 |
+
" # Save to .abc file\n",
|
| 475 |
+
" Path(output_path).write_text(''.join(abc_lines), encoding='utf-8')\n",
|
| 476 |
+
" print(f\"Saved cleaned ABC to {output_path}\")\n",
|
| 477 |
+
" return output_path"
|
| 478 |
+
]
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"cell_type": "code",
|
| 482 |
+
"execution_count": null,
|
| 483 |
+
"id": "6d126533-a9a1-48a5-9b1b-be6da37a55ad",
|
| 484 |
+
"metadata": {},
|
| 485 |
+
"outputs": [],
|
| 486 |
+
"source": [
|
| 487 |
+
"Note_list = Note_list + ['z', 'x']\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"patchilizer = Patchilizer()\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,\n",
|
| 492 |
+
" max_length=PATCH_LENGTH,\n",
|
| 493 |
+
" max_position_embeddings=PATCH_LENGTH,\n",
|
| 494 |
+
" n_embd=HIDDEN_SIZE,\n",
|
| 495 |
+
" num_attention_heads=HIDDEN_SIZE // 64,\n",
|
| 496 |
+
" vocab_size=1)\n",
|
| 497 |
+
"byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,\n",
|
| 498 |
+
" max_length=PATCH_SIZE + 1,\n",
|
| 499 |
+
" max_position_embeddings=PATCH_SIZE + 1,\n",
|
| 500 |
+
" hidden_size=HIDDEN_SIZE,\n",
|
| 501 |
+
" num_attention_heads=HIDDEN_SIZE // 64,\n",
|
| 502 |
+
" vocab_size=128)\n",
|
| 503 |
+
"\n",
|
| 504 |
+
"model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)\n",
|
| 505 |
+
"\n",
|
| 506 |
+
"def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):\n",
|
| 507 |
+
" \"\"\"\n",
|
| 508 |
+
" Prepare model for k-bit training.\n",
|
| 509 |
+
" Features include:\n",
|
| 510 |
+
" 1. Convert model to mixed precision (FP16).\n",
|
| 511 |
+
" 2. Disable unnecessary gradient computations.\n",
|
| 512 |
+
" 3. Enable gradient checkpointing (optional).\n",
|
| 513 |
+
" \"\"\"\n",
|
| 514 |
+
" # Convert model to mixed precision\n",
|
| 515 |
+
" model = model.to(dtype=torch.float16)\n",
|
| 516 |
+
"\n",
|
| 517 |
+
" # Disable gradients for embedding layers\n",
|
| 518 |
+
" for param in model.parameters():\n",
|
| 519 |
+
" if param.dtype == torch.float32:\n",
|
| 520 |
+
" param.requires_grad = False\n",
|
| 521 |
+
"\n",
|
| 522 |
+
" # Enable gradient checkpointing\n",
|
| 523 |
+
" if use_gradient_checkpointing:\n",
|
| 524 |
+
" model.gradient_checkpointing_enable()\n",
|
| 525 |
+
"\n",
|
| 526 |
+
" return model\n",
|
| 527 |
+
"\n",
|
| 528 |
+
"\n",
|
| 529 |
+
"model = prepare_model_for_kbit_training(\n",
|
| 530 |
+
" model,\n",
|
| 531 |
+
" use_gradient_checkpointing=False \n",
|
| 532 |
+
")\n",
|
| 533 |
+
"\n",
|
| 534 |
+
"print(\"Parameter Number: \" + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"checkpoint = torch.load(INFERENCE_WEIGHTS_PATH, map_location=torch.device(device))\n",
|
| 537 |
+
"model.load_state_dict(checkpoint['model'])\n",
|
| 538 |
+
"model = model.to(device)\n",
|
| 539 |
+
"model.eval()\n",
|
| 540 |
+
"\n",
|
| 541 |
+
"def complete_brackets(s):\n",
|
| 542 |
+
" stack = []\n",
|
| 543 |
+
" bracket_map = {'{': '}', '[': ']', '(': ')'}\n",
|
| 544 |
+
" \n",
|
| 545 |
+
" # Iterate through each character, handle bracket matching\n",
|
| 546 |
+
" for char in s:\n",
|
| 547 |
+
" if char in bracket_map:\n",
|
| 548 |
+
" stack.append(char)\n",
|
| 549 |
+
" elif char in bracket_map.values():\n",
|
| 550 |
+
" # Find the corresponding left bracket\n",
|
| 551 |
+
" for key, value in bracket_map.items():\n",
|
| 552 |
+
" if value == char:\n",
|
| 553 |
+
" if stack and stack[-1] == key:\n",
|
| 554 |
+
" stack.pop()\n",
|
| 555 |
+
" break # Found matching right bracket, process next character\n",
|
| 556 |
+
" \n",
|
| 557 |
+
" # Complete missing right brackets (in reverse order of remaining left brackets in stack)\n",
|
| 558 |
+
" completion = ''.join(bracket_map[c] for c in reversed(stack))\n",
|
| 559 |
+
" return s + completion\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"\n",
|
| 562 |
+
"def rest_unreduce(abc_lines):\n",
|
| 563 |
+
"\n",
|
| 564 |
+
" tunebody_index = None\n",
|
| 565 |
+
" for i in range(len(abc_lines)):\n",
|
| 566 |
+
" if abc_lines[i].startswith('%%score'):\n",
|
| 567 |
+
" abc_lines[i] = complete_brackets(abc_lines[i])\n",
|
| 568 |
+
" if '[V:' in abc_lines[i]:\n",
|
| 569 |
+
" tunebody_index = i\n",
|
| 570 |
+
" break\n",
|
| 571 |
+
"\n",
|
| 572 |
+
" metadata_lines = abc_lines[: tunebody_index]\n",
|
| 573 |
+
" tunebody_lines = abc_lines[tunebody_index:]\n",
|
| 574 |
+
"\n",
|
| 575 |
+
" part_symbol_list = []\n",
|
| 576 |
+
" voice_group_list = []\n",
|
| 577 |
+
" for line in metadata_lines:\n",
|
| 578 |
+
" if line.startswith('%%score'):\n",
|
| 579 |
+
" for round_bracket_match in re.findall(r'\\((.*?)\\)', line):\n",
|
| 580 |
+
" voice_group_list.append(round_bracket_match.split())\n",
|
| 581 |
+
" existed_voices = [item for sublist in voice_group_list for item in sublist]\n",
|
| 582 |
+
" if line.startswith('V:'):\n",
|
| 583 |
+
" symbol = line.split()[0]\n",
|
| 584 |
+
" part_symbol_list.append(symbol)\n",
|
| 585 |
+
" if symbol[2:] not in existed_voices:\n",
|
| 586 |
+
" voice_group_list.append([symbol[2:]])\n",
|
| 587 |
+
" z_symbol_list = [] # voices that use z as rest\n",
|
| 588 |
+
" x_symbol_list = [] # voices that use x as rest\n",
|
| 589 |
+
" for voice_group in voice_group_list:\n",
|
| 590 |
+
" z_symbol_list.append('V:' + voice_group[0])\n",
|
| 591 |
+
" for j in range(1, len(voice_group)):\n",
|
| 592 |
+
" x_symbol_list.append('V:' + voice_group[j])\n",
|
| 593 |
+
"\n",
|
| 594 |
+
" part_symbol_list.sort(key=lambda x: int(x[2:]))\n",
|
| 595 |
+
"\n",
|
| 596 |
+
" unreduced_tunebody_lines = []\n",
|
| 597 |
+
"\n",
|
| 598 |
+
" for i, line in enumerate(tunebody_lines):\n",
|
| 599 |
+
" unreduced_line = ''\n",
|
| 600 |
+
"\n",
|
| 601 |
+
" line = re.sub(r'^\\[r:[^\\]]*\\]', '', line)\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" pattern = r'\\[V:(\\d+)\\](.*?)(?=\\[V:|$)'\n",
|
| 604 |
+
" matches = re.findall(pattern, line)\n",
|
| 605 |
+
"\n",
|
| 606 |
+
" line_bar_dict = {}\n",
|
| 607 |
+
" for match in matches:\n",
|
| 608 |
+
" key = f'V:{match[0]}'\n",
|
| 609 |
+
" value = match[1]\n",
|
| 610 |
+
" line_bar_dict[key] = value\n",
|
| 611 |
+
"\n",
|
| 612 |
+
" # calculate duration and collect barline\n",
|
| 613 |
+
" dur_dict = {} \n",
|
| 614 |
+
" for symbol, bartext in line_bar_dict.items():\n",
|
| 615 |
+
" right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])\n",
|
| 616 |
+
" bartext = bartext[:-len(right_barline)]\n",
|
| 617 |
+
" try:\n",
|
| 618 |
+
" bar_dur = calculate_bartext_duration(bartext)\n",
|
| 619 |
+
" except:\n",
|
| 620 |
+
" bar_dur = None\n",
|
| 621 |
+
" if bar_dur is not None:\n",
|
| 622 |
+
" if bar_dur not in dur_dict.keys():\n",
|
| 623 |
+
" dur_dict[bar_dur] = 1\n",
|
| 624 |
+
" else:\n",
|
| 625 |
+
" dur_dict[bar_dur] += 1\n",
|
| 626 |
+
"\n",
|
| 627 |
+
" try:\n",
|
| 628 |
+
" ref_dur = max(dur_dict, key=dur_dict.get)\n",
|
| 629 |
+
" except:\n",
|
| 630 |
+
" pass # use last ref_dur\n",
|
| 631 |
+
"\n",
|
| 632 |
+
" if i == 0:\n",
|
| 633 |
+
" prefix_left_barline = line.split('[V:')[0]\n",
|
| 634 |
+
" else:\n",
|
| 635 |
+
" prefix_left_barline = ''\n",
|
| 636 |
+
"\n",
|
| 637 |
+
" for symbol in part_symbol_list:\n",
|
| 638 |
+
" if symbol in line_bar_dict.keys():\n",
|
| 639 |
+
" symbol_bartext = line_bar_dict[symbol]\n",
|
| 640 |
+
" else:\n",
|
| 641 |
+
" if symbol in z_symbol_list:\n",
|
| 642 |
+
" symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline\n",
|
| 643 |
+
" elif symbol in x_symbol_list:\n",
|
| 644 |
+
" symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline\n",
|
| 645 |
+
" unreduced_line += '[' + symbol + ']' + symbol_bartext\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" unreduced_tunebody_lines.append(unreduced_line + '\\n')\n",
|
| 648 |
+
"\n",
|
| 649 |
+
" unreduced_lines = metadata_lines + unreduced_tunebody_lines\n",
|
| 650 |
+
"\n",
|
| 651 |
+
" return unreduced_lines\n",
|
| 652 |
+
"\n",
|
| 653 |
+
"\n",
|
| 654 |
+
"def inference_patch(period, composer, instrumentation):\n",
|
| 655 |
+
"\n",
|
| 656 |
+
" prompt_lines=[\n",
|
| 657 |
+
" '%' + period + '\\n',\n",
|
| 658 |
+
" '%' + composer + '\\n',\n",
|
| 659 |
+
" '%' + instrumentation + '\\n']\n",
|
| 660 |
+
"\n",
|
| 661 |
+
" while True:\n",
|
| 662 |
+
"\n",
|
| 663 |
+
" failure_flag = False\n",
|
| 664 |
+
"\n",
|
| 665 |
+
" bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" start_time = time.time()\n",
|
| 668 |
+
"\n",
|
| 669 |
+
" prompt_patches = patchilizer.patchilize_metadata(prompt_lines)\n",
|
| 670 |
+
" byte_list = list(''.join(prompt_lines))\n",
|
| 671 |
+
" context_tunebody_byte_list = []\n",
|
| 672 |
+
" metadata_byte_list = []\n",
|
| 673 |
+
"\n",
|
| 674 |
+
" print(''.join(byte_list), end='')\n",
|
| 675 |
+
"\n",
|
| 676 |
+
" prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch\n",
|
| 677 |
+
" in prompt_patches]\n",
|
| 678 |
+
" prompt_patches.insert(0, bos_patch)\n",
|
| 679 |
+
"\n",
|
| 680 |
+
" input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)\n",
|
| 681 |
+
"\n",
|
| 682 |
+
" end_flag = False\n",
|
| 683 |
+
" cut_index = None\n",
|
| 684 |
+
"\n",
|
| 685 |
+
" tunebody_flag = False\n",
|
| 686 |
+
"\n",
|
| 687 |
+
" with torch.inference_mode():\n",
|
| 688 |
+
" \n",
|
| 689 |
+
" while True:\n",
|
| 690 |
+
" with torch.autocast(device_type='cuda', dtype=torch.float16):\n",
|
| 691 |
+
" predicted_patch = model.generate(input_patches.unsqueeze(0),\n",
|
| 692 |
+
" top_k=TOP_K,\n",
|
| 693 |
+
" top_p=TOP_P,\n",
|
| 694 |
+
" temperature=TEMPERATURE)\n",
|
| 695 |
+
" if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # 初次进入tunebody,必须以[r:0/开头\n",
|
| 696 |
+
" tunebody_flag = True\n",
|
| 697 |
+
" r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)\n",
|
| 698 |
+
" temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)\n",
|
| 699 |
+
" predicted_patch = model.generate(temp_input_patches.unsqueeze(0),\n",
|
| 700 |
+
" top_k=TOP_K,\n",
|
| 701 |
+
" top_p=TOP_P,\n",
|
| 702 |
+
" temperature=TEMPERATURE)\n",
|
| 703 |
+
" predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch\n",
|
| 704 |
+
" if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:\n",
|
| 705 |
+
" end_flag = True\n",
|
| 706 |
+
" break\n",
|
| 707 |
+
" next_patch = patchilizer.decode([predicted_patch])\n",
|
| 708 |
+
"\n",
|
| 709 |
+
" for char in next_patch:\n",
|
| 710 |
+
" byte_list.append(char)\n",
|
| 711 |
+
" if tunebody_flag:\n",
|
| 712 |
+
" context_tunebody_byte_list.append(char)\n",
|
| 713 |
+
" else:\n",
|
| 714 |
+
" metadata_byte_list.append(char)\n",
|
| 715 |
+
" print(char, end='')\n",
|
| 716 |
+
"\n",
|
| 717 |
+
" patch_end_flag = False\n",
|
| 718 |
+
" for j in range(len(predicted_patch)):\n",
|
| 719 |
+
" if patch_end_flag:\n",
|
| 720 |
+
" predicted_patch[j] = patchilizer.special_token_id\n",
|
| 721 |
+
" if predicted_patch[j] == patchilizer.eos_token_id:\n",
|
| 722 |
+
" patch_end_flag = True\n",
|
| 723 |
+
"\n",
|
| 724 |
+
" predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)\n",
|
| 725 |
+
" input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)\n",
|
| 726 |
+
"\n",
|
| 727 |
+
" if len(byte_list) > 102400:\n",
|
| 728 |
+
" failure_flag = True\n",
|
| 729 |
+
" break\n",
|
| 730 |
+
" if time.time() - start_time > 10 * 60: \n",
|
| 731 |
+
" failure_flag = True\n",
|
| 732 |
+
" break\n",
|
| 733 |
+
"\n",
|
| 734 |
+
" if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:\n",
|
| 735 |
+
" print('Stream generating...')\n",
|
| 736 |
+
"\n",
|
| 737 |
+
" metadata = ''.join(metadata_byte_list)\n",
|
| 738 |
+
" context_tunebody = ''.join(context_tunebody_byte_list)\n",
|
| 739 |
+
"\n",
|
| 740 |
+
" if '\\n' not in context_tunebody:\n",
|
| 741 |
+
" break # Generated content is all metadata, abandon\n",
|
| 742 |
+
"\n",
|
| 743 |
+
" context_tunebody_liness = context_tunebody.split('\\n')\n",
|
| 744 |
+
" if not context_tunebody.endswith('\\n'):\n",
|
| 745 |
+
" context_tunebody_liness = [context_tunebody_liness[i] + '\\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]\n",
|
| 746 |
+
" else:\n",
|
| 747 |
+
" context_tunebody_liness = [context_tunebody_liness[i] + '\\n' for i in range(len(context_tunebody_liness))]\n",
|
| 748 |
+
"\n",
|
| 749 |
+
" cut_index = len(context_tunebody_liness) // 2\n",
|
| 750 |
+
" abc_code_slice = metadata + ''.join(context_tunebody_liness[-cut_index:])\n",
|
| 751 |
+
"\n",
|
| 752 |
+
" input_patches = patchilizer.encode_generate(abc_code_slice)\n",
|
| 753 |
+
"\n",
|
| 754 |
+
" input_patches = [item for sublist in input_patches for item in sublist]\n",
|
| 755 |
+
" input_patches = torch.tensor([input_patches], device=device)\n",
|
| 756 |
+
" input_patches = input_patches.reshape(1, -1)\n",
|
| 757 |
+
"\n",
|
| 758 |
+
" context_tunebody_byte_list = list(''.join(context_tunebody_lines[-cut_index:]))\n",
|
| 759 |
+
"\n",
|
| 760 |
+
" if not failure_flag:\n",
|
| 761 |
+
" abc_text = ''.join(byte_list)\n",
|
| 762 |
+
"\n",
|
| 763 |
+
" # unreduce\n",
|
| 764 |
+
" abc_lines = abc_text.split('\\n')\n",
|
| 765 |
+
" abc_lines = list(filter(None, abc_lines))\n",
|
| 766 |
+
" abc_lines = [line + '\\n' for line in abc_lines]\n",
|
| 767 |
+
" try:\n",
|
| 768 |
+
" unreduced_abc_lines = rest_unreduce(abc_lines)\n",
|
| 769 |
+
" except:\n",
|
| 770 |
+
" failure_flag = True\n",
|
| 771 |
+
" pass\n",
|
| 772 |
+
" else:\n",
|
| 773 |
+
" unreduced_abc_lines = [line for line in unreduced_abc_lines if not(line.startswith('%') and not line.startswith('%%'))]\n",
|
| 774 |
+
" unreduced_abc_lines = ['X:1\\n'] + unreduced_abc_lines\n",
|
| 775 |
+
" unreduced_abc_text = ''.join(unreduced_abc_lines)\n",
|
| 776 |
+
" return unreduced_abc_text"
|
| 777 |
+
]
|
| 778 |
+
},
|
| 779 |
+
{
|
| 780 |
+
"cell_type": "code",
|
| 781 |
+
"execution_count": null,
|
| 782 |
+
"id": "502c4420-533b-43cc-80ca-b2c94cd4be04",
|
| 783 |
+
"metadata": {},
|
| 784 |
+
"outputs": [],
|
| 785 |
+
"source": [
|
| 786 |
+
"result = inference_patch('Classical', 'Beethoven, Ludwig van', 'Art Song')\n",
|
| 787 |
+
"\n",
|
| 788 |
+
"abc_lines = result.splitlines()\n",
|
| 789 |
+
"abc_lines = [line + '\\n' for line in abc_lines if line.strip()] # Add newlines and remove empty lines\n",
|
| 790 |
+
"\n",
|
| 791 |
+
"abc_lines = rest_unreduce(abc_lines)\n",
|
| 792 |
+
"\n",
|
| 793 |
+
"with open(\"output.abc\", \"w\", encoding=\"utf-8\") as f:\n",
|
| 794 |
+
" f.writelines(abc_lines)\n",
|
| 795 |
+
"\n",
|
| 796 |
+
"!python abc2xml.py -o . output.abc"
|
| 797 |
+
]
|
| 798 |
+
}
|
| 799 |
+
],
|
| 800 |
+
"metadata": {
|
| 801 |
+
"kernelspec": {
|
| 802 |
+
"display_name": "Python 3 (ipykernel)",
|
| 803 |
+
"language": "python",
|
| 804 |
+
"name": "python3"
|
| 805 |
+
},
|
| 806 |
+
"language_info": {
|
| 807 |
+
"codemirror_mode": {
|
| 808 |
+
"name": "ipython",
|
| 809 |
+
"version": 3
|
| 810 |
+
},
|
| 811 |
+
"file_extension": ".py",
|
| 812 |
+
"mimetype": "text/x-python",
|
| 813 |
+
"name": "python",
|
| 814 |
+
"nbconvert_exporter": "python",
|
| 815 |
+
"pygments_lexer": "ipython3",
|
| 816 |
+
"version": "3.10.0"
|
| 817 |
+
}
|
| 818 |
+
},
|
| 819 |
+
"nbformat": 4,
|
| 820 |
+
"nbformat_minor": 5
|
| 821 |
+
}
|
demo.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import sys
|
| 3 |
+
import threading
|
| 4 |
+
import queue
|
| 5 |
+
from io import TextIOBase
|
| 6 |
+
from inference import inference_patch
|
| 7 |
+
import datetime
|
| 8 |
+
import subprocess
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# Predefined valid combinations set
|
| 12 |
+
with open('prompts.txt', 'r') as f:
|
| 13 |
+
prompts = f.readlines()
|
| 14 |
+
valid_combinations = set()
|
| 15 |
+
for prompt in prompts:
|
| 16 |
+
prompt = prompt.strip()
|
| 17 |
+
parts = prompt.split('_')
|
| 18 |
+
valid_combinations.add((parts[0], parts[1], parts[2]))
|
| 19 |
+
|
| 20 |
+
# Generate available options
|
| 21 |
+
periods = sorted({p for p, _, _ in valid_combinations})
|
| 22 |
+
composers = sorted({c for _, c, _ in valid_combinations})
|
| 23 |
+
instruments = sorted({i for _, _, i in valid_combinations})
|
| 24 |
+
|
| 25 |
+
# Dynamic component updates
|
| 26 |
+
def update_components(period, composer):
|
| 27 |
+
if not period:
|
| 28 |
+
return [
|
| 29 |
+
gr.Dropdown(choices=[], value=None, interactive=False),
|
| 30 |
+
gr.Dropdown(choices=[], value=None, interactive=False)
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
valid_composers = sorted({c for p, c, _ in valid_combinations if p == period})
|
| 34 |
+
valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else []
|
| 35 |
+
|
| 36 |
+
return [
|
| 37 |
+
gr.Dropdown(
|
| 38 |
+
choices=valid_composers,
|
| 39 |
+
value=composer if composer in valid_composers else None,
|
| 40 |
+
interactive=True
|
| 41 |
+
),
|
| 42 |
+
gr.Dropdown(
|
| 43 |
+
choices=valid_instruments,
|
| 44 |
+
value=None,
|
| 45 |
+
interactive=bool(valid_instruments)
|
| 46 |
+
)
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RealtimeStream(TextIOBase):
|
| 51 |
+
def __init__(self, queue):
|
| 52 |
+
self.queue = queue
|
| 53 |
+
|
| 54 |
+
def write(self, text):
|
| 55 |
+
self.queue.put(text)
|
| 56 |
+
return len(text)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def save_and_convert(abc_content, period, composer, instrumentation):
|
| 60 |
+
if not all([period, composer, instrumentation]):
|
| 61 |
+
raise gr.Error("Please complete a valid generation first before saving")
|
| 62 |
+
|
| 63 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 64 |
+
prompt_str = f"{period}_{composer}_{instrumentation}"
|
| 65 |
+
filename_base = f"{timestamp}_{prompt_str}"
|
| 66 |
+
|
| 67 |
+
abc_filename = f"{filename_base}.abc"
|
| 68 |
+
with open(abc_filename, "w", encoding="utf-8") as f:
|
| 69 |
+
f.write(abc_content)
|
| 70 |
+
|
| 71 |
+
xml_filename = f"{filename_base}.xml"
|
| 72 |
+
try:
|
| 73 |
+
subprocess.run(
|
| 74 |
+
["python", "abc2xml.py", '-o', '.', abc_filename, ],
|
| 75 |
+
check=True,
|
| 76 |
+
capture_output=True,
|
| 77 |
+
text=True
|
| 78 |
+
)
|
| 79 |
+
except subprocess.CalledProcessError as e:
|
| 80 |
+
error_msg = f"Conversion failed: {e.stderr}" if e.stderr else "Unknown error"
|
| 81 |
+
raise gr.Error(f"ABC to XML conversion failed: {error_msg}. Please try to generate another composition.")
|
| 82 |
+
|
| 83 |
+
return f"Saved successfully: {abc_filename} -> {xml_filename}"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def generate_music(period, composer, instrumentation):
|
| 88 |
+
if (period, composer, instrumentation) not in valid_combinations:
|
| 89 |
+
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
|
| 90 |
+
|
| 91 |
+
output_queue = queue.Queue()
|
| 92 |
+
original_stdout = sys.stdout
|
| 93 |
+
sys.stdout = RealtimeStream(output_queue)
|
| 94 |
+
|
| 95 |
+
result_container = []
|
| 96 |
+
def run_inference():
|
| 97 |
+
try:
|
| 98 |
+
result_container.append(inference_patch(period, composer, instrumentation))
|
| 99 |
+
finally:
|
| 100 |
+
sys.stdout = original_stdout
|
| 101 |
+
|
| 102 |
+
thread = threading.Thread(target=run_inference)
|
| 103 |
+
thread.start()
|
| 104 |
+
|
| 105 |
+
process_output = ""
|
| 106 |
+
while thread.is_alive():
|
| 107 |
+
try:
|
| 108 |
+
text = output_queue.get(timeout=0.1)
|
| 109 |
+
process_output += text
|
| 110 |
+
yield process_output, None
|
| 111 |
+
except queue.Empty:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
while not output_queue.empty():
|
| 115 |
+
text = output_queue.get()
|
| 116 |
+
process_output += text
|
| 117 |
+
yield process_output, None
|
| 118 |
+
|
| 119 |
+
final_result = result_container[0] if result_container else ""
|
| 120 |
+
yield process_output, final_result
|
| 121 |
+
|
| 122 |
+
with gr.Blocks() as demo:
|
| 123 |
+
gr.Markdown("## NotaGen")
|
| 124 |
+
|
| 125 |
+
with gr.Row():
|
| 126 |
+
# 左侧栏
|
| 127 |
+
with gr.Column():
|
| 128 |
+
period_dd = gr.Dropdown(
|
| 129 |
+
choices=periods,
|
| 130 |
+
value=None,
|
| 131 |
+
label="Period",
|
| 132 |
+
interactive=True
|
| 133 |
+
)
|
| 134 |
+
composer_dd = gr.Dropdown(
|
| 135 |
+
choices=[],
|
| 136 |
+
value=None,
|
| 137 |
+
label="Composer",
|
| 138 |
+
interactive=False
|
| 139 |
+
)
|
| 140 |
+
instrument_dd = gr.Dropdown(
|
| 141 |
+
choices=[],
|
| 142 |
+
value=None,
|
| 143 |
+
label="Instrumentation",
|
| 144 |
+
interactive=False
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
generate_btn = gr.Button("Generate!", variant="primary")
|
| 148 |
+
|
| 149 |
+
process_output = gr.Textbox(
|
| 150 |
+
label="Generation process",
|
| 151 |
+
interactive=False,
|
| 152 |
+
lines=15,
|
| 153 |
+
max_lines=15,
|
| 154 |
+
placeholder="Generation progress will be shown here...",
|
| 155 |
+
elem_classes="process-output"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# 右侧栏
|
| 159 |
+
with gr.Column():
|
| 160 |
+
final_output = gr.Textbox(
|
| 161 |
+
label="Post-processed ABC notation scores",
|
| 162 |
+
interactive=True,
|
| 163 |
+
lines=23,
|
| 164 |
+
placeholder="Post-processed ABC scores will be shown here...",
|
| 165 |
+
elem_classes="final-output"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
save_btn = gr.Button("💾 Save as ABC & XML files", variant="secondary")
|
| 170 |
+
|
| 171 |
+
save_status = gr.Textbox(
|
| 172 |
+
label="Save Status",
|
| 173 |
+
interactive=False,
|
| 174 |
+
visible=True,
|
| 175 |
+
max_lines=2
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
period_dd.change(
|
| 179 |
+
update_components,
|
| 180 |
+
inputs=[period_dd, composer_dd],
|
| 181 |
+
outputs=[composer_dd, instrument_dd]
|
| 182 |
+
)
|
| 183 |
+
composer_dd.change(
|
| 184 |
+
update_components,
|
| 185 |
+
inputs=[period_dd, composer_dd],
|
| 186 |
+
outputs=[composer_dd, instrument_dd]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
generate_btn.click(
|
| 190 |
+
generate_music,
|
| 191 |
+
inputs=[period_dd, composer_dd, instrument_dd],
|
| 192 |
+
outputs=[process_output, final_output]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
save_btn.click(
|
| 196 |
+
save_and_convert,
|
| 197 |
+
inputs=[final_output, period_dd, composer_dd, instrument_dd],
|
| 198 |
+
outputs=[save_status]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
css = """
|
| 203 |
+
.process-output {
|
| 204 |
+
background-color: #f0f0f0;
|
| 205 |
+
font-family: monospace;
|
| 206 |
+
padding: 10px;
|
| 207 |
+
border-radius: 5px;
|
| 208 |
+
}
|
| 209 |
+
.final-output {
|
| 210 |
+
background-color: #ffffff;
|
| 211 |
+
font-family: sans-serif;
|
| 212 |
+
padding: 10px;
|
| 213 |
+
border-radius: 5px;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
.process-output textarea {
|
| 217 |
+
max-height: 500px !important;
|
| 218 |
+
overflow-y: auto !important;
|
| 219 |
+
white-space: pre-wrap;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
"""
|
| 223 |
+
css += """
|
| 224 |
+
button#💾-save-convert:hover {
|
| 225 |
+
background-color: #ffe6e6;
|
| 226 |
+
}
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
demo.css = css
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
|
| 233 |
+
demo.launch(
|
| 234 |
+
server_name="0.0.0.0",
|
| 235 |
+
server_port=7861
|
| 236 |
+
)
|
extract_clamp2.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input_dir = '' # interleaved abc folder
|
| 2 |
+
output_dir = '' # feature folder
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from config import *
|
| 11 |
+
from utils import *
|
| 12 |
+
from samplings import *
|
| 13 |
+
from accelerate import Accelerator
|
| 14 |
+
from transformers import BertConfig, AutoTokenizer
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
normalize = True
|
| 19 |
+
|
| 20 |
+
os.makedirs("logs", exist_ok=True)
|
| 21 |
+
for file in ["logs/files_extract_clamp2.json",
|
| 22 |
+
"logs/files_shuffle_extract_clamp2.json",
|
| 23 |
+
"logs/log_extract_clamp2.txt",
|
| 24 |
+
"logs/pass_extract_clamp2.txt",
|
| 25 |
+
"logs/skip_extract_clamp2.txt"]:
|
| 26 |
+
if os.path.exists(file):
|
| 27 |
+
os.remove(file)
|
| 28 |
+
|
| 29 |
+
files = []
|
| 30 |
+
for root, dirs, fs in os.walk(input_dir):
|
| 31 |
+
for f in fs:
|
| 32 |
+
if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf"):
|
| 33 |
+
files.append(os.path.join(root, f))
|
| 34 |
+
print(f"Found {len(files)} files in total")
|
| 35 |
+
with open("logs/files_extract_clamp2.json", "w", encoding="utf-8") as f:
|
| 36 |
+
json.dump(files, f)
|
| 37 |
+
random.shuffle(files)
|
| 38 |
+
with open("logs/files_shuffle_extract_clamp2.json", "w", encoding="utf-8") as f:
|
| 39 |
+
json.dump(files, f)
|
| 40 |
+
|
| 41 |
+
accelerator = Accelerator()
|
| 42 |
+
device = accelerator.device
|
| 43 |
+
print("Using device:", device)
|
| 44 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
| 45 |
+
f.write("Using device: " + str(device) + "\n")
|
| 46 |
+
|
| 47 |
+
m3_config = BertConfig(vocab_size=1,
|
| 48 |
+
hidden_size=M3_HIDDEN_SIZE,
|
| 49 |
+
num_hidden_layers=PATCH_NUM_LAYERS,
|
| 50 |
+
num_attention_heads=M3_HIDDEN_SIZE//64,
|
| 51 |
+
intermediate_size=M3_HIDDEN_SIZE*4,
|
| 52 |
+
max_position_embeddings=PATCH_LENGTH)
|
| 53 |
+
model = CLaMP2Model(m3_config,
|
| 54 |
+
text_model_name=TEXT_MODEL_NAME,
|
| 55 |
+
hidden_size=CLAMP2_HIDDEN_SIZE,
|
| 56 |
+
load_m3=CLAMP2_LOAD_M3)
|
| 57 |
+
model = model.to(device)
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
| 59 |
+
patchilizer = M3Patchilizer()
|
| 60 |
+
|
| 61 |
+
# print parameter number
|
| 62 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 63 |
+
|
| 64 |
+
model.eval()
|
| 65 |
+
checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True)
|
| 66 |
+
print(f"Successfully Loaded CLaMP 2 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
|
| 67 |
+
model.load_state_dict(checkpoint['model'])
|
| 68 |
+
|
| 69 |
+
def extract_feature(filename, get_normalized=normalize):
|
| 70 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 71 |
+
lines = f.readlines()
|
| 72 |
+
|
| 73 |
+
filtered_lines = []
|
| 74 |
+
for line in lines:
|
| 75 |
+
if line.startswith('%') and not line.startswith('%%'):
|
| 76 |
+
pass
|
| 77 |
+
else:
|
| 78 |
+
filtered_lines.append(line)
|
| 79 |
+
|
| 80 |
+
item = ''.join(filtered_lines)
|
| 81 |
+
|
| 82 |
+
if filename.endswith(".txt"):
|
| 83 |
+
item = list(set(item.split("\n")))
|
| 84 |
+
item = "\n".join(item)
|
| 85 |
+
item = item.split("\n")
|
| 86 |
+
item = [c for c in item if len(c) > 0]
|
| 87 |
+
item = tokenizer.sep_token.join(item)
|
| 88 |
+
input_data = tokenizer(item, return_tensors="pt")
|
| 89 |
+
input_data = input_data['input_ids'].squeeze(0)
|
| 90 |
+
max_input_length = MAX_TEXT_LENGTH
|
| 91 |
+
else:
|
| 92 |
+
input_data = patchilizer.encode(item, add_special_patches=True)
|
| 93 |
+
input_data = torch.tensor(input_data)
|
| 94 |
+
max_input_length = PATCH_LENGTH
|
| 95 |
+
|
| 96 |
+
segment_list = []
|
| 97 |
+
for i in range(0, len(input_data), max_input_length):
|
| 98 |
+
segment_list.append(input_data[i:i+max_input_length])
|
| 99 |
+
segment_list[-1] = input_data[-max_input_length:]
|
| 100 |
+
|
| 101 |
+
last_hidden_states_list = []
|
| 102 |
+
|
| 103 |
+
for input_segment in segment_list:
|
| 104 |
+
input_masks = torch.tensor([1]*input_segment.size(0))
|
| 105 |
+
if filename.endswith(".txt"):
|
| 106 |
+
pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
|
| 107 |
+
else:
|
| 108 |
+
pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
|
| 109 |
+
input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
|
| 110 |
+
input_segment = torch.cat((input_segment, pad_indices), 0)
|
| 111 |
+
|
| 112 |
+
if filename.endswith(".txt"):
|
| 113 |
+
last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
|
| 114 |
+
text_masks=input_masks.unsqueeze(0).to(device),
|
| 115 |
+
get_normalized=get_normalized)
|
| 116 |
+
else:
|
| 117 |
+
last_hidden_states = model.get_music_features(music_inputs=input_segment.unsqueeze(0).to(device),
|
| 118 |
+
music_masks=input_masks.unsqueeze(0).to(device),
|
| 119 |
+
get_normalized=get_normalized)
|
| 120 |
+
if not get_normalized:
|
| 121 |
+
last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
|
| 122 |
+
last_hidden_states_list.append(last_hidden_states)
|
| 123 |
+
|
| 124 |
+
if not get_normalized:
|
| 125 |
+
last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
|
| 126 |
+
last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
|
| 127 |
+
last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
|
| 128 |
+
else:
|
| 129 |
+
full_chunk_cnt = len(input_data) // max_input_length
|
| 130 |
+
remain_chunk_len = len(input_data) % max_input_length
|
| 131 |
+
if remain_chunk_len == 0:
|
| 132 |
+
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
|
| 133 |
+
else:
|
| 134 |
+
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
|
| 135 |
+
|
| 136 |
+
last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
|
| 137 |
+
last_hidden_states_list = last_hidden_states_list * feature_weights
|
| 138 |
+
last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()
|
| 139 |
+
|
| 140 |
+
return last_hidden_states_list
|
| 141 |
+
|
| 142 |
+
def process_directory(input_dir, output_dir, files):
|
| 143 |
+
print(f"Found {len(files)} files in total")
|
| 144 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
| 145 |
+
f.write("Found " + str(len(files)) + " files in total\n")
|
| 146 |
+
|
| 147 |
+
# calculate the number of files to process per GPU
|
| 148 |
+
num_files_per_gpu = len(files) // accelerator.num_processes
|
| 149 |
+
|
| 150 |
+
# calculate the start and end index for the current GPU
|
| 151 |
+
start_idx = accelerator.process_index * num_files_per_gpu
|
| 152 |
+
end_idx = start_idx + num_files_per_gpu
|
| 153 |
+
if accelerator.process_index == accelerator.num_processes - 1:
|
| 154 |
+
end_idx = len(files)
|
| 155 |
+
|
| 156 |
+
files_to_process = files[start_idx:end_idx]
|
| 157 |
+
|
| 158 |
+
# process the files
|
| 159 |
+
for file in tqdm(files_to_process):
|
| 160 |
+
output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
|
| 161 |
+
try:
|
| 162 |
+
os.makedirs(output_subdir, exist_ok=True)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(output_subdir + " can not be created\n" + str(e))
|
| 165 |
+
with open("logs/log_extract_clamp.txt", "a") as f:
|
| 166 |
+
f.write(output_subdir + " can not be created\n" + str(e) + "\n")
|
| 167 |
+
|
| 168 |
+
output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
|
| 169 |
+
|
| 170 |
+
if os.path.exists(output_file):
|
| 171 |
+
print(f"Skipping {file}, output already exists")
|
| 172 |
+
with open("logs/skip_extract_clamp2.txt", "a", encoding="utf-8") as f:
|
| 173 |
+
f.write(file + "\n")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
features = extract_feature(file).unsqueeze(0)
|
| 179 |
+
np.save(output_file, features.detach().cpu().numpy())
|
| 180 |
+
with open("logs/pass_extract_clamp2.txt", "a", encoding="utf-8") as f:
|
| 181 |
+
f.write(file + "\n")
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Failed to process {file}: {e}")
|
| 184 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
| 185 |
+
f.write("Failed to process " + file + ": " + str(e) + "\n")
|
| 186 |
+
|
| 187 |
+
with open("logs/files_shuffle_extract_clamp2.json", "r", encoding="utf-8") as f:
|
| 188 |
+
files = json.load(f)
|
| 189 |
+
|
| 190 |
+
# process the files
|
| 191 |
+
process_directory(input_dir, output_dir, files)
|
| 192 |
+
|
| 193 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
| 194 |
+
f.write("GPU ID: " + str(device) + "\n")
|
illustration.png
ADDED
|
Git LFS Details
|
illustration_online.png
ADDED
|
Git LFS Details
|
inference (1).py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
from utils import *
|
| 6 |
+
from config import *
|
| 7 |
+
from transformers import GPT2Config, LlamaConfig
|
| 8 |
+
from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
|
| 9 |
+
from abctoolkit.transpose import Note_list, Pitch_sign_list
|
| 10 |
+
from abctoolkit.duration import calculate_bartext_duration
|
| 11 |
+
|
| 12 |
+
Note_list = Note_list + ['z', 'x']
|
| 13 |
+
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
device = torch.device("cuda")
|
| 16 |
+
else:
|
| 17 |
+
device = torch.device("cpu")
|
| 18 |
+
|
| 19 |
+
os.makedirs(ORIGINAL_OUTPUT_FOLDER, exist_ok=True)
|
| 20 |
+
os.makedirs(INTERLEAVED_OUTPUT_FOLDER, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
patchilizer = Patchilizer()
|
| 23 |
+
|
| 24 |
+
patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
|
| 25 |
+
max_length=PATCH_LENGTH,
|
| 26 |
+
max_position_embeddings=PATCH_LENGTH,
|
| 27 |
+
n_embd=HIDDEN_SIZE,
|
| 28 |
+
num_attention_heads=HIDDEN_SIZE // 64,
|
| 29 |
+
vocab_size=1)
|
| 30 |
+
byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
| 31 |
+
max_length=PATCH_SIZE + 1,
|
| 32 |
+
max_position_embeddings=PATCH_SIZE + 1,
|
| 33 |
+
hidden_size=HIDDEN_SIZE,
|
| 34 |
+
num_attention_heads=HIDDEN_SIZE // 64,
|
| 35 |
+
vocab_size=128)
|
| 36 |
+
|
| 37 |
+
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config)
|
| 38 |
+
|
| 39 |
+
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 40 |
+
|
| 41 |
+
checkpoint = torch.load(INFERENCE_WEIGHTS_PATH, map_location=torch.device(device))
|
| 42 |
+
model.load_state_dict(checkpoint['model'])
|
| 43 |
+
model = model.to(device)
|
| 44 |
+
model.eval()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def rest_unreduce(abc_lines):
|
| 48 |
+
|
| 49 |
+
tunebody_index = None
|
| 50 |
+
for i in range(len(abc_lines)):
|
| 51 |
+
if '[V:' in abc_lines[i]:
|
| 52 |
+
tunebody_index = i
|
| 53 |
+
break
|
| 54 |
+
|
| 55 |
+
metadata_lines = abc_lines[: tunebody_index]
|
| 56 |
+
tunebody_lines = abc_lines[tunebody_index:]
|
| 57 |
+
|
| 58 |
+
part_symbol_list = []
|
| 59 |
+
voice_group_list = []
|
| 60 |
+
for line in metadata_lines:
|
| 61 |
+
if line.startswith('%%score'):
|
| 62 |
+
for round_bracket_match in re.findall(r'\((.*?)\)', line):
|
| 63 |
+
voice_group_list.append(round_bracket_match.split())
|
| 64 |
+
existed_voices = [item for sublist in voice_group_list for item in sublist]
|
| 65 |
+
if line.startswith('V:'):
|
| 66 |
+
symbol = line.split()[0]
|
| 67 |
+
part_symbol_list.append(symbol)
|
| 68 |
+
if symbol[2:] not in existed_voices:
|
| 69 |
+
voice_group_list.append([symbol[2:]])
|
| 70 |
+
z_symbol_list = [] # voices that use z as rest
|
| 71 |
+
x_symbol_list = [] # voices that use x as rest
|
| 72 |
+
for voice_group in voice_group_list:
|
| 73 |
+
z_symbol_list.append('V:' + voice_group[0])
|
| 74 |
+
for j in range(1, len(voice_group)):
|
| 75 |
+
x_symbol_list.append('V:' + voice_group[j])
|
| 76 |
+
|
| 77 |
+
part_symbol_list.sort(key=lambda x: int(x[2:]))
|
| 78 |
+
|
| 79 |
+
unreduced_tunebody_lines = []
|
| 80 |
+
|
| 81 |
+
for i, line in enumerate(tunebody_lines):
|
| 82 |
+
unreduced_line = ''
|
| 83 |
+
|
| 84 |
+
line = re.sub(r'^\[r:[^\]]*\]', '', line)
|
| 85 |
+
|
| 86 |
+
pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
|
| 87 |
+
matches = re.findall(pattern, line)
|
| 88 |
+
|
| 89 |
+
line_bar_dict = {}
|
| 90 |
+
for match in matches:
|
| 91 |
+
key = f'V:{match[0]}'
|
| 92 |
+
value = match[1]
|
| 93 |
+
line_bar_dict[key] = value
|
| 94 |
+
|
| 95 |
+
# calculate duration and collect barline
|
| 96 |
+
dur_dict = {}
|
| 97 |
+
for symbol, bartext in line_bar_dict.items():
|
| 98 |
+
right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
|
| 99 |
+
bartext = bartext[:-len(right_barline)]
|
| 100 |
+
try:
|
| 101 |
+
bar_dur = calculate_bartext_duration(bartext)
|
| 102 |
+
except:
|
| 103 |
+
bar_dur = None
|
| 104 |
+
if bar_dur is not None:
|
| 105 |
+
if bar_dur not in dur_dict.keys():
|
| 106 |
+
dur_dict[bar_dur] = 1
|
| 107 |
+
else:
|
| 108 |
+
dur_dict[bar_dur] += 1
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
ref_dur = max(dur_dict, key=dur_dict.get)
|
| 112 |
+
except:
|
| 113 |
+
pass # use last ref_dur
|
| 114 |
+
|
| 115 |
+
if i == 0:
|
| 116 |
+
prefix_left_barline = line.split('[V:')[0]
|
| 117 |
+
else:
|
| 118 |
+
prefix_left_barline = ''
|
| 119 |
+
|
| 120 |
+
for symbol in part_symbol_list:
|
| 121 |
+
if symbol in line_bar_dict.keys():
|
| 122 |
+
symbol_bartext = line_bar_dict[symbol]
|
| 123 |
+
else:
|
| 124 |
+
if symbol in z_symbol_list:
|
| 125 |
+
symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
|
| 126 |
+
elif symbol in x_symbol_list:
|
| 127 |
+
symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
|
| 128 |
+
unreduced_line += '[' + symbol + ']' + symbol_bartext
|
| 129 |
+
|
| 130 |
+
unreduced_tunebody_lines.append(unreduced_line + '\n')
|
| 131 |
+
|
| 132 |
+
unreduced_lines = metadata_lines + unreduced_tunebody_lines
|
| 133 |
+
|
| 134 |
+
return unreduced_lines
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def inference_patch(prompt_lines=[], pieces=NUM_SAMPLES):
|
| 138 |
+
|
| 139 |
+
file_no = 1
|
| 140 |
+
|
| 141 |
+
bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
|
| 142 |
+
|
| 143 |
+
while file_no <= pieces:
|
| 144 |
+
|
| 145 |
+
start_time = time.time()
|
| 146 |
+
start_time_format = time.strftime("%Y%m%d-%H%M%S")
|
| 147 |
+
|
| 148 |
+
prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
|
| 149 |
+
byte_list = list(''.join(prompt_lines))
|
| 150 |
+
print(''.join(byte_list), end='')
|
| 151 |
+
|
| 152 |
+
prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
|
| 153 |
+
in prompt_patches]
|
| 154 |
+
prompt_patches.insert(0, bos_patch)
|
| 155 |
+
|
| 156 |
+
input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
|
| 157 |
+
|
| 158 |
+
failure_flag = False
|
| 159 |
+
end_flag = False
|
| 160 |
+
cut_index = None
|
| 161 |
+
|
| 162 |
+
tunebody_flag = False
|
| 163 |
+
while True:
|
| 164 |
+
predicted_patch = model.generate(input_patches.unsqueeze(0),
|
| 165 |
+
top_k=TOP_K,
|
| 166 |
+
top_p=TOP_P,
|
| 167 |
+
temperature=TEMPERATURE)
|
| 168 |
+
if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # start with [r:0/
|
| 169 |
+
tunebody_flag = True
|
| 170 |
+
r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
|
| 171 |
+
temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
|
| 172 |
+
predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
|
| 173 |
+
top_k=TOP_K,
|
| 174 |
+
top_p=TOP_P,
|
| 175 |
+
temperature=TEMPERATURE)
|
| 176 |
+
predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
|
| 177 |
+
if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
|
| 178 |
+
end_flag = True
|
| 179 |
+
break
|
| 180 |
+
next_patch = patchilizer.decode([predicted_patch])
|
| 181 |
+
|
| 182 |
+
for char in next_patch:
|
| 183 |
+
byte_list.append(char)
|
| 184 |
+
print(char, end='')
|
| 185 |
+
|
| 186 |
+
patch_end_flag = False
|
| 187 |
+
for j in range(len(predicted_patch)):
|
| 188 |
+
if patch_end_flag:
|
| 189 |
+
predicted_patch[j] = patchilizer.special_token_id
|
| 190 |
+
if predicted_patch[j] == patchilizer.eos_token_id:
|
| 191 |
+
patch_end_flag = True
|
| 192 |
+
|
| 193 |
+
predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
|
| 194 |
+
input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
|
| 195 |
+
|
| 196 |
+
if len(byte_list) > 102400:
|
| 197 |
+
failure_flag = True
|
| 198 |
+
break
|
| 199 |
+
if time.time() - start_time > 20 * 60:
|
| 200 |
+
failure_flag = True
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
|
| 204 |
+
print('Stream generating...')
|
| 205 |
+
abc_code = ''.join(byte_list)
|
| 206 |
+
abc_lines = abc_code.split('\n')
|
| 207 |
+
|
| 208 |
+
tunebody_index = None
|
| 209 |
+
for i, line in enumerate(abc_lines):
|
| 210 |
+
if line.startswith('[r:') or line.startswith('[V:'):
|
| 211 |
+
tunebody_index = i
|
| 212 |
+
break
|
| 213 |
+
if tunebody_index is None or tunebody_index == len(abc_lines) - 1:
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
metadata_lines = abc_lines[:tunebody_index]
|
| 217 |
+
tunebody_lines = abc_lines[tunebody_index:]
|
| 218 |
+
|
| 219 |
+
metadata_lines = [line + '\n' for line in metadata_lines]
|
| 220 |
+
if not abc_code.endswith('\n'):
|
| 221 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [
|
| 222 |
+
tunebody_lines[-1]]
|
| 223 |
+
else:
|
| 224 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
|
| 225 |
+
|
| 226 |
+
if cut_index is None:
|
| 227 |
+
cut_index = len(tunebody_lines) // 2
|
| 228 |
+
|
| 229 |
+
abc_code_slice = ''.join(metadata_lines + tunebody_lines[-cut_index:])
|
| 230 |
+
input_patches = patchilizer.encode_generate(abc_code_slice)
|
| 231 |
+
|
| 232 |
+
input_patches = [item for sublist in input_patches for item in sublist]
|
| 233 |
+
input_patches = torch.tensor([input_patches], device=device)
|
| 234 |
+
input_patches = input_patches.reshape(1, -1)
|
| 235 |
+
|
| 236 |
+
if not failure_flag:
|
| 237 |
+
generation_time_cost = time.time() - start_time
|
| 238 |
+
|
| 239 |
+
abc_text = ''.join(byte_list)
|
| 240 |
+
filename = time.strftime("%Y%m%d-%H%M%S") + \
|
| 241 |
+
"_" + format(generation_time_cost, '.2f') + '_' + str(file_no) + ".abc"
|
| 242 |
+
|
| 243 |
+
# unreduce
|
| 244 |
+
unreduced_output_path = os.path.join(INTERLEAVED_OUTPUT_FOLDER, filename)
|
| 245 |
+
|
| 246 |
+
abc_lines = abc_text.split('\n')
|
| 247 |
+
abc_lines = list(filter(None, abc_lines))
|
| 248 |
+
abc_lines = [line + '\n' for line in abc_lines]
|
| 249 |
+
try:
|
| 250 |
+
abc_lines = rest_unreduce(abc_lines)
|
| 251 |
+
|
| 252 |
+
with open(unreduced_output_path, 'w') as file:
|
| 253 |
+
file.writelines(abc_lines)
|
| 254 |
+
except:
|
| 255 |
+
pass
|
| 256 |
+
else:
|
| 257 |
+
# original
|
| 258 |
+
original_output_path = os.path.join(ORIGINAL_OUTPUT_FOLDER, filename)
|
| 259 |
+
with open(original_output_path, 'w') as w:
|
| 260 |
+
w.write(abc_text)
|
| 261 |
+
|
| 262 |
+
file_no += 1
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
print('failed')
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == '__main__':
|
| 270 |
+
|
| 271 |
+
inference_patch()
|
inference.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
from utils import *
|
| 6 |
+
from config import *
|
| 7 |
+
from transformers import GPT2Config
|
| 8 |
+
from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
|
| 9 |
+
from abctoolkit.transpose import Note_list, Pitch_sign_list
|
| 10 |
+
from abctoolkit.duration import calculate_bartext_duration
|
| 11 |
+
|
| 12 |
+
Note_list = Note_list + ['z', 'x']
|
| 13 |
+
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
device = torch.device("cuda")
|
| 16 |
+
elif torch.backends.mps.is_available():
|
| 17 |
+
device = torch.device("mps")
|
| 18 |
+
else:
|
| 19 |
+
device = torch.device("cpu")
|
| 20 |
+
|
| 21 |
+
patchilizer = Patchilizer()
|
| 22 |
+
|
| 23 |
+
patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
|
| 24 |
+
max_length=PATCH_LENGTH,
|
| 25 |
+
max_position_embeddings=PATCH_LENGTH,
|
| 26 |
+
n_embd=HIDDEN_SIZE,
|
| 27 |
+
num_attention_heads=HIDDEN_SIZE // 64,
|
| 28 |
+
vocab_size=1)
|
| 29 |
+
byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
| 30 |
+
max_length=PATCH_SIZE + 1,
|
| 31 |
+
max_position_embeddings=PATCH_SIZE + 1,
|
| 32 |
+
hidden_size=HIDDEN_SIZE,
|
| 33 |
+
num_attention_heads=HIDDEN_SIZE // 64,
|
| 34 |
+
vocab_size=128)
|
| 35 |
+
|
| 36 |
+
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
| 40 |
+
"""
|
| 41 |
+
Prepare model for k-bit training.
|
| 42 |
+
Features include:
|
| 43 |
+
1. Convert model to mixed precision (FP16).
|
| 44 |
+
2. Disable unnecessary gradient computations.
|
| 45 |
+
3. Enable gradient checkpointing (optional).
|
| 46 |
+
"""
|
| 47 |
+
# Convert model to mixed precision
|
| 48 |
+
model = model.to(dtype=torch.float16)
|
| 49 |
+
|
| 50 |
+
# Disable gradients for embedding layers
|
| 51 |
+
for param in model.parameters():
|
| 52 |
+
if param.dtype == torch.float32:
|
| 53 |
+
param.requires_grad = False
|
| 54 |
+
|
| 55 |
+
# Enable gradient checkpointing
|
| 56 |
+
if use_gradient_checkpointing:
|
| 57 |
+
model.gradient_checkpointing_enable()
|
| 58 |
+
|
| 59 |
+
return model
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
model = prepare_model_for_kbit_training(
|
| 63 |
+
model,
|
| 64 |
+
use_gradient_checkpointing=False
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 68 |
+
|
| 69 |
+
checkpoint = torch.load(INFERENCE_WEIGHTS_PATH, map_location=torch.device(device))
|
| 70 |
+
model.load_state_dict(checkpoint['model'])
|
| 71 |
+
model = model.to(device)
|
| 72 |
+
model.eval()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def complete_brackets(s):
|
| 76 |
+
stack = []
|
| 77 |
+
bracket_map = {'{': '}', '[': ']', '(': ')'}
|
| 78 |
+
|
| 79 |
+
# Iterate through each character, handle bracket matching
|
| 80 |
+
for char in s:
|
| 81 |
+
if char in bracket_map:
|
| 82 |
+
stack.append(char)
|
| 83 |
+
elif char in bracket_map.values():
|
| 84 |
+
# Find the corresponding left bracket
|
| 85 |
+
for key, value in bracket_map.items():
|
| 86 |
+
if value == char:
|
| 87 |
+
if stack and stack[-1] == key:
|
| 88 |
+
stack.pop()
|
| 89 |
+
break # Found matching right bracket, process next character
|
| 90 |
+
|
| 91 |
+
# Complete missing right brackets (in reverse order of remaining left brackets in stack)
|
| 92 |
+
completion = ''.join(bracket_map[c] for c in reversed(stack))
|
| 93 |
+
return s + completion
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def rest_unreduce(abc_lines):
|
| 97 |
+
|
| 98 |
+
tunebody_index = None
|
| 99 |
+
for i in range(len(abc_lines)):
|
| 100 |
+
if abc_lines[i].startswith('%%score'):
|
| 101 |
+
abc_lines[i] = complete_brackets(abc_lines[i])
|
| 102 |
+
if '[V:' in abc_lines[i]:
|
| 103 |
+
tunebody_index = i
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
metadata_lines = abc_lines[: tunebody_index]
|
| 107 |
+
tunebody_lines = abc_lines[tunebody_index:]
|
| 108 |
+
|
| 109 |
+
part_symbol_list = []
|
| 110 |
+
voice_group_list = []
|
| 111 |
+
for line in metadata_lines:
|
| 112 |
+
if line.startswith('%%score'):
|
| 113 |
+
for round_bracket_match in re.findall(r'\((.*?)\)', line):
|
| 114 |
+
voice_group_list.append(round_bracket_match.split())
|
| 115 |
+
existed_voices = [item for sublist in voice_group_list for item in sublist]
|
| 116 |
+
if line.startswith('V:'):
|
| 117 |
+
symbol = line.split()[0]
|
| 118 |
+
part_symbol_list.append(symbol)
|
| 119 |
+
if symbol[2:] not in existed_voices:
|
| 120 |
+
voice_group_list.append([symbol[2:]])
|
| 121 |
+
z_symbol_list = [] # voices that use z as rest
|
| 122 |
+
x_symbol_list = [] # voices that use x as rest
|
| 123 |
+
for voice_group in voice_group_list:
|
| 124 |
+
z_symbol_list.append('V:' + voice_group[0])
|
| 125 |
+
for j in range(1, len(voice_group)):
|
| 126 |
+
x_symbol_list.append('V:' + voice_group[j])
|
| 127 |
+
|
| 128 |
+
part_symbol_list.sort(key=lambda x: int(x[2:]))
|
| 129 |
+
|
| 130 |
+
unreduced_tunebody_lines = []
|
| 131 |
+
|
| 132 |
+
for i, line in enumerate(tunebody_lines):
|
| 133 |
+
unreduced_line = ''
|
| 134 |
+
|
| 135 |
+
line = re.sub(r'^\[r:[^\]]*\]', '', line)
|
| 136 |
+
|
| 137 |
+
pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
|
| 138 |
+
matches = re.findall(pattern, line)
|
| 139 |
+
|
| 140 |
+
line_bar_dict = {}
|
| 141 |
+
for match in matches:
|
| 142 |
+
key = f'V:{match[0]}'
|
| 143 |
+
value = match[1]
|
| 144 |
+
line_bar_dict[key] = value
|
| 145 |
+
|
| 146 |
+
# calculate duration and collect barline
|
| 147 |
+
dur_dict = {}
|
| 148 |
+
for symbol, bartext in line_bar_dict.items():
|
| 149 |
+
right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
|
| 150 |
+
bartext = bartext[:-len(right_barline)]
|
| 151 |
+
try:
|
| 152 |
+
bar_dur = calculate_bartext_duration(bartext)
|
| 153 |
+
except:
|
| 154 |
+
bar_dur = None
|
| 155 |
+
if bar_dur is not None:
|
| 156 |
+
if bar_dur not in dur_dict.keys():
|
| 157 |
+
dur_dict[bar_dur] = 1
|
| 158 |
+
else:
|
| 159 |
+
dur_dict[bar_dur] += 1
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
ref_dur = max(dur_dict, key=dur_dict.get)
|
| 163 |
+
except:
|
| 164 |
+
pass # use last ref_dur
|
| 165 |
+
|
| 166 |
+
if i == 0:
|
| 167 |
+
prefix_left_barline = line.split('[V:')[0]
|
| 168 |
+
else:
|
| 169 |
+
prefix_left_barline = ''
|
| 170 |
+
|
| 171 |
+
for symbol in part_symbol_list:
|
| 172 |
+
if symbol in line_bar_dict.keys():
|
| 173 |
+
symbol_bartext = line_bar_dict[symbol]
|
| 174 |
+
else:
|
| 175 |
+
if symbol in z_symbol_list:
|
| 176 |
+
symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
|
| 177 |
+
elif symbol in x_symbol_list:
|
| 178 |
+
symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
|
| 179 |
+
unreduced_line += '[' + symbol + ']' + symbol_bartext
|
| 180 |
+
|
| 181 |
+
unreduced_tunebody_lines.append(unreduced_line + '\n')
|
| 182 |
+
|
| 183 |
+
unreduced_lines = metadata_lines + unreduced_tunebody_lines
|
| 184 |
+
|
| 185 |
+
return unreduced_lines
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def inference_patch(period, composer, instrumentation):
|
| 189 |
+
|
| 190 |
+
prompt_lines=[
|
| 191 |
+
'%' + period + '\n',
|
| 192 |
+
'%' + composer + '\n',
|
| 193 |
+
'%' + instrumentation + '\n']
|
| 194 |
+
|
| 195 |
+
while True:
|
| 196 |
+
|
| 197 |
+
failure_flag = False
|
| 198 |
+
|
| 199 |
+
bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
|
| 200 |
+
|
| 201 |
+
start_time = time.time()
|
| 202 |
+
|
| 203 |
+
prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
|
| 204 |
+
byte_list = list(''.join(prompt_lines))
|
| 205 |
+
context_tunebody_byte_list = []
|
| 206 |
+
metadata_byte_list = []
|
| 207 |
+
|
| 208 |
+
print(''.join(byte_list), end='')
|
| 209 |
+
|
| 210 |
+
prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
|
| 211 |
+
in prompt_patches]
|
| 212 |
+
prompt_patches.insert(0, bos_patch)
|
| 213 |
+
|
| 214 |
+
input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
|
| 215 |
+
|
| 216 |
+
end_flag = False
|
| 217 |
+
cut_index = None
|
| 218 |
+
|
| 219 |
+
tunebody_flag = False
|
| 220 |
+
|
| 221 |
+
with torch.inference_mode():
|
| 222 |
+
|
| 223 |
+
while True:
|
| 224 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 225 |
+
predicted_patch = model.generate(input_patches.unsqueeze(0),
|
| 226 |
+
top_k=TOP_K,
|
| 227 |
+
top_p=TOP_P,
|
| 228 |
+
temperature=TEMPERATURE)
|
| 229 |
+
if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # 初次进入tunebody,必须以[r:0/开头
|
| 230 |
+
tunebody_flag = True
|
| 231 |
+
r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
|
| 232 |
+
temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
|
| 233 |
+
predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
|
| 234 |
+
top_k=TOP_K,
|
| 235 |
+
top_p=TOP_P,
|
| 236 |
+
temperature=TEMPERATURE)
|
| 237 |
+
predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
|
| 238 |
+
if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
|
| 239 |
+
end_flag = True
|
| 240 |
+
break
|
| 241 |
+
next_patch = patchilizer.decode([predicted_patch])
|
| 242 |
+
|
| 243 |
+
for char in next_patch:
|
| 244 |
+
byte_list.append(char)
|
| 245 |
+
if tunebody_flag:
|
| 246 |
+
context_tunebody_byte_list.append(char)
|
| 247 |
+
else:
|
| 248 |
+
metadata_byte_list.append(char)
|
| 249 |
+
print(char, end='')
|
| 250 |
+
|
| 251 |
+
patch_end_flag = False
|
| 252 |
+
for j in range(len(predicted_patch)):
|
| 253 |
+
if patch_end_flag:
|
| 254 |
+
predicted_patch[j] = patchilizer.special_token_id
|
| 255 |
+
if predicted_patch[j] == patchilizer.eos_token_id:
|
| 256 |
+
patch_end_flag = True
|
| 257 |
+
|
| 258 |
+
predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
|
| 259 |
+
input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
|
| 260 |
+
|
| 261 |
+
if len(byte_list) > 102400:
|
| 262 |
+
failure_flag = True
|
| 263 |
+
break
|
| 264 |
+
if time.time() - start_time > 10 * 60:
|
| 265 |
+
failure_flag = True
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
|
| 269 |
+
print('Stream generating...')
|
| 270 |
+
|
| 271 |
+
metadata = ''.join(metadata_byte_list)
|
| 272 |
+
context_tunebody = ''.join(context_tunebody_byte_list)
|
| 273 |
+
|
| 274 |
+
if '\n' not in context_tunebody:
|
| 275 |
+
break # Generated content is all metadata, abandon
|
| 276 |
+
|
| 277 |
+
context_tunebody_liness = context_tunebody.split('\n')
|
| 278 |
+
if not context_tunebody.endswith('\n'):
|
| 279 |
+
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
|
| 280 |
+
else:
|
| 281 |
+
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
|
| 282 |
+
|
| 283 |
+
cut_index = len(context_tunebody_liness) // 2
|
| 284 |
+
abc_code_slice = metadata + ''.join(context_tunebody_liness[-cut_index:])
|
| 285 |
+
|
| 286 |
+
input_patches = patchilizer.encode_generate(abc_code_slice)
|
| 287 |
+
|
| 288 |
+
input_patches = [item for sublist in input_patches for item in sublist]
|
| 289 |
+
input_patches = torch.tensor([input_patches], device=device)
|
| 290 |
+
input_patches = input_patches.reshape(1, -1)
|
| 291 |
+
|
| 292 |
+
context_tunebody_byte_list = list(''.join(context_tunebody_lines[-cut_index:]))
|
| 293 |
+
|
| 294 |
+
if not failure_flag:
|
| 295 |
+
abc_text = ''.join(byte_list)
|
| 296 |
+
|
| 297 |
+
# unreduce
|
| 298 |
+
abc_lines = abc_text.split('\n')
|
| 299 |
+
abc_lines = list(filter(None, abc_lines))
|
| 300 |
+
abc_lines = [line + '\n' for line in abc_lines]
|
| 301 |
+
try:
|
| 302 |
+
unreduced_abc_lines = rest_unreduce(abc_lines)
|
| 303 |
+
except:
|
| 304 |
+
failure_flag = True
|
| 305 |
+
pass
|
| 306 |
+
else:
|
| 307 |
+
unreduced_abc_lines = [line for line in unreduced_abc_lines if not(line.startswith('%') and not line.startswith('%%'))]
|
| 308 |
+
unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
|
| 309 |
+
unreduced_abc_text = ''.join(unreduced_abc_lines)
|
| 310 |
+
return unreduced_abc_text
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == '__main__':
|
| 318 |
+
inference_patch('Classical', 'Beethoven, Ludwig van', 'Keyboard')
|
notagen.png
ADDED
|
Git LFS Details
|
prompts.txt
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Baroque_Bach, Johann Sebastian_Chamber
|
| 2 |
+
Baroque_Bach, Johann Sebastian_Choral
|
| 3 |
+
Baroque_Bach, Johann Sebastian_Keyboard
|
| 4 |
+
Baroque_Bach, Johann Sebastian_Orchestral
|
| 5 |
+
Baroque_Bach, Johann Sebastian_Vocal-Orchestral
|
| 6 |
+
Baroque_Corelli, Arcangelo_Chamber
|
| 7 |
+
Baroque_Corelli, Arcangelo_Orchestral
|
| 8 |
+
Baroque_Handel, George Frideric_Chamber
|
| 9 |
+
Baroque_Handel, George Frideric_Keyboard
|
| 10 |
+
Baroque_Handel, George Frideric_Orchestral
|
| 11 |
+
Baroque_Handel, George Frideric_Vocal-Orchestral
|
| 12 |
+
Baroque_Scarlatti, Domenico_Keyboard
|
| 13 |
+
Baroque_Vivaldi, Antonio_Chamber
|
| 14 |
+
Baroque_Vivaldi, Antonio_Orchestral
|
| 15 |
+
Baroque_Vivaldi, Antonio_Vocal-Orchestral
|
| 16 |
+
Classical_Beethoven, Ludwig van_Art Song
|
| 17 |
+
Classical_Beethoven, Ludwig van_Chamber
|
| 18 |
+
Classical_Beethoven, Ludwig van_Keyboard
|
| 19 |
+
Classical_Beethoven, Ludwig van_Orchestral
|
| 20 |
+
Classical_Haydn, Joseph_Chamber
|
| 21 |
+
Classical_Haydn, Joseph_Keyboard
|
| 22 |
+
Classical_Haydn, Joseph_Orchestral
|
| 23 |
+
Classical_Haydn, Joseph_Vocal-Orchestral
|
| 24 |
+
Classical_Mozart, Wolfgang Amadeus_Chamber
|
| 25 |
+
Classical_Mozart, Wolfgang Amadeus_Choral
|
| 26 |
+
Classical_Mozart, Wolfgang Amadeus_Keyboard
|
| 27 |
+
Classical_Mozart, Wolfgang Amadeus_Orchestral
|
| 28 |
+
Classical_Mozart, Wolfgang Amadeus_Vocal-Orchestral
|
| 29 |
+
Classical_Paradis, Maria Theresia von_Art Song
|
| 30 |
+
Classical_Reichardt, Louise_Art Song
|
| 31 |
+
Classical_Saint-Georges, Joseph Bologne_Chamber
|
| 32 |
+
Classical_Schroter, Corona_Art Song
|
| 33 |
+
Romantic_Bartok, Bela_Keyboard
|
| 34 |
+
Romantic_Berlioz, Hector_Choral
|
| 35 |
+
Romantic_Bizet, Georges_Art Song
|
| 36 |
+
Romantic_Boulanger, Lili_Art Song
|
| 37 |
+
Romantic_Boulton, Harold_Art Song
|
| 38 |
+
Romantic_Brahms, Johannes_Art Song
|
| 39 |
+
Romantic_Brahms, Johannes_Chamber
|
| 40 |
+
Romantic_Brahms, Johannes_Choral
|
| 41 |
+
Romantic_Brahms, Johannes_Keyboard
|
| 42 |
+
Romantic_Brahms, Johannes_Orchestral
|
| 43 |
+
Romantic_Burgmuller, Friedrich_Keyboard
|
| 44 |
+
Romantic_Butterworth, George_Art Song
|
| 45 |
+
Romantic_Chaminade, Cecile_Art Song
|
| 46 |
+
Romantic_Chausson, Ernest_Art Song
|
| 47 |
+
Romantic_Chopin, Frederic_Art Song
|
| 48 |
+
Romantic_Chopin, Frederic_Keyboard
|
| 49 |
+
Romantic_Cornelius, Peter_Art Song
|
| 50 |
+
Romantic_Debussy, Claude_Art Song
|
| 51 |
+
Romantic_Debussy, Claude_Keyboard
|
| 52 |
+
Romantic_Dvorak, Antonin_Chamber
|
| 53 |
+
Romantic_Dvorak, Antonin_Choral
|
| 54 |
+
Romantic_Dvorak, Antonin_Keyboard
|
| 55 |
+
Romantic_Dvorak, Antonin_Orchestral
|
| 56 |
+
Romantic_Faisst, Clara_Art Song
|
| 57 |
+
Romantic_Faure, Gabriel_Art Song
|
| 58 |
+
Romantic_Faure, Gabriel_Chamber
|
| 59 |
+
Romantic_Faure, Gabriel_Keyboard
|
| 60 |
+
Romantic_Franz, Robert_Art Song
|
| 61 |
+
Romantic_Gonzaga, Chiquinha_Art Song
|
| 62 |
+
Romantic_Grandval, Clemence de_Art Song
|
| 63 |
+
Romantic_Grieg, Edvard_Keyboard
|
| 64 |
+
Romantic_Grieg, Edvard_Orchestral
|
| 65 |
+
Romantic_Hensel, Fanny_Art Song
|
| 66 |
+
Romantic_Holmes, Augusta Mary Anne_Art Song
|
| 67 |
+
Romantic_Jaell, Marie_Art Song
|
| 68 |
+
Romantic_Kinkel, Johanna_Art Song
|
| 69 |
+
Romantic_Kralik, Mathilde_Art Song
|
| 70 |
+
Romantic_Lang, Josephine_Art Song
|
| 71 |
+
Romantic_Lehmann, Liza_Art Song
|
| 72 |
+
Romantic_Liszt, Franz_Keyboard
|
| 73 |
+
Romantic_Mayer, Emilie_Chamber
|
| 74 |
+
Romantic_Medtner, Nikolay_Keyboard
|
| 75 |
+
Romantic_Mendelssohn, Felix_Art Song
|
| 76 |
+
Romantic_Mendelssohn, Felix_Chamber
|
| 77 |
+
Romantic_Mendelssohn, Felix_Choral
|
| 78 |
+
Romantic_Mendelssohn, Felix_Keyboard
|
| 79 |
+
Romantic_Mendelssohn, Felix_Orchestral
|
| 80 |
+
Romantic_Munktell, Helena_Art Song
|
| 81 |
+
Romantic_Parratt, Walter_Choral
|
| 82 |
+
Romantic_Prokofiev, Sergey_Keyboard
|
| 83 |
+
Romantic_Rachmaninoff, Sergei_Choral
|
| 84 |
+
Romantic_Rachmaninoff, Sergei_Keyboard
|
| 85 |
+
Romantic_Ravel, Maurice_Art Song
|
| 86 |
+
Romantic_Ravel, Maurice_Chamber
|
| 87 |
+
Romantic_Ravel, Maurice_Keyboard
|
| 88 |
+
Romantic_Saint-Saens, Camille_Chamber
|
| 89 |
+
Romantic_Saint-Saens, Camille_Keyboard
|
| 90 |
+
Romantic_Saint-Saens, Camille_Orchestral
|
| 91 |
+
Romantic_Satie, Erik_Art Song
|
| 92 |
+
Romantic_Satie, Erik_Keyboard
|
| 93 |
+
Romantic_Schubert, Franz_Art Song
|
| 94 |
+
Romantic_Schubert, Franz_Chamber
|
| 95 |
+
Romantic_Schubert, Franz_Choral
|
| 96 |
+
Romantic_Schubert, Franz_Keyboard
|
| 97 |
+
Romantic_Schumann, Clara_Art Song
|
| 98 |
+
Romantic_Schumann, Robert_Art Song
|
| 99 |
+
Romantic_Schumann, Robert_Chamber
|
| 100 |
+
Romantic_Schumann, Robert_Choral
|
| 101 |
+
Romantic_Schumann, Robert_Keyboard
|
| 102 |
+
Romantic_Scriabin, Aleksandr_Keyboard
|
| 103 |
+
Romantic_Shostakovich, Dmitry_Chamber
|
| 104 |
+
Romantic_Shostakovich, Dmitry_Keyboard
|
| 105 |
+
Romantic_Sibelius, Jean_Keyboard
|
| 106 |
+
Romantic_Smetana, Bedrich_Keyboard
|
| 107 |
+
Romantic_Tchaikovsky, Pyotr_Keyboard
|
| 108 |
+
Romantic_Tchaikovsky, Pyotr_Orchestral
|
| 109 |
+
Romantic_Viardot, Pauline_Art Song
|
| 110 |
+
Romantic_Warlock, Peter_Art Song
|
| 111 |
+
Romantic_Wolf, Hugo_Art Song
|
| 112 |
+
Romantic_Zumsteeg, Emilie_Art Song
|
requirements (6).txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.40.0
|
| 2 |
+
numpy==1.26.4
|
| 3 |
+
wandb==0.17.2
|
| 4 |
+
abctoolkit==0.0.6
|
| 5 |
+
samplings==0.1.7
|
| 6 |
+
pyparsing==3.2.1
|
| 7 |
+
gradio==5.17.1
|
statistics.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gt_feature_folder = ''
|
| 2 |
+
output_feature_folder = ''
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import random
|
| 7 |
+
import re
|
| 8 |
+
import numpy as np
|
| 9 |
+
from config import *
|
| 10 |
+
|
| 11 |
+
def load_npy_files(folder_path_list):
|
| 12 |
+
"""
|
| 13 |
+
Load all .npy files from a specified folder and return a list of numpy arrays.
|
| 14 |
+
"""
|
| 15 |
+
npy_list = []
|
| 16 |
+
for file_path in folder_path_list:
|
| 17 |
+
if file_path.endswith('.npy'):
|
| 18 |
+
# file_path = os.path.join(folder_path, file_name)
|
| 19 |
+
np_array = np.load(file_path)[0]
|
| 20 |
+
npy_list.append(np_array)
|
| 21 |
+
return npy_list
|
| 22 |
+
|
| 23 |
+
def average_npy(npy_list):
|
| 24 |
+
"""
|
| 25 |
+
Compute the average of a list of numpy arrays.
|
| 26 |
+
"""
|
| 27 |
+
return np.mean(npy_list, axis=0)
|
| 28 |
+
|
| 29 |
+
def cosine_similarity(vec1, vec2):
|
| 30 |
+
"""
|
| 31 |
+
Compute cosine similarity between two numpy arrays.
|
| 32 |
+
"""
|
| 33 |
+
dot_product = np.dot(vec1, vec2)
|
| 34 |
+
|
| 35 |
+
norm_vec1 = np.linalg.norm(vec1)
|
| 36 |
+
norm_vec2 = np.linalg.norm(vec2)
|
| 37 |
+
|
| 38 |
+
cosine_sim = dot_product / (norm_vec1 * norm_vec2)
|
| 39 |
+
|
| 40 |
+
return cosine_sim
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_generated_results_similarity():
|
| 45 |
+
|
| 46 |
+
gt_feature_paths = []
|
| 47 |
+
for gt_feature_file in os.listdir(gt_feature_folder):
|
| 48 |
+
gt_feature_paths.append(os.path.join(gt_feature_folder, gt_feature_file))
|
| 49 |
+
gt_features = load_npy_files(gt_feature_paths)
|
| 50 |
+
gt_avg_feature = average_npy(gt_features)
|
| 51 |
+
|
| 52 |
+
clamp2score_list = []
|
| 53 |
+
for output_feature_file in os.listdir(output_feature_folder):
|
| 54 |
+
output_feature_path = os.path.join(output_feature_folder, output_feature_file)
|
| 55 |
+
output_feature = np.load(output_feature_path)[0]
|
| 56 |
+
clamp2score = cosine_similarity(gt_avg_feature, output_feature)
|
| 57 |
+
clamp2score_list.append(clamp2score)
|
| 58 |
+
avg_clampscore = sum(clamp2score_list) / len(clamp2score_list)
|
| 59 |
+
|
| 60 |
+
print('average clamp 2 score:', avg_clampscore)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
|
| 67 |
+
test_generated_results_similarity()
|
| 68 |
+
|
train-gen (1).py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
import math
|
| 5 |
+
import json
|
| 6 |
+
import wandb
|
| 7 |
+
import torch
|
| 8 |
+
import random
|
| 9 |
+
import numpy as np
|
| 10 |
+
from utils import *
|
| 11 |
+
from config import *
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from copy import deepcopy
|
| 14 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from transformers import GPT2Config, LlamaConfig, get_scheduler, get_constant_schedule_with_warmup
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 20 |
+
|
| 21 |
+
# Set up distributed training
|
| 22 |
+
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
| 23 |
+
global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
|
| 24 |
+
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
|
| 25 |
+
|
| 26 |
+
if world_size > 1:
|
| 27 |
+
torch.cuda.set_device(local_rank)
|
| 28 |
+
device = torch.device("cuda", local_rank)
|
| 29 |
+
dist.init_process_group(backend='nccl') if world_size > 1 else None
|
| 30 |
+
else:
|
| 31 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 32 |
+
|
| 33 |
+
# Set random seed
|
| 34 |
+
seed = 0 + global_rank
|
| 35 |
+
random.seed(seed)
|
| 36 |
+
np.random.seed(seed)
|
| 37 |
+
torch.manual_seed(seed)
|
| 38 |
+
torch.cuda.manual_seed_all(seed)
|
| 39 |
+
torch.backends.cudnn.deterministic = True
|
| 40 |
+
torch.backends.cudnn.benchmark = False
|
| 41 |
+
|
| 42 |
+
batch_size = BATCH_SIZE
|
| 43 |
+
|
| 44 |
+
patchilizer = Patchilizer()
|
| 45 |
+
|
| 46 |
+
patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
|
| 47 |
+
max_length=PATCH_LENGTH,
|
| 48 |
+
max_position_embeddings=PATCH_LENGTH,
|
| 49 |
+
n_embd=HIDDEN_SIZE,
|
| 50 |
+
num_attention_heads=HIDDEN_SIZE//64,
|
| 51 |
+
vocab_size=1)
|
| 52 |
+
char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
| 53 |
+
max_length=PATCH_SIZE+1,
|
| 54 |
+
max_position_embeddings=PATCH_SIZE+1,
|
| 55 |
+
hidden_size=HIDDEN_SIZE,
|
| 56 |
+
num_attention_heads=HIDDEN_SIZE//64,
|
| 57 |
+
vocab_size=128)
|
| 58 |
+
|
| 59 |
+
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
|
| 60 |
+
|
| 61 |
+
model = model.to(device)
|
| 62 |
+
|
| 63 |
+
# print parameter number
|
| 64 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 65 |
+
|
| 66 |
+
if world_size > 1:
|
| 67 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
| 68 |
+
|
| 69 |
+
scaler = GradScaler()
|
| 70 |
+
is_autocast = True
|
| 71 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def clear_unused_tensors():
|
| 75 |
+
gc.disable() # Temporarily disable garbage collection
|
| 76 |
+
try:
|
| 77 |
+
# Get the set of tensor ids used by the model
|
| 78 |
+
if hasattr(model, "module"):
|
| 79 |
+
model_tensors = {id(p) for p in model.module.parameters()}
|
| 80 |
+
else:
|
| 81 |
+
model_tensors = {id(p) for p in model.parameters()}
|
| 82 |
+
|
| 83 |
+
# Get the set of tensor ids used by the optimizer
|
| 84 |
+
optimizer_tensors = {
|
| 85 |
+
id(state)
|
| 86 |
+
for state_dict in optimizer.state.values()
|
| 87 |
+
for state in state_dict.values()
|
| 88 |
+
if isinstance(state, torch.Tensor) # Ensure only tensors are considered
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# List of all CUDA tensors currently in memory
|
| 92 |
+
tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
|
| 93 |
+
|
| 94 |
+
# Create weak references to avoid interfering with garbage collection
|
| 95 |
+
tensor_refs = [weakref.ref(tensor) for tensor in tensors]
|
| 96 |
+
|
| 97 |
+
for tensor_ref in tensor_refs:
|
| 98 |
+
tensor = tensor_ref() # Dereference the weak reference
|
| 99 |
+
if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors:
|
| 100 |
+
# Mark the tensor for deletion
|
| 101 |
+
tensor.detach_() # Detach from computation graph
|
| 102 |
+
del tensor # Delete the tensor reference
|
| 103 |
+
except:
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
finally:
|
| 107 |
+
gc.enable() # Re-enable garbage collection
|
| 108 |
+
gc.collect() # Force a garbage collection
|
| 109 |
+
torch.cuda.empty_cache() # Clear the CUDA cache
|
| 110 |
+
|
| 111 |
+
def collate_batch(input_batches):
|
| 112 |
+
|
| 113 |
+
input_patches, input_masks = zip(*input_batches)
|
| 114 |
+
input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=0)
|
| 115 |
+
input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0)
|
| 116 |
+
|
| 117 |
+
return input_patches.to(device), input_masks.to(device)
|
| 118 |
+
|
| 119 |
+
def split_into_minibatches(input_patches, input_masks, minibatch_size):
|
| 120 |
+
minibatches = []
|
| 121 |
+
for start_idx in range(0, len(input_patches), minibatch_size):
|
| 122 |
+
end_idx = start_idx + minibatch_size
|
| 123 |
+
minibatch_patches = input_patches[start_idx:end_idx]
|
| 124 |
+
minibatch_masks = input_masks[start_idx:end_idx]
|
| 125 |
+
minibatches.append((minibatch_patches, minibatch_masks))
|
| 126 |
+
return minibatches
|
| 127 |
+
|
| 128 |
+
class NotaGenDataset(Dataset):
|
| 129 |
+
def __init__(self, filenames):
|
| 130 |
+
self.filenames = filenames
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
return len(self.filenames)
|
| 134 |
+
|
| 135 |
+
def __getitem__(self, idx):
|
| 136 |
+
|
| 137 |
+
filepath = self.filenames[idx]['path']
|
| 138 |
+
|
| 139 |
+
key = random.choice(['C#', 'F#', 'B', 'E', 'A', 'D', 'G', 'C', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
|
| 140 |
+
|
| 141 |
+
folder = os.path.dirname(filepath)
|
| 142 |
+
name = os.path.split(filepath)[-1]
|
| 143 |
+
des_filepath = os.path.join(folder, key, name + '_' + key + '.abc')
|
| 144 |
+
|
| 145 |
+
with open(des_filepath, 'r', encoding='utf-8') as f:
|
| 146 |
+
abc_text = f.read()
|
| 147 |
+
|
| 148 |
+
file_bytes = patchilizer.encode_train(abc_text)
|
| 149 |
+
file_masks = [1] * len(file_bytes)
|
| 150 |
+
|
| 151 |
+
file_bytes = torch.tensor(file_bytes, dtype=torch.long)
|
| 152 |
+
file_masks = torch.tensor(file_masks, dtype=torch.long)
|
| 153 |
+
|
| 154 |
+
return file_bytes, file_masks
|
| 155 |
+
|
| 156 |
+
def process_one_batch(batch):
|
| 157 |
+
input_patches, input_masks = batch
|
| 158 |
+
loss = model(input_patches, input_masks).loss
|
| 159 |
+
|
| 160 |
+
# Reduce the loss on GPU 0
|
| 161 |
+
if world_size > 1:
|
| 162 |
+
loss = loss.unsqueeze(0)
|
| 163 |
+
dist.reduce(loss, dst=0)
|
| 164 |
+
loss = loss / world_size
|
| 165 |
+
dist.broadcast(loss, src=0)
|
| 166 |
+
|
| 167 |
+
return loss
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# do one epoch for training
|
| 171 |
+
def train_epoch(epoch):
|
| 172 |
+
tqdm_train_set = tqdm(train_set)
|
| 173 |
+
total_train_loss = 0
|
| 174 |
+
iter_idx = 1
|
| 175 |
+
model.train()
|
| 176 |
+
train_steps = (epoch-1)*len(train_set)
|
| 177 |
+
|
| 178 |
+
for batch in tqdm_train_set:
|
| 179 |
+
minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
|
| 180 |
+
for minibatch in minibatches:
|
| 181 |
+
with autocast():
|
| 182 |
+
loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
|
| 183 |
+
scaler.scale(loss).backward()
|
| 184 |
+
total_train_loss += loss.item()
|
| 185 |
+
scaler.step(optimizer)
|
| 186 |
+
scaler.update()
|
| 187 |
+
|
| 188 |
+
lr_scheduler.step()
|
| 189 |
+
model.zero_grad(set_to_none=True)
|
| 190 |
+
tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
|
| 191 |
+
train_steps += 1
|
| 192 |
+
|
| 193 |
+
# Log the training loss to wandb
|
| 194 |
+
if global_rank==0 and WANDB_LOGGING:
|
| 195 |
+
wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
|
| 196 |
+
|
| 197 |
+
iter_idx += 1
|
| 198 |
+
if iter_idx % 1000 == 0:
|
| 199 |
+
clear_unused_tensors()
|
| 200 |
+
|
| 201 |
+
return total_train_loss / (iter_idx-1)
|
| 202 |
+
|
| 203 |
+
# do one epoch for eval
|
| 204 |
+
def eval_epoch():
|
| 205 |
+
tqdm_eval_set = tqdm(eval_set)
|
| 206 |
+
total_eval_loss = 0
|
| 207 |
+
total_eval_bpb = 0
|
| 208 |
+
iter_idx = 1
|
| 209 |
+
model.eval()
|
| 210 |
+
|
| 211 |
+
# Evaluate data for one epoch
|
| 212 |
+
for batch in tqdm_eval_set:
|
| 213 |
+
minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
|
| 214 |
+
for minibatch in minibatches:
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
|
| 217 |
+
total_eval_loss += loss.item()
|
| 218 |
+
tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
|
| 219 |
+
iter_idx += 1
|
| 220 |
+
return total_eval_loss / (iter_idx-1)
|
| 221 |
+
|
| 222 |
+
# train and eval
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
|
| 225 |
+
# Initialize wandb
|
| 226 |
+
if WANDB_LOGGING and global_rank==0:
|
| 227 |
+
wandb.login(key=WANDB_KEY)
|
| 228 |
+
wandb.init(project="notagen",
|
| 229 |
+
name=WANDB_NAME)
|
| 230 |
+
|
| 231 |
+
# load data
|
| 232 |
+
with open(DATA_TRAIN_INDEX_PATH, "r", encoding="utf-8") as f:
|
| 233 |
+
print("Loading Data...")
|
| 234 |
+
train_files = []
|
| 235 |
+
for line in f:
|
| 236 |
+
train_files.append(json.loads(line))
|
| 237 |
+
|
| 238 |
+
with open(DATA_EVAL_INDEX_PATH, "r", encoding="utf-8") as f:
|
| 239 |
+
print("Loading Data...")
|
| 240 |
+
eval_files = []
|
| 241 |
+
for line in f:
|
| 242 |
+
eval_files.append(json.loads(line))
|
| 243 |
+
|
| 244 |
+
train_batch_nums = int(len(train_files) / batch_size)
|
| 245 |
+
eval_batch_nums = int(len(eval_files) / batch_size)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
random.shuffle(train_files)
|
| 249 |
+
random.shuffle(eval_files)
|
| 250 |
+
|
| 251 |
+
train_files = train_files[:train_batch_nums*batch_size]
|
| 252 |
+
eval_files = eval_files[:eval_batch_nums*batch_size]
|
| 253 |
+
|
| 254 |
+
train_set = NotaGenDataset(train_files)
|
| 255 |
+
eval_set = NotaGenDataset(eval_files)
|
| 256 |
+
|
| 257 |
+
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=local_rank)
|
| 258 |
+
eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=local_rank)
|
| 259 |
+
|
| 260 |
+
train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
|
| 261 |
+
eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
|
| 262 |
+
|
| 263 |
+
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=1000)
|
| 264 |
+
|
| 265 |
+
model = model.to(device)
|
| 266 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 267 |
+
|
| 268 |
+
if LOAD_FROM_CHECKPOINT and os.path.exists(WEIGHTS_PATH):
|
| 269 |
+
# Load checkpoint to CPU
|
| 270 |
+
checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
|
| 271 |
+
|
| 272 |
+
# Here, model is assumed to be on GPU
|
| 273 |
+
# Load state dict to CPU model first, then move the model to GPU
|
| 274 |
+
if torch.cuda.device_count() > 1:
|
| 275 |
+
# If you have a DataParallel model, you need to load to model.module instead
|
| 276 |
+
cpu_model = deepcopy(model.module)
|
| 277 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
| 278 |
+
model.module.load_state_dict(cpu_model.state_dict())
|
| 279 |
+
else:
|
| 280 |
+
# Load to a CPU clone of the model, then load back
|
| 281 |
+
cpu_model = deepcopy(model)
|
| 282 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
| 283 |
+
model.load_state_dict(cpu_model.state_dict())
|
| 284 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 285 |
+
lr_scheduler.load_state_dict(checkpoint['lr_sched'])
|
| 286 |
+
pre_epoch = checkpoint['epoch']
|
| 287 |
+
best_epoch = checkpoint['best_epoch']
|
| 288 |
+
min_eval_loss = checkpoint['min_eval_loss']
|
| 289 |
+
print("Successfully Loaded Checkpoint from Epoch %d" % pre_epoch)
|
| 290 |
+
checkpoint = None
|
| 291 |
+
|
| 292 |
+
else:
|
| 293 |
+
pre_epoch = 0
|
| 294 |
+
best_epoch = 0
|
| 295 |
+
min_eval_loss = 100
|
| 296 |
+
|
| 297 |
+
for epoch in range(1+pre_epoch, NUM_EPOCHS+1):
|
| 298 |
+
train_sampler.set_epoch(epoch)
|
| 299 |
+
eval_sampler.set_epoch(epoch)
|
| 300 |
+
print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
|
| 301 |
+
train_loss = train_epoch(epoch)
|
| 302 |
+
eval_loss = eval_epoch()
|
| 303 |
+
if global_rank==0:
|
| 304 |
+
with open(LOGS_PATH,'a') as f:
|
| 305 |
+
f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
|
| 306 |
+
if eval_loss < min_eval_loss:
|
| 307 |
+
best_epoch = epoch
|
| 308 |
+
min_eval_loss = eval_loss
|
| 309 |
+
checkpoint = {
|
| 310 |
+
'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
| 311 |
+
'optimizer': optimizer.state_dict(),
|
| 312 |
+
'lr_sched': lr_scheduler.state_dict(),
|
| 313 |
+
'epoch': epoch,
|
| 314 |
+
'best_epoch': best_epoch,
|
| 315 |
+
'min_eval_loss': min_eval_loss
|
| 316 |
+
}
|
| 317 |
+
torch.save(checkpoint, WEIGHTS_PATH)
|
| 318 |
+
|
| 319 |
+
if world_size > 1:
|
| 320 |
+
dist.barrier()
|
| 321 |
+
|
| 322 |
+
if global_rank==0:
|
| 323 |
+
print("Best Eval Epoch : "+str(best_epoch))
|
| 324 |
+
print("Min Eval Loss : "+str(min_eval_loss))
|
| 325 |
+
|
train-gen.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
import math
|
| 5 |
+
import json
|
| 6 |
+
import wandb
|
| 7 |
+
import torch
|
| 8 |
+
import random
|
| 9 |
+
import numpy as np
|
| 10 |
+
from abctoolkit.transpose import Key2index, Key2Mode
|
| 11 |
+
from utils import *
|
| 12 |
+
from config import *
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 16 |
+
from torch.utils.data import Dataset, DataLoader
|
| 17 |
+
from transformers import GPT2Config, LlamaConfig, get_scheduler, get_constant_schedule_with_warmup
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 20 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 21 |
+
|
| 22 |
+
Index2Key = {index: key for key, index in Key2index.items() if index not in [1, 11]}
|
| 23 |
+
Mode2Key = {mode: key for key, mode_list in Key2Mode.items() for mode in mode_list }
|
| 24 |
+
|
| 25 |
+
# Set up distributed training
|
| 26 |
+
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
| 27 |
+
global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
|
| 28 |
+
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
|
| 29 |
+
|
| 30 |
+
if world_size > 1:
|
| 31 |
+
torch.cuda.set_device(local_rank)
|
| 32 |
+
device = torch.device("cuda", local_rank)
|
| 33 |
+
dist.init_process_group(backend='nccl') if world_size > 1 else None
|
| 34 |
+
else:
|
| 35 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 36 |
+
|
| 37 |
+
# Set random seed
|
| 38 |
+
seed = 0 + global_rank
|
| 39 |
+
random.seed(seed)
|
| 40 |
+
np.random.seed(seed)
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
torch.cuda.manual_seed_all(seed)
|
| 43 |
+
torch.backends.cudnn.deterministic = True
|
| 44 |
+
torch.backends.cudnn.benchmark = False
|
| 45 |
+
|
| 46 |
+
batch_size = BATCH_SIZE
|
| 47 |
+
|
| 48 |
+
patchilizer = Patchilizer()
|
| 49 |
+
|
| 50 |
+
patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
|
| 51 |
+
max_length=PATCH_LENGTH,
|
| 52 |
+
max_position_embeddings=PATCH_LENGTH,
|
| 53 |
+
n_embd=HIDDEN_SIZE,
|
| 54 |
+
num_attention_heads=HIDDEN_SIZE//64,
|
| 55 |
+
vocab_size=1)
|
| 56 |
+
char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
| 57 |
+
max_length=PATCH_SIZE+1,
|
| 58 |
+
max_position_embeddings=PATCH_SIZE+1,
|
| 59 |
+
hidden_size=HIDDEN_SIZE,
|
| 60 |
+
num_attention_heads=HIDDEN_SIZE//64,
|
| 61 |
+
vocab_size=128)
|
| 62 |
+
|
| 63 |
+
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
|
| 64 |
+
|
| 65 |
+
model = model.to(device)
|
| 66 |
+
|
| 67 |
+
# print parameter number
|
| 68 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 69 |
+
|
| 70 |
+
if world_size > 1:
|
| 71 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
| 72 |
+
|
| 73 |
+
scaler = GradScaler()
|
| 74 |
+
is_autocast = True
|
| 75 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def clear_unused_tensors():
|
| 79 |
+
gc.disable() # Temporarily disable garbage collection
|
| 80 |
+
try:
|
| 81 |
+
# Get the set of tensor ids used by the model
|
| 82 |
+
if hasattr(model, "module"):
|
| 83 |
+
model_tensors = {id(p) for p in model.module.parameters()}
|
| 84 |
+
else:
|
| 85 |
+
model_tensors = {id(p) for p in model.parameters()}
|
| 86 |
+
|
| 87 |
+
# Get the set of tensor ids used by the optimizer
|
| 88 |
+
optimizer_tensors = {
|
| 89 |
+
id(state)
|
| 90 |
+
for state_dict in optimizer.state.values()
|
| 91 |
+
for state in state_dict.values()
|
| 92 |
+
if isinstance(state, torch.Tensor) # Ensure only tensors are considered
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# List of all CUDA tensors currently in memory
|
| 96 |
+
tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
|
| 97 |
+
|
| 98 |
+
# Create weak references to avoid interfering with garbage collection
|
| 99 |
+
tensor_refs = [weakref.ref(tensor) for tensor in tensors]
|
| 100 |
+
|
| 101 |
+
for tensor_ref in tensor_refs:
|
| 102 |
+
tensor = tensor_ref() # Dereference the weak reference
|
| 103 |
+
if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors:
|
| 104 |
+
# Mark the tensor for deletion
|
| 105 |
+
tensor.detach_() # Detach from computation graph
|
| 106 |
+
del tensor # Delete the tensor reference
|
| 107 |
+
except:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
finally:
|
| 111 |
+
gc.enable() # Re-enable garbage collection
|
| 112 |
+
gc.collect() # Force a garbage collection
|
| 113 |
+
torch.cuda.empty_cache() # Clear the CUDA cache
|
| 114 |
+
|
| 115 |
+
def collate_batch(input_batches):
|
| 116 |
+
|
| 117 |
+
input_patches, input_masks = zip(*input_batches)
|
| 118 |
+
input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=0)
|
| 119 |
+
input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0)
|
| 120 |
+
|
| 121 |
+
return input_patches.to(device), input_masks.to(device)
|
| 122 |
+
|
| 123 |
+
def split_into_minibatches(input_patches, input_masks, minibatch_size):
|
| 124 |
+
minibatches = []
|
| 125 |
+
for start_idx in range(0, len(input_patches), minibatch_size):
|
| 126 |
+
end_idx = start_idx + minibatch_size
|
| 127 |
+
minibatch_patches = input_patches[start_idx:end_idx]
|
| 128 |
+
minibatch_masks = input_masks[start_idx:end_idx]
|
| 129 |
+
minibatches.append((minibatch_patches, minibatch_masks))
|
| 130 |
+
return minibatches
|
| 131 |
+
|
| 132 |
+
class NotaGenDataset(Dataset):
|
| 133 |
+
def __init__(self, filenames):
|
| 134 |
+
self.filenames = filenames
|
| 135 |
+
|
| 136 |
+
def __len__(self):
|
| 137 |
+
return len(self.filenames)
|
| 138 |
+
|
| 139 |
+
def __getitem__(self, idx):
|
| 140 |
+
|
| 141 |
+
filepath = self.filenames[idx]['path']
|
| 142 |
+
ori_key = Mode2Key[self.filenames[idx]['key']]
|
| 143 |
+
|
| 144 |
+
# choose a key to transpose, according to a probility distribution
|
| 145 |
+
ori_key_index = Key2index[ori_key]
|
| 146 |
+
available_index = [(ori_key_index + offset) % 12 for offset in range(-3, 4)]
|
| 147 |
+
index_prob = [1/16, 2/16, 3/16, 4/16, 3/16, 2/16, 1/16]
|
| 148 |
+
index_prob_range = [0] + [sum(index_prob[0 : i + 1]) for i in range(len(index_prob))]
|
| 149 |
+
random_number = random.random()
|
| 150 |
+
for i in range(len(index_prob_range) - 1):
|
| 151 |
+
if index_prob_range[i] <= random_number < index_prob_range[i + 1]:
|
| 152 |
+
des_key_index = available_index[i]
|
| 153 |
+
if des_key_index == 1:
|
| 154 |
+
des_key = 'Db' if random.random() < 0.8 else 'C#'
|
| 155 |
+
elif des_key_index == 11:
|
| 156 |
+
des_key = 'B' if random.random() < 0.8 else 'Cb'
|
| 157 |
+
elif des_key_index == 6:
|
| 158 |
+
des_key = 'F#' if random.random() < 0.5 else 'Gb'
|
| 159 |
+
else:
|
| 160 |
+
des_key = Index2Key[des_key_index]
|
| 161 |
+
|
| 162 |
+
folder = os.path.dirname(filepath)
|
| 163 |
+
name = os.path.split(filepath)[-1]
|
| 164 |
+
des_filepath = os.path.join(folder, des_key, name + '_' + des_key + '.abc')
|
| 165 |
+
|
| 166 |
+
with open(des_filepath, 'r', encoding='utf-8') as f:
|
| 167 |
+
abc_text = f.read()
|
| 168 |
+
|
| 169 |
+
file_bytes = patchilizer.encode_train(abc_text)
|
| 170 |
+
file_masks = [1] * len(file_bytes)
|
| 171 |
+
|
| 172 |
+
file_bytes = torch.tensor(file_bytes, dtype=torch.long)
|
| 173 |
+
file_masks = torch.tensor(file_masks, dtype=torch.long)
|
| 174 |
+
|
| 175 |
+
return file_bytes, file_masks
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def process_one_batch(batch):
|
| 179 |
+
input_patches, input_masks = batch
|
| 180 |
+
loss = model(input_patches, input_masks).loss
|
| 181 |
+
|
| 182 |
+
# Reduce the loss on GPU 0
|
| 183 |
+
if world_size > 1:
|
| 184 |
+
loss = loss.unsqueeze(0)
|
| 185 |
+
dist.reduce(loss, dst=0)
|
| 186 |
+
loss = loss / world_size
|
| 187 |
+
dist.broadcast(loss, src=0)
|
| 188 |
+
|
| 189 |
+
return loss
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# do one epoch for training
|
| 193 |
+
def train_epoch(epoch):
|
| 194 |
+
tqdm_train_set = tqdm(train_set)
|
| 195 |
+
total_train_loss = 0
|
| 196 |
+
iter_idx = 1
|
| 197 |
+
model.train()
|
| 198 |
+
train_steps = (epoch-1)*len(train_set)
|
| 199 |
+
|
| 200 |
+
for batch in tqdm_train_set:
|
| 201 |
+
minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
|
| 202 |
+
for minibatch in minibatches:
|
| 203 |
+
with autocast():
|
| 204 |
+
loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
|
| 205 |
+
scaler.scale(loss).backward()
|
| 206 |
+
total_train_loss += loss.item()
|
| 207 |
+
scaler.step(optimizer)
|
| 208 |
+
scaler.update()
|
| 209 |
+
|
| 210 |
+
lr_scheduler.step()
|
| 211 |
+
model.zero_grad(set_to_none=True)
|
| 212 |
+
tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
|
| 213 |
+
train_steps += 1
|
| 214 |
+
|
| 215 |
+
# Log the training loss to wandb
|
| 216 |
+
if global_rank==0 and WANDB_LOGGING:
|
| 217 |
+
wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
|
| 218 |
+
|
| 219 |
+
iter_idx += 1
|
| 220 |
+
if iter_idx % 1000 == 0:
|
| 221 |
+
clear_unused_tensors()
|
| 222 |
+
|
| 223 |
+
return total_train_loss / (iter_idx-1)
|
| 224 |
+
|
| 225 |
+
# do one epoch for eval
|
| 226 |
+
def eval_epoch():
|
| 227 |
+
tqdm_eval_set = tqdm(eval_set)
|
| 228 |
+
total_eval_loss = 0
|
| 229 |
+
total_eval_bpb = 0
|
| 230 |
+
iter_idx = 1
|
| 231 |
+
model.eval()
|
| 232 |
+
|
| 233 |
+
# Evaluate data for one epoch
|
| 234 |
+
for batch in tqdm_eval_set:
|
| 235 |
+
minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
|
| 236 |
+
for minibatch in minibatches:
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
|
| 239 |
+
total_eval_loss += loss.item()
|
| 240 |
+
tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
|
| 241 |
+
iter_idx += 1
|
| 242 |
+
return total_eval_loss / (iter_idx-1)
|
| 243 |
+
|
| 244 |
+
# train and eval
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
|
| 247 |
+
# Initialize wandb
|
| 248 |
+
if WANDB_LOGGING and global_rank==0:
|
| 249 |
+
wandb.login(key=WANDB_KEY)
|
| 250 |
+
wandb.init(project="notagen",
|
| 251 |
+
name=WANDB_NAME)
|
| 252 |
+
|
| 253 |
+
# load data
|
| 254 |
+
with open(DATA_TRAIN_INDEX_PATH, "r", encoding="utf-8") as f:
|
| 255 |
+
print("Loading Data...")
|
| 256 |
+
train_files = []
|
| 257 |
+
for line in f:
|
| 258 |
+
train_files.append(json.loads(line))
|
| 259 |
+
|
| 260 |
+
with open(DATA_EVAL_INDEX_PATH, "r", encoding="utf-8") as f:
|
| 261 |
+
print("Loading Data...")
|
| 262 |
+
eval_files = []
|
| 263 |
+
for line in f:
|
| 264 |
+
eval_files.append(json.loads(line))
|
| 265 |
+
|
| 266 |
+
if len(eval_files) == 0:
|
| 267 |
+
train_files, eval_files = split_data(train_files)
|
| 268 |
+
|
| 269 |
+
train_batch_nums = int(len(train_files) / batch_size)
|
| 270 |
+
eval_batch_nums = int(len(eval_files) / batch_size)
|
| 271 |
+
|
| 272 |
+
random.shuffle(train_files)
|
| 273 |
+
random.shuffle(eval_files)
|
| 274 |
+
|
| 275 |
+
train_files = train_files[:train_batch_nums*batch_size]
|
| 276 |
+
eval_files = eval_files[:eval_batch_nums*batch_size]
|
| 277 |
+
|
| 278 |
+
train_set = NotaGenDataset(train_files)
|
| 279 |
+
eval_set = NotaGenDataset(eval_files)
|
| 280 |
+
|
| 281 |
+
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=local_rank)
|
| 282 |
+
eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=local_rank)
|
| 283 |
+
|
| 284 |
+
train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
|
| 285 |
+
eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
|
| 286 |
+
|
| 287 |
+
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=1000)
|
| 288 |
+
|
| 289 |
+
model = model.to(device)
|
| 290 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 291 |
+
|
| 292 |
+
if not LOAD_FROM_CHECKPOINT:
|
| 293 |
+
if os.path.exists(PRETRAINED_PATH):
|
| 294 |
+
# Load pre-trained checkpoint to CPU
|
| 295 |
+
checkpoint = torch.load(PRETRAINED_PATH, map_location='cpu')
|
| 296 |
+
|
| 297 |
+
# Here, model is assumed to be on GPU
|
| 298 |
+
# Load state dict to CPU model first, then move the model to GPU
|
| 299 |
+
if torch.cuda.device_count() > 1:
|
| 300 |
+
# If you have a DataParallel model, you need to load to model.module instead
|
| 301 |
+
cpu_model = deepcopy(model.module)
|
| 302 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
| 303 |
+
model.module.load_state_dict(cpu_model.state_dict())
|
| 304 |
+
else:
|
| 305 |
+
# Load to a CPU clone of the model, then load back
|
| 306 |
+
cpu_model = deepcopy(model)
|
| 307 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
| 308 |
+
model.load_state_dict(cpu_model.state_dict())
|
| 309 |
+
|
| 310 |
+
print(f"Successfully Loaded Pretrained Checkpoint at Epoch {checkpoint['epoch']} with Loss {checkpoint['min_eval_loss']}")
|
| 311 |
+
|
| 312 |
+
pre_epoch = 0
|
| 313 |
+
best_epoch = 0
|
| 314 |
+
min_eval_loss = 100
|
| 315 |
+
else:
|
| 316 |
+
raise Exception('Pre-trained Checkpoint not found. Please check your pre-trained ckpt path.')
|
| 317 |
+
|
| 318 |
+
else:
|
| 319 |
+
if os.path.exists(WEIGHTS_PATH):
|
| 320 |
+
# Load checkpoint to CPU
|
| 321 |
+
checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
|
| 322 |
+
|
| 323 |
+
# Here, model is assumed to be on GPU
|
| 324 |
+
# Load state dict to CPU model first, then move the model to GPU
|
| 325 |
+
if torch.cuda.device_count() > 1:
|
| 326 |
+
# If you have a DataParallel model, you need to load to model.module instead
|
| 327 |
+
cpu_model = deepcopy(model.module)
|
| 328 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
| 329 |
+
model.module.load_state_dict(cpu_model.state_dict())
|
| 330 |
+
else:
|
| 331 |
+
# Load to a CPU clone of the model, then load back
|
| 332 |
+
cpu_model = deepcopy(model)
|
| 333 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
| 334 |
+
model.load_state_dict(cpu_model.state_dict())
|
| 335 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 336 |
+
lr_scheduler.load_state_dict(checkpoint['lr_sched'])
|
| 337 |
+
pre_epoch = checkpoint['epoch']
|
| 338 |
+
best_epoch = checkpoint['best_epoch']
|
| 339 |
+
min_eval_loss = checkpoint['min_eval_loss']
|
| 340 |
+
print("Successfully Loaded Checkpoint from Epoch %d" % pre_epoch)
|
| 341 |
+
checkpoint = None
|
| 342 |
+
|
| 343 |
+
else:
|
| 344 |
+
raise Exception('Checkpoint not found to continue training. Please check your parameter settings.')
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
for epoch in range(1+pre_epoch, NUM_EPOCHS+1):
|
| 348 |
+
train_sampler.set_epoch(epoch)
|
| 349 |
+
eval_sampler.set_epoch(epoch)
|
| 350 |
+
print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
|
| 351 |
+
train_loss = train_epoch(epoch)
|
| 352 |
+
eval_loss = eval_epoch()
|
| 353 |
+
if global_rank==0:
|
| 354 |
+
with open(LOGS_PATH,'a') as f:
|
| 355 |
+
f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
|
| 356 |
+
if eval_loss < min_eval_loss:
|
| 357 |
+
best_epoch = epoch
|
| 358 |
+
min_eval_loss = eval_loss
|
| 359 |
+
checkpoint = {
|
| 360 |
+
'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
| 361 |
+
'optimizer': optimizer.state_dict(),
|
| 362 |
+
'lr_sched': lr_scheduler.state_dict(),
|
| 363 |
+
'epoch': epoch,
|
| 364 |
+
'best_epoch': best_epoch,
|
| 365 |
+
'min_eval_loss': min_eval_loss
|
| 366 |
+
}
|
| 367 |
+
torch.save(checkpoint, WEIGHTS_PATH)
|
| 368 |
+
|
| 369 |
+
if world_size > 1:
|
| 370 |
+
dist.barrier()
|
| 371 |
+
|
| 372 |
+
if global_rank==0:
|
| 373 |
+
print("Best Eval Epoch : "+str(best_epoch))
|
| 374 |
+
print("Min Eval Loss : "+str(min_eval_loss))
|
train.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
import math
|
| 5 |
+
import json
|
| 6 |
+
import wandb
|
| 7 |
+
import torch
|
| 8 |
+
import random
|
| 9 |
+
import numpy as np
|
| 10 |
+
from abctoolkit.transpose import Key2index, Key2Mode
|
| 11 |
+
from utils import *
|
| 12 |
+
from config import *
|
| 13 |
+
from data import generate_preference_dict
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
from torch.utils.data import Dataset, DataLoader
|
| 17 |
+
from transformers import GPT2Config, get_scheduler, get_constant_schedule_with_warmup
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 21 |
+
|
| 22 |
+
# Set random seed
|
| 23 |
+
seed = 0
|
| 24 |
+
random.seed(seed)
|
| 25 |
+
np.random.seed(seed)
|
| 26 |
+
torch.manual_seed(seed)
|
| 27 |
+
torch.cuda.manual_seed_all(seed)
|
| 28 |
+
torch.backends.cudnn.deterministic = True
|
| 29 |
+
torch.backends.cudnn.benchmark = False
|
| 30 |
+
|
| 31 |
+
patchilizer = Patchilizer()
|
| 32 |
+
|
| 33 |
+
patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
|
| 34 |
+
max_length=PATCH_LENGTH,
|
| 35 |
+
max_position_embeddings=PATCH_LENGTH,
|
| 36 |
+
n_embd=HIDDEN_SIZE,
|
| 37 |
+
num_attention_heads=HIDDEN_SIZE//64,
|
| 38 |
+
vocab_size=1)
|
| 39 |
+
char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
| 40 |
+
max_length=PATCH_SIZE+1,
|
| 41 |
+
max_position_embeddings=PATCH_SIZE+1,
|
| 42 |
+
hidden_size=HIDDEN_SIZE,
|
| 43 |
+
num_attention_heads=HIDDEN_SIZE//64,
|
| 44 |
+
vocab_size=128)
|
| 45 |
+
|
| 46 |
+
model_ref = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
|
| 47 |
+
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
model_ref = model_ref.to(device)
|
| 51 |
+
model = model.to(device)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# print parameter number
|
| 55 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
| 56 |
+
|
| 57 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def collate_batch(input_batches):
|
| 61 |
+
pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = input_batches
|
| 62 |
+
pos_input_patches = pos_input_patches.unsqueeze(0)
|
| 63 |
+
pos_input_masks = pos_input_masks.unsqueeze(0)
|
| 64 |
+
neg_input_patches = neg_input_patches.unsqueeze(0)
|
| 65 |
+
neg_input_masks = neg_input_masks.unsqueeze(0)
|
| 66 |
+
pos_input_patches = torch.nn.utils.rnn.pad_sequence(pos_input_patches, batch_first=True, padding_value=0)
|
| 67 |
+
pos_input_masks = torch.nn.utils.rnn.pad_sequence(pos_input_masks, batch_first=True, padding_value=0)
|
| 68 |
+
neg_input_patches = torch.nn.utils.rnn.pad_sequence(neg_input_patches, batch_first=True, padding_value=0)
|
| 69 |
+
neg_input_masks = torch.nn.utils.rnn.pad_sequence(neg_input_masks, batch_first=True, padding_value=0)
|
| 70 |
+
return (pos_input_patches.to(device), pos_input_masks.to(device),
|
| 71 |
+
neg_input_patches.to(device), neg_input_masks.to(device))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class NotaGenDataset(Dataset):
|
| 75 |
+
def __init__(self, preference_dict):
|
| 76 |
+
self.preference_dict = preference_dict
|
| 77 |
+
self.pair_list = []
|
| 78 |
+
for pos_filepath in self.preference_dict['chosen']:
|
| 79 |
+
for neg_filepath in self.preference_dict['rejected']:
|
| 80 |
+
self.pair_list.append({'chosen': pos_filepath, 'rejected': neg_filepath})
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self.pair_list)
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, idx):
|
| 86 |
+
try:
|
| 87 |
+
pair = self.pair_list[idx]
|
| 88 |
+
pos_filepath = pair['chosen']
|
| 89 |
+
neg_filepath = pair['rejected']
|
| 90 |
+
|
| 91 |
+
with open(pos_filepath, 'r', encoding='utf-8') as f:
|
| 92 |
+
pos_abc_text = f.read()
|
| 93 |
+
with open(neg_filepath, 'r', encoding='utf-8') as f:
|
| 94 |
+
neg_abc_text = f.read()
|
| 95 |
+
|
| 96 |
+
pos_file_bytes = patchilizer.encode(pos_abc_text)
|
| 97 |
+
pos_file_masks = [1] * len(pos_file_bytes)
|
| 98 |
+
neg_file_bytes = patchilizer.encode(neg_abc_text)
|
| 99 |
+
neg_file_masks = [1] * len(neg_file_bytes)
|
| 100 |
+
|
| 101 |
+
pos_file_bytes = torch.tensor(pos_file_bytes, dtype=torch.long)
|
| 102 |
+
pos_file_masks = torch.tensor(pos_file_masks, dtype=torch.long)
|
| 103 |
+
neg_file_bytes = torch.tensor(neg_file_bytes, dtype=torch.long)
|
| 104 |
+
neg_file_masks = torch.tensor(neg_file_masks, dtype=torch.long)
|
| 105 |
+
|
| 106 |
+
return pos_file_bytes, pos_file_masks, neg_file_bytes, neg_file_masks
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(e)
|
| 109 |
+
return self.__getitem__((idx+1) % len(self.pair_list))
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def process_one_batch(batch):
|
| 113 |
+
pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = batch
|
| 114 |
+
pos_input_patches_ref = pos_input_patches.clone()
|
| 115 |
+
pos_input_masks_ref = pos_input_masks.clone()
|
| 116 |
+
neg_input_patches_ref = neg_input_patches.clone()
|
| 117 |
+
neg_input_masks_ref = neg_input_masks.clone()
|
| 118 |
+
policy_pos_logps = model(pos_input_patches, pos_input_masks)
|
| 119 |
+
policy_neg_logps = model(neg_input_patches, neg_input_masks)
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
ref_pos_logps = model_ref(pos_input_patches_ref, pos_input_masks_ref).detach()
|
| 122 |
+
ref_neg_logps = model_ref(neg_input_patches_ref, neg_input_masks_ref).detach()
|
| 123 |
+
logits = (policy_pos_logps - policy_neg_logps) - (ref_pos_logps - ref_neg_logps)
|
| 124 |
+
loss = - torch.nn.functional.logsigmoid(BETA * (logits - LAMBDA * max(0, ref_pos_logps - policy_pos_logps)))
|
| 125 |
+
return loss
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# train
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
|
| 132 |
+
# Initialize wandb
|
| 133 |
+
if WANDB_LOGGING:
|
| 134 |
+
wandb.login(key=WANDB_KEY)
|
| 135 |
+
wandb.init(project="notagen",
|
| 136 |
+
name=WANDB_NAME)
|
| 137 |
+
|
| 138 |
+
# load data
|
| 139 |
+
with open(DATA_INDEX_PATH, 'r') as f:
|
| 140 |
+
preference_dict = json.loads(f.read())
|
| 141 |
+
|
| 142 |
+
train_set = NotaGenDataset(preference_dict)
|
| 143 |
+
|
| 144 |
+
# Load model actor/ref
|
| 145 |
+
if os.path.exists(PRETRAINED_PATH):
|
| 146 |
+
checkpoint = torch.load(PRETRAINED_PATH, map_location='cpu')
|
| 147 |
+
cpu_model = deepcopy(model)
|
| 148 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
| 149 |
+
model.load_state_dict(cpu_model.state_dict())
|
| 150 |
+
cpu_model_ref = deepcopy(model_ref)
|
| 151 |
+
cpu_model_ref.load_state_dict(checkpoint['model'])
|
| 152 |
+
model_ref.load_state_dict(cpu_model_ref.state_dict())
|
| 153 |
+
else:
|
| 154 |
+
raise Exception('No pre-trained model loaded.')
|
| 155 |
+
|
| 156 |
+
model.train()
|
| 157 |
+
total_train_loss = 0
|
| 158 |
+
iter_idx = 1
|
| 159 |
+
|
| 160 |
+
tqdm_set = tqdm(range(OPTIMIZATION_STEPS))
|
| 161 |
+
for i in tqdm_set:
|
| 162 |
+
idx = random.randint(0, len(train_set)-1)
|
| 163 |
+
batch = train_set[idx]
|
| 164 |
+
batch = collate_batch(batch)
|
| 165 |
+
|
| 166 |
+
loss = process_one_batch(batch)
|
| 167 |
+
total_train_loss += loss.item()
|
| 168 |
+
|
| 169 |
+
loss.backward()
|
| 170 |
+
torch.nn.utils.clip_grad_norm(model.parameters(),max_norm=1.0 )
|
| 171 |
+
optimizer.step()
|
| 172 |
+
|
| 173 |
+
model.zero_grad(set_to_none=True)
|
| 174 |
+
tqdm_set.set_postfix({'train_loss': total_train_loss / (i + 1)})
|
| 175 |
+
|
| 176 |
+
# Log the training loss to wandb
|
| 177 |
+
if WANDB_LOGGING:
|
| 178 |
+
wandb.log({"train_loss": total_train_loss / (i + 1)}, step=i+1)
|
| 179 |
+
|
| 180 |
+
checkpoint = {'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict()}
|
| 181 |
+
|
| 182 |
+
torch.save(checkpoint, WEIGHTS_PATH)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
utils (1).py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
from config import *
|
| 7 |
+
from unidecode import unidecode
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import torch.distributed.nn
|
| 13 |
+
from torch import distributed as dist
|
| 14 |
+
|
| 15 |
+
has_distributed = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
has_distributed = False
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import horovod.torch as hvd
|
| 21 |
+
except ImportError:
|
| 22 |
+
hvd = None
|
| 23 |
+
|
| 24 |
+
class ClipLoss(torch.nn.Module):
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
local_loss=False,
|
| 29 |
+
gather_with_grad=False,
|
| 30 |
+
cache_labels=False,
|
| 31 |
+
rank=0,
|
| 32 |
+
world_size=1,
|
| 33 |
+
use_horovod=False,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.local_loss = local_loss
|
| 37 |
+
self.gather_with_grad = gather_with_grad
|
| 38 |
+
self.cache_labels = cache_labels
|
| 39 |
+
self.rank = rank
|
| 40 |
+
self.world_size = world_size
|
| 41 |
+
self.use_horovod = use_horovod
|
| 42 |
+
|
| 43 |
+
# cache state
|
| 44 |
+
self.prev_num_logits = 0
|
| 45 |
+
self.labels = {}
|
| 46 |
+
|
| 47 |
+
def gather_features(
|
| 48 |
+
self,
|
| 49 |
+
image_features,
|
| 50 |
+
text_features,
|
| 51 |
+
local_loss=False,
|
| 52 |
+
gather_with_grad=False,
|
| 53 |
+
rank=0,
|
| 54 |
+
world_size=1,
|
| 55 |
+
use_horovod=False
|
| 56 |
+
):
|
| 57 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
| 58 |
+
if use_horovod:
|
| 59 |
+
assert hvd is not None, 'Please install horovod'
|
| 60 |
+
if gather_with_grad:
|
| 61 |
+
all_image_features = hvd.allgather(image_features)
|
| 62 |
+
all_text_features = hvd.allgather(text_features)
|
| 63 |
+
else:
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
all_image_features = hvd.allgather(image_features)
|
| 66 |
+
all_text_features = hvd.allgather(text_features)
|
| 67 |
+
if not local_loss:
|
| 68 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
| 69 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
| 70 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
| 71 |
+
gathered_image_features[rank] = image_features
|
| 72 |
+
gathered_text_features[rank] = text_features
|
| 73 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 74 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 75 |
+
else:
|
| 76 |
+
# We gather tensors from all gpus
|
| 77 |
+
if gather_with_grad:
|
| 78 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
| 79 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
| 80 |
+
else:
|
| 81 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
| 82 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
| 83 |
+
dist.all_gather(gathered_image_features, image_features)
|
| 84 |
+
dist.all_gather(gathered_text_features, text_features)
|
| 85 |
+
if not local_loss:
|
| 86 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
| 87 |
+
gathered_image_features[rank] = image_features
|
| 88 |
+
gathered_text_features[rank] = text_features
|
| 89 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 90 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 91 |
+
|
| 92 |
+
return all_image_features, all_text_features
|
| 93 |
+
|
| 94 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
| 95 |
+
# calculated ground-truth and cache if enabled
|
| 96 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
| 97 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
| 98 |
+
if self.world_size > 1 and self.local_loss:
|
| 99 |
+
labels = labels + num_logits * self.rank
|
| 100 |
+
if self.cache_labels:
|
| 101 |
+
self.labels[device] = labels
|
| 102 |
+
self.prev_num_logits = num_logits
|
| 103 |
+
else:
|
| 104 |
+
labels = self.labels[device]
|
| 105 |
+
return labels
|
| 106 |
+
|
| 107 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
| 108 |
+
if self.world_size > 1:
|
| 109 |
+
all_image_features, all_text_features = self.gather_features(
|
| 110 |
+
image_features, text_features,
|
| 111 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
| 112 |
+
|
| 113 |
+
if self.local_loss:
|
| 114 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
| 115 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
| 116 |
+
else:
|
| 117 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
| 118 |
+
logits_per_text = logits_per_image.T
|
| 119 |
+
else:
|
| 120 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
| 121 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
| 122 |
+
|
| 123 |
+
return logits_per_image, logits_per_text
|
| 124 |
+
|
| 125 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
| 126 |
+
device = image_features.device
|
| 127 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
| 128 |
+
|
| 129 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
| 130 |
+
|
| 131 |
+
total_loss = (
|
| 132 |
+
F.cross_entropy(logits_per_image, labels) +
|
| 133 |
+
F.cross_entropy(logits_per_text, labels)
|
| 134 |
+
) / 2
|
| 135 |
+
|
| 136 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
| 137 |
+
|
| 138 |
+
class M3Patchilizer:
|
| 139 |
+
def __init__(self):
|
| 140 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
| 141 |
+
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
| 142 |
+
self.pad_token_id = 0
|
| 143 |
+
self.bos_token_id = 1
|
| 144 |
+
self.eos_token_id = 2
|
| 145 |
+
self.mask_token_id = 3
|
| 146 |
+
|
| 147 |
+
def split_bars(self, body):
|
| 148 |
+
bars = re.split(self.regexPattern, ''.join(body))
|
| 149 |
+
bars = list(filter(None, bars)) # remove empty strings
|
| 150 |
+
if bars[0] in self.delimiters:
|
| 151 |
+
bars[1] = bars[0] + bars[1]
|
| 152 |
+
bars = bars[1:]
|
| 153 |
+
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
|
| 154 |
+
return bars
|
| 155 |
+
|
| 156 |
+
def bar2patch(self, bar, patch_size=PATCH_SIZE):
|
| 157 |
+
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
|
| 158 |
+
patch = patch[:patch_size]
|
| 159 |
+
patch += [self.pad_token_id] * (patch_size - len(patch))
|
| 160 |
+
return patch
|
| 161 |
+
|
| 162 |
+
def patch2bar(self, patch):
|
| 163 |
+
return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch)
|
| 164 |
+
|
| 165 |
+
def encode(self,
|
| 166 |
+
item,
|
| 167 |
+
patch_size=PATCH_SIZE,
|
| 168 |
+
add_special_patches=False,
|
| 169 |
+
truncate=False,
|
| 170 |
+
random_truncate=False):
|
| 171 |
+
|
| 172 |
+
item = unidecode(item)
|
| 173 |
+
lines = re.findall(r'.*?\n|.*$', item)
|
| 174 |
+
lines = list(filter(None, lines)) # remove empty lines
|
| 175 |
+
|
| 176 |
+
patches = []
|
| 177 |
+
|
| 178 |
+
if lines[0].split(" ")[0] == "ticks_per_beat":
|
| 179 |
+
patch = ""
|
| 180 |
+
for line in lines:
|
| 181 |
+
if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2):
|
| 182 |
+
patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:])
|
| 183 |
+
else:
|
| 184 |
+
if patch:
|
| 185 |
+
patches.append(patch)
|
| 186 |
+
patch = line
|
| 187 |
+
if patch!="":
|
| 188 |
+
patches.append(patch)
|
| 189 |
+
else:
|
| 190 |
+
for line in lines:
|
| 191 |
+
if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')):
|
| 192 |
+
patches.append(line)
|
| 193 |
+
else:
|
| 194 |
+
bars = self.split_bars(line)
|
| 195 |
+
if bars:
|
| 196 |
+
bars[-1] += '\n'
|
| 197 |
+
patches.extend(bars)
|
| 198 |
+
|
| 199 |
+
if add_special_patches:
|
| 200 |
+
bos_patch = chr(self.bos_token_id) * patch_size
|
| 201 |
+
eos_patch = chr(self.eos_token_id) * patch_size
|
| 202 |
+
patches = [bos_patch] + patches + [eos_patch]
|
| 203 |
+
|
| 204 |
+
if len(patches) > PATCH_LENGTH and truncate:
|
| 205 |
+
choices = ["head", "tail", "middle"]
|
| 206 |
+
choice = random.choice(choices)
|
| 207 |
+
if choice=="head" or random_truncate==False:
|
| 208 |
+
patches = patches[:PATCH_LENGTH]
|
| 209 |
+
elif choice=="tail":
|
| 210 |
+
patches = patches[-PATCH_LENGTH:]
|
| 211 |
+
else:
|
| 212 |
+
start = random.randint(1, len(patches)-PATCH_LENGTH)
|
| 213 |
+
patches = patches[start:start+PATCH_LENGTH]
|
| 214 |
+
|
| 215 |
+
patches = [self.bar2patch(patch) for patch in patches]
|
| 216 |
+
|
| 217 |
+
return patches
|
| 218 |
+
|
| 219 |
+
def decode(self, patches):
|
| 220 |
+
return ''.join(self.patch2bar(patch) for patch in patches)
|
| 221 |
+
|
| 222 |
+
class M3PatchEncoder(PreTrainedModel):
|
| 223 |
+
def __init__(self, config):
|
| 224 |
+
super(M3PatchEncoder, self).__init__(config)
|
| 225 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE)
|
| 226 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
| 227 |
+
self.base = BertModel(config=config)
|
| 228 |
+
self.pad_token_id = 0
|
| 229 |
+
self.bos_token_id = 1
|
| 230 |
+
self.eos_token_id = 2
|
| 231 |
+
self.mask_token_id = 3
|
| 232 |
+
|
| 233 |
+
def forward(self,
|
| 234 |
+
input_patches, # [batch_size, seq_length, hidden_size]
|
| 235 |
+
input_masks): # [batch_size, seq_length]
|
| 236 |
+
# Transform input_patches into embeddings
|
| 237 |
+
input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128)
|
| 238 |
+
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor)
|
| 239 |
+
input_patches = self.patch_embedding(input_patches.to(self.device))
|
| 240 |
+
|
| 241 |
+
# Apply BERT model to input_patches and input_masks
|
| 242 |
+
return self.base(inputs_embeds=input_patches, attention_mask=input_masks)
|
| 243 |
+
|
| 244 |
+
class M3TokenDecoder(PreTrainedModel):
|
| 245 |
+
def __init__(self, config):
|
| 246 |
+
super(M3TokenDecoder, self).__init__(config)
|
| 247 |
+
self.base = GPT2LMHeadModel(config=config)
|
| 248 |
+
self.pad_token_id = 0
|
| 249 |
+
self.bos_token_id = 1
|
| 250 |
+
self.eos_token_id = 2
|
| 251 |
+
self.mask_token_id = 3
|
| 252 |
+
|
| 253 |
+
def forward(self,
|
| 254 |
+
patch_features, # [batch_size, hidden_size]
|
| 255 |
+
target_patches): # [batch_size, seq_length]
|
| 256 |
+
# get input embeddings
|
| 257 |
+
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
| 258 |
+
|
| 259 |
+
# concatenate the encoded patches with the input embeddings
|
| 260 |
+
inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
|
| 261 |
+
|
| 262 |
+
# preparing the labels for model training
|
| 263 |
+
target_masks = target_patches == self.pad_token_id
|
| 264 |
+
target_patches = target_patches.clone().masked_fill_(target_masks, -100)
|
| 265 |
+
|
| 266 |
+
# get the attention mask
|
| 267 |
+
target_masks = ~target_masks
|
| 268 |
+
target_masks = target_masks.type(torch.int)
|
| 269 |
+
|
| 270 |
+
return self.base(inputs_embeds=inputs_embeds,
|
| 271 |
+
attention_mask=target_masks,
|
| 272 |
+
labels=target_patches)
|
| 273 |
+
|
| 274 |
+
def generate(self,
|
| 275 |
+
patch_feature,
|
| 276 |
+
tokens):
|
| 277 |
+
# reshape the patch_feature and tokens
|
| 278 |
+
patch_feature = patch_feature.reshape(1, 1, -1)
|
| 279 |
+
tokens = tokens.reshape(1, -1)
|
| 280 |
+
|
| 281 |
+
# get input embeddings
|
| 282 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
| 283 |
+
|
| 284 |
+
# concatenate the encoded patches with the input embeddings
|
| 285 |
+
tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1)
|
| 286 |
+
|
| 287 |
+
# get the outputs from the model
|
| 288 |
+
outputs = self.base(inputs_embeds=tokens)
|
| 289 |
+
|
| 290 |
+
# get the probabilities of the next token
|
| 291 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
| 292 |
+
|
| 293 |
+
return probs.detach().cpu().numpy()
|
| 294 |
+
|
| 295 |
+
class M3Model(PreTrainedModel):
|
| 296 |
+
def __init__(self, encoder_config, decoder_config):
|
| 297 |
+
super(M3Model, self).__init__(encoder_config)
|
| 298 |
+
self.encoder = M3PatchEncoder(encoder_config)
|
| 299 |
+
self.decoder = M3TokenDecoder(decoder_config)
|
| 300 |
+
self.pad_token_id = 0
|
| 301 |
+
self.bos_token_id = 1
|
| 302 |
+
self.eos_token_id = 2
|
| 303 |
+
self.mask_token_id = 3
|
| 304 |
+
|
| 305 |
+
def forward(self,
|
| 306 |
+
input_patches, # [batch_size, seq_length, hidden_size]
|
| 307 |
+
input_masks, # [batch_size, seq_length]
|
| 308 |
+
selected_indices, # [batch_size, seq_length]
|
| 309 |
+
target_patches): # [batch_size, seq_length, hidden_size]
|
| 310 |
+
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device)
|
| 311 |
+
input_masks = input_masks.to(self.device)
|
| 312 |
+
selected_indices = selected_indices.to(self.device)
|
| 313 |
+
target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device)
|
| 314 |
+
|
| 315 |
+
# Pass the input_patches and input_masks through the encoder
|
| 316 |
+
outputs = self.encoder(input_patches, input_masks)["last_hidden_state"]
|
| 317 |
+
|
| 318 |
+
# Use selected_indices to form target_patches
|
| 319 |
+
target_patches = target_patches[selected_indices.bool()]
|
| 320 |
+
patch_features = outputs[selected_indices.bool()]
|
| 321 |
+
|
| 322 |
+
# Pass patch_features and target_patches through the decoder
|
| 323 |
+
return self.decoder(patch_features, target_patches)
|
| 324 |
+
|
| 325 |
+
class CLaMP2Model(PreTrainedModel):
|
| 326 |
+
def __init__(self,
|
| 327 |
+
music_config,
|
| 328 |
+
global_rank=None,
|
| 329 |
+
world_size=None,
|
| 330 |
+
text_model_name=TEXT_MODEL_NAME,
|
| 331 |
+
hidden_size=CLAMP2_HIDDEN_SIZE,
|
| 332 |
+
load_m3=CLAMP2_LOAD_M3):
|
| 333 |
+
super(CLaMP2Model, self).__init__(music_config)
|
| 334 |
+
|
| 335 |
+
self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model
|
| 336 |
+
self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections
|
| 337 |
+
torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution
|
| 338 |
+
|
| 339 |
+
self.music_model = M3PatchEncoder(music_config) # Initialize the music model
|
| 340 |
+
self.music_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for music projections
|
| 341 |
+
torch.nn.init.normal_(self.music_proj.weight, std=0.02) # Initialize weights with normal distribution
|
| 342 |
+
|
| 343 |
+
if global_rank==None or world_size==None:
|
| 344 |
+
global_rank = 0
|
| 345 |
+
world_size = 1
|
| 346 |
+
|
| 347 |
+
self.loss_fn = ClipLoss(local_loss=False,
|
| 348 |
+
gather_with_grad=True,
|
| 349 |
+
cache_labels=False,
|
| 350 |
+
rank=global_rank,
|
| 351 |
+
world_size=world_size,
|
| 352 |
+
use_horovod=False)
|
| 353 |
+
|
| 354 |
+
if load_m3 and os.path.exists(M3_WEIGHTS_PATH):
|
| 355 |
+
checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
|
| 356 |
+
decoder_config = GPT2Config(vocab_size=128,
|
| 357 |
+
n_positions=PATCH_SIZE,
|
| 358 |
+
n_embd=M3_HIDDEN_SIZE,
|
| 359 |
+
n_layer=TOKEN_NUM_LAYERS,
|
| 360 |
+
n_head=M3_HIDDEN_SIZE//64,
|
| 361 |
+
n_inner=M3_HIDDEN_SIZE*4)
|
| 362 |
+
model = M3Model(music_config, decoder_config)
|
| 363 |
+
model.load_state_dict(checkpoint['model'])
|
| 364 |
+
self.music_model = model.encoder
|
| 365 |
+
model = None
|
| 366 |
+
print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
|
| 367 |
+
|
| 368 |
+
def avg_pooling(self, input_features, input_masks):
|
| 369 |
+
input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension
|
| 370 |
+
input_features = input_features * input_masks # apply mask to input_features
|
| 371 |
+
avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling
|
| 372 |
+
|
| 373 |
+
return avg_pool
|
| 374 |
+
|
| 375 |
+
def get_text_features(self,
|
| 376 |
+
text_inputs,
|
| 377 |
+
text_masks,
|
| 378 |
+
get_normalized=False):
|
| 379 |
+
text_features = self.text_model(text_inputs.to(self.device),
|
| 380 |
+
attention_mask=text_masks.to(self.device))['last_hidden_state']
|
| 381 |
+
|
| 382 |
+
if get_normalized:
|
| 383 |
+
text_features = self.avg_pooling(text_features, text_masks)
|
| 384 |
+
text_features = self.text_proj(text_features)
|
| 385 |
+
|
| 386 |
+
return text_features
|
| 387 |
+
|
| 388 |
+
def get_music_features(self,
|
| 389 |
+
music_inputs,
|
| 390 |
+
music_masks,
|
| 391 |
+
get_normalized=False):
|
| 392 |
+
music_features = self.music_model(music_inputs.to(self.device),
|
| 393 |
+
music_masks.to(self.device))['last_hidden_state']
|
| 394 |
+
|
| 395 |
+
if get_normalized:
|
| 396 |
+
music_features = self.avg_pooling(music_features, music_masks)
|
| 397 |
+
music_features = self.music_proj(music_features)
|
| 398 |
+
|
| 399 |
+
return music_features
|
| 400 |
+
|
| 401 |
+
def forward(self,
|
| 402 |
+
text_inputs, # [batch_size, seq_length]
|
| 403 |
+
text_masks, # [batch_size, seq_length]
|
| 404 |
+
music_inputs, # [batch_size, seq_length, hidden_size]
|
| 405 |
+
music_masks): # [batch_size, seq_length]
|
| 406 |
+
# Compute the text features
|
| 407 |
+
text_features = self.get_text_features(text_inputs, text_masks, get_normalized=True)
|
| 408 |
+
|
| 409 |
+
# Compute the music features
|
| 410 |
+
music_features = self.get_music_features(music_inputs, music_masks, get_normalized=True)
|
| 411 |
+
|
| 412 |
+
return self.loss_fn(text_features,
|
| 413 |
+
music_features,
|
| 414 |
+
LOGIT_SCALE,
|
| 415 |
+
output_dict=False)
|
| 416 |
+
|
| 417 |
+
def split_data(data, eval_ratio=EVAL_SPLIT):
|
| 418 |
+
random.shuffle(data)
|
| 419 |
+
split_idx = int(len(data)*eval_ratio)
|
| 420 |
+
eval_set = data[:split_idx]
|
| 421 |
+
train_set = data[split_idx:]
|
| 422 |
+
return train_set, eval_set
|
| 423 |
+
|
| 424 |
+
def mask_patches(target_patches, patchilizer, mode):
|
| 425 |
+
indices = list(range(len(target_patches)))
|
| 426 |
+
random.shuffle(indices)
|
| 427 |
+
selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))]
|
| 428 |
+
sorted_indices = sorted(selected_indices)
|
| 429 |
+
input_patches = torch.tensor(target_patches)
|
| 430 |
+
|
| 431 |
+
if mode=="eval":
|
| 432 |
+
choice = "original"
|
| 433 |
+
else:
|
| 434 |
+
choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0]
|
| 435 |
+
|
| 436 |
+
if choice=="mask":
|
| 437 |
+
input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE)
|
| 438 |
+
elif choice=="shuffle":
|
| 439 |
+
for idx in sorted_indices:
|
| 440 |
+
patch = input_patches[idx]
|
| 441 |
+
try:
|
| 442 |
+
index_eos = (patch == patchilizer.eos_token_id).nonzero().item()
|
| 443 |
+
except:
|
| 444 |
+
index_eos = len(patch)
|
| 445 |
+
|
| 446 |
+
indices = list(range(1, index_eos))
|
| 447 |
+
random.shuffle(indices)
|
| 448 |
+
indices = [0] + indices + list(range(index_eos, len(patch)))
|
| 449 |
+
input_patches[idx] = patch[indices]
|
| 450 |
+
|
| 451 |
+
selected_indices = torch.zeros(len(target_patches))
|
| 452 |
+
selected_indices[sorted_indices] = 1.
|
| 453 |
+
|
| 454 |
+
return input_patches, selected_indices
|
| 455 |
+
|
| 456 |
+
def remove_instrument_info(item):
|
| 457 |
+
# remove instrument information from symbolic music
|
| 458 |
+
lines = re.findall(r'.*?\n|.*$', item)
|
| 459 |
+
lines = list(filter(None, lines))
|
| 460 |
+
if lines[0].split(" ")[0] == "ticks_per_beat":
|
| 461 |
+
type = "mtf"
|
| 462 |
+
else:
|
| 463 |
+
type = "abc"
|
| 464 |
+
|
| 465 |
+
cleaned_lines = []
|
| 466 |
+
for line in lines:
|
| 467 |
+
if type=="abc" and line.startswith("V:"):
|
| 468 |
+
# find the position of " nm=" or " snm="
|
| 469 |
+
nm_pos = line.find(" nm=")
|
| 470 |
+
snm_pos = line.find(" snm=")
|
| 471 |
+
# keep the part before " nm=" or " snm="
|
| 472 |
+
if nm_pos != -1:
|
| 473 |
+
line = line[:nm_pos]
|
| 474 |
+
elif snm_pos != -1:
|
| 475 |
+
line = line[:snm_pos]
|
| 476 |
+
if nm_pos != -1 or snm_pos != -1:
|
| 477 |
+
line += "\n"
|
| 478 |
+
elif type=="mtf" and line.startswith("program_change"):
|
| 479 |
+
line = " ".join(line.split(" ")[:-1]) + " 0\n"
|
| 480 |
+
|
| 481 |
+
cleaned_lines.append(line)
|
| 482 |
+
|
| 483 |
+
return ''.join(cleaned_lines)
|
utils (2).py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import bisect
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import numpy as np
|
| 7 |
+
from config import *
|
| 8 |
+
from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
|
| 9 |
+
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
| 10 |
+
from tokenizers import Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Patchilizer:
|
| 14 |
+
def __init__(self, stream=PATCH_STREAM):
|
| 15 |
+
self.stream = stream
|
| 16 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
| 17 |
+
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
| 18 |
+
self.bos_token_id = 1
|
| 19 |
+
self.eos_token_id = 2
|
| 20 |
+
self.special_token_id = 0
|
| 21 |
+
|
| 22 |
+
def split_bars(self, body_lines):
|
| 23 |
+
"""
|
| 24 |
+
Split a body of music into individual bars.
|
| 25 |
+
"""
|
| 26 |
+
new_bars = []
|
| 27 |
+
try:
|
| 28 |
+
for line in body_lines:
|
| 29 |
+
line_bars = re.split(self.regexPattern, line)
|
| 30 |
+
line_bars = list(filter(None, line_bars))
|
| 31 |
+
new_line_bars = []
|
| 32 |
+
|
| 33 |
+
if len(line_bars) == 1:
|
| 34 |
+
new_line_bars = line_bars
|
| 35 |
+
else:
|
| 36 |
+
if line_bars[0] in self.delimiters:
|
| 37 |
+
new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
|
| 38 |
+
else:
|
| 39 |
+
new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
|
| 40 |
+
if 'V' not in new_line_bars[-1]:
|
| 41 |
+
new_line_bars[-2] += new_line_bars[-1]
|
| 42 |
+
new_line_bars = new_line_bars[:-1]
|
| 43 |
+
new_bars += new_line_bars
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
return new_bars
|
| 48 |
+
|
| 49 |
+
def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
|
| 50 |
+
if not generate_last and len(abc_text) % patch_size != 0:
|
| 51 |
+
abc_text += chr(self.eos_token_id)
|
| 52 |
+
patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
|
| 53 |
+
return patches
|
| 54 |
+
|
| 55 |
+
def patch2chars(self, patch):
|
| 56 |
+
"""
|
| 57 |
+
Convert a patch into a bar.
|
| 58 |
+
"""
|
| 59 |
+
bytes = ''
|
| 60 |
+
for idx in patch:
|
| 61 |
+
if idx == self.eos_token_id:
|
| 62 |
+
break
|
| 63 |
+
if idx < self.eos_token_id:
|
| 64 |
+
pass
|
| 65 |
+
bytes += chr(idx)
|
| 66 |
+
return bytes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def patchilize_metadata(self, metadata_lines):
|
| 70 |
+
|
| 71 |
+
metadata_patches = []
|
| 72 |
+
for line in metadata_lines:
|
| 73 |
+
metadata_patches += self.split_patches(line)
|
| 74 |
+
|
| 75 |
+
return metadata_patches
|
| 76 |
+
|
| 77 |
+
def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
|
| 78 |
+
|
| 79 |
+
tunebody_patches = []
|
| 80 |
+
bars = self.split_bars(tunebody_lines)
|
| 81 |
+
if encode_mode == 'train':
|
| 82 |
+
for bar in bars:
|
| 83 |
+
tunebody_patches += self.split_patches(bar)
|
| 84 |
+
elif encode_mode == 'generate':
|
| 85 |
+
for bar in bars[:-1]:
|
| 86 |
+
tunebody_patches += self.split_patches(bar)
|
| 87 |
+
tunebody_patches += self.split_patches(bars[-1], generate_last=True)
|
| 88 |
+
|
| 89 |
+
return tunebody_patches
|
| 90 |
+
|
| 91 |
+
def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
|
| 92 |
+
|
| 93 |
+
lines = abc_text.split('\n')
|
| 94 |
+
lines = list(filter(None, lines))
|
| 95 |
+
lines = [line + '\n' for line in lines]
|
| 96 |
+
|
| 97 |
+
tunebody_index = -1
|
| 98 |
+
for i, line in enumerate(lines):
|
| 99 |
+
if '[V:' in line:
|
| 100 |
+
tunebody_index = i
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
metadata_lines = lines[ : tunebody_index]
|
| 104 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 105 |
+
|
| 106 |
+
if self.stream:
|
| 107 |
+
tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
|
| 108 |
+
enumerate(tunebody_lines)]
|
| 109 |
+
|
| 110 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 111 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
|
| 112 |
+
|
| 113 |
+
if add_special_patches:
|
| 114 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 115 |
+
eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
|
| 116 |
+
|
| 117 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 118 |
+
tunebody_patches = tunebody_patches + [eos_patch]
|
| 119 |
+
|
| 120 |
+
if self.stream:
|
| 121 |
+
if len(metadata_patches) + len(tunebody_patches) > patch_length:
|
| 122 |
+
available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
|
| 123 |
+
line_index_for_cut_index = list(range(len(available_cut_indexes)))
|
| 124 |
+
end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
|
| 125 |
+
biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
|
| 126 |
+
available_cut_indexes = available_cut_indexes[:biggest_index + 1]
|
| 127 |
+
|
| 128 |
+
if len(available_cut_indexes) == 1:
|
| 129 |
+
choices = ['head']
|
| 130 |
+
elif len(available_cut_indexes) == 2:
|
| 131 |
+
choices = ['head', 'tail']
|
| 132 |
+
else:
|
| 133 |
+
choices = ['head', 'tail', 'middle']
|
| 134 |
+
choice = random.choice(choices)
|
| 135 |
+
if choice == 'head':
|
| 136 |
+
patches = metadata_patches + tunebody_patches[0:]
|
| 137 |
+
else:
|
| 138 |
+
if choice == 'tail':
|
| 139 |
+
cut_index = len(available_cut_indexes) - 1
|
| 140 |
+
else:
|
| 141 |
+
cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
|
| 142 |
+
|
| 143 |
+
line_index = line_index_for_cut_index[cut_index]
|
| 144 |
+
stream_tunebody_lines = tunebody_lines[line_index : ]
|
| 145 |
+
|
| 146 |
+
stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
|
| 147 |
+
if add_special_patches:
|
| 148 |
+
stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
|
| 149 |
+
patches = metadata_patches + stream_tunebody_patches
|
| 150 |
+
else:
|
| 151 |
+
patches = metadata_patches + tunebody_patches
|
| 152 |
+
else:
|
| 153 |
+
patches = metadata_patches + tunebody_patches
|
| 154 |
+
|
| 155 |
+
if cut:
|
| 156 |
+
patches = patches[ : patch_length]
|
| 157 |
+
else:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
# encode to ids
|
| 161 |
+
id_patches = []
|
| 162 |
+
for patch in patches:
|
| 163 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 164 |
+
id_patches.append(id_patch)
|
| 165 |
+
|
| 166 |
+
return id_patches
|
| 167 |
+
|
| 168 |
+
def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
|
| 169 |
+
|
| 170 |
+
lines = abc_code.split('\n')
|
| 171 |
+
lines = list(filter(None, lines))
|
| 172 |
+
|
| 173 |
+
tunebody_index = None
|
| 174 |
+
for i, line in enumerate(lines):
|
| 175 |
+
if line.startswith('[V:') or line.startswith('[r:'):
|
| 176 |
+
tunebody_index = i
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
metadata_lines = lines[ : tunebody_index]
|
| 180 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 181 |
+
|
| 182 |
+
metadata_lines = [line + '\n' for line in metadata_lines]
|
| 183 |
+
if self.stream:
|
| 184 |
+
if not abc_code.endswith('\n'):
|
| 185 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
|
| 186 |
+
else:
|
| 187 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
|
| 188 |
+
else:
|
| 189 |
+
tunebody_lines = [line + '\n' for line in tunebody_lines]
|
| 190 |
+
|
| 191 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 192 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
|
| 193 |
+
|
| 194 |
+
if add_special_patches:
|
| 195 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 196 |
+
|
| 197 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 198 |
+
|
| 199 |
+
patches = metadata_patches + tunebody_patches
|
| 200 |
+
patches = patches[ : patch_length]
|
| 201 |
+
|
| 202 |
+
# encode to ids
|
| 203 |
+
id_patches = []
|
| 204 |
+
for patch in patches:
|
| 205 |
+
if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
|
| 206 |
+
id_patch = [ord(c) for c in patch]
|
| 207 |
+
else:
|
| 208 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 209 |
+
id_patches.append(id_patch)
|
| 210 |
+
|
| 211 |
+
return id_patches
|
| 212 |
+
|
| 213 |
+
def decode(self, patches):
|
| 214 |
+
"""
|
| 215 |
+
Decode patches into music.
|
| 216 |
+
"""
|
| 217 |
+
return ''.join(self.patch2chars(patch) for patch in patches)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class PatchLevelDecoder(PreTrainedModel):
|
| 223 |
+
"""
|
| 224 |
+
A Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
| 225 |
+
It inherits PreTrainedModel from transformers.
|
| 226 |
+
"""
|
| 227 |
+
def __init__(self, config):
|
| 228 |
+
super().__init__(config)
|
| 229 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
|
| 230 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
| 231 |
+
self.base = GPT2Model(config)
|
| 232 |
+
|
| 233 |
+
def forward(self,
|
| 234 |
+
patches: torch.Tensor,
|
| 235 |
+
masks=None) -> torch.Tensor:
|
| 236 |
+
"""
|
| 237 |
+
The forward pass of the patch-level decoder model.
|
| 238 |
+
:param patches: the patches to be encoded
|
| 239 |
+
:param masks: the masks for the patches
|
| 240 |
+
:return: the encoded patches
|
| 241 |
+
"""
|
| 242 |
+
patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
|
| 243 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
|
| 244 |
+
patches = self.patch_embedding(patches.to(self.device))
|
| 245 |
+
|
| 246 |
+
if masks==None:
|
| 247 |
+
return self.base(inputs_embeds=patches)
|
| 248 |
+
else:
|
| 249 |
+
return self.base(inputs_embeds=patches,
|
| 250 |
+
attention_mask=masks)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CharLevelDecoder(PreTrainedModel):
|
| 254 |
+
"""
|
| 255 |
+
A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
|
| 256 |
+
based on the encoded patch features. It inherits PreTrainedModel from transformers.
|
| 257 |
+
"""
|
| 258 |
+
def __init__(self, config):
|
| 259 |
+
super().__init__(config)
|
| 260 |
+
self.special_token_id = 0
|
| 261 |
+
self.bos_token_id = 1
|
| 262 |
+
|
| 263 |
+
self.base = GPT2LMHeadModel(config)
|
| 264 |
+
|
| 265 |
+
def forward(self,
|
| 266 |
+
encoded_patches: torch.Tensor,
|
| 267 |
+
target_patches: torch.Tensor):
|
| 268 |
+
"""
|
| 269 |
+
The forward pass of the char-level decoder model.
|
| 270 |
+
:param encoded_patches: the encoded patches
|
| 271 |
+
:param target_patches: the target patches
|
| 272 |
+
:return: the output of the model
|
| 273 |
+
"""
|
| 274 |
+
# preparing the labels for model training
|
| 275 |
+
target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
|
| 276 |
+
# print('target_patches shape:', target_patches.shape)
|
| 277 |
+
|
| 278 |
+
target_masks = target_patches == self.special_token_id
|
| 279 |
+
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
| 280 |
+
|
| 281 |
+
# masking the labels for model training
|
| 282 |
+
target_masks = torch.ones_like(labels)
|
| 283 |
+
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
| 284 |
+
|
| 285 |
+
# select patches
|
| 286 |
+
if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
|
| 287 |
+
indices = list(range(len(target_patches)))
|
| 288 |
+
random.shuffle(indices)
|
| 289 |
+
selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
|
| 290 |
+
|
| 291 |
+
target_patches = target_patches[selected_indices,:]
|
| 292 |
+
target_masks = target_masks[selected_indices,:]
|
| 293 |
+
encoded_patches = encoded_patches[selected_indices,:]
|
| 294 |
+
|
| 295 |
+
# get input embeddings
|
| 296 |
+
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
| 297 |
+
|
| 298 |
+
# concatenate the encoded patches with the input embeddings
|
| 299 |
+
inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
|
| 300 |
+
|
| 301 |
+
output = self.base(inputs_embeds=inputs_embeds,
|
| 302 |
+
attention_mask=target_masks,
|
| 303 |
+
labels=labels)
|
| 304 |
+
# output_hidden_states=True=True)
|
| 305 |
+
|
| 306 |
+
return output
|
| 307 |
+
|
| 308 |
+
def generate(self,
|
| 309 |
+
encoded_patch: torch.Tensor, # [hidden_size]
|
| 310 |
+
tokens: torch.Tensor): # [1]
|
| 311 |
+
"""
|
| 312 |
+
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
| 313 |
+
:param encoded_patch: the encoded patch
|
| 314 |
+
:param tokens: already generated tokens in the patch
|
| 315 |
+
:return: the probability distribution of next token
|
| 316 |
+
"""
|
| 317 |
+
encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
|
| 318 |
+
tokens = tokens.reshape(1, -1)
|
| 319 |
+
|
| 320 |
+
# Get input embeddings
|
| 321 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
| 322 |
+
|
| 323 |
+
# Concatenate the encoded patch with the input embeddings
|
| 324 |
+
tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
|
| 325 |
+
|
| 326 |
+
# Get output from model
|
| 327 |
+
outputs = self.base(inputs_embeds=tokens)
|
| 328 |
+
|
| 329 |
+
# Get probabilities of next token
|
| 330 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
| 331 |
+
|
| 332 |
+
return probs
|
| 333 |
+
|
| 334 |
+
def safe_normalize_probs(probs):
|
| 335 |
+
epsilon = 1e-12
|
| 336 |
+
probs = np.array(probs, dtype=np.float64)
|
| 337 |
+
probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
|
| 338 |
+
probs = probs + epsilon
|
| 339 |
+
s = probs.sum()
|
| 340 |
+
if s > 0:
|
| 341 |
+
probs = probs / s
|
| 342 |
+
else:
|
| 343 |
+
probs = np.zeros_like(probs)
|
| 344 |
+
probs[0] = 1.0
|
| 345 |
+
return probs
|
| 346 |
+
|
| 347 |
+
class NotaGenLMHeadModel(PreTrainedModel):
|
| 348 |
+
"""
|
| 349 |
+
NotaGen is a language model with a hierarchical structure.
|
| 350 |
+
It includes a patch-level decoder and a char-level decoder.
|
| 351 |
+
The patch-level decoder is used to generate patch features in an auto-regressive manner.
|
| 352 |
+
The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
|
| 353 |
+
It inherits PreTrainedModel from transformers.
|
| 354 |
+
"""
|
| 355 |
+
def __init__(self, encoder_config, decoder_config):
|
| 356 |
+
super().__init__(encoder_config)
|
| 357 |
+
self.special_token_id = 0
|
| 358 |
+
self.bos_token_id = 1
|
| 359 |
+
self.eos_token_id = 2
|
| 360 |
+
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
| 361 |
+
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
| 362 |
+
|
| 363 |
+
def forward(self,
|
| 364 |
+
patches: torch.Tensor,
|
| 365 |
+
masks: torch.Tensor):
|
| 366 |
+
"""
|
| 367 |
+
The forward pass of the bGPT model.
|
| 368 |
+
:param patches: the patches to be encoded
|
| 369 |
+
:param masks: the masks for the patches
|
| 370 |
+
:return: the decoded patches
|
| 371 |
+
"""
|
| 372 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
| 373 |
+
encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
|
| 374 |
+
|
| 375 |
+
left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
|
| 376 |
+
masks[:, 0] = 0
|
| 377 |
+
|
| 378 |
+
encoded_patches = encoded_patches[left_shift_masks == 1]
|
| 379 |
+
patches = patches[masks == 1]
|
| 380 |
+
|
| 381 |
+
return self.char_level_decoder(encoded_patches, patches)
|
| 382 |
+
|
| 383 |
+
def generate(self,
|
| 384 |
+
patches: torch.Tensor,
|
| 385 |
+
top_k=0,
|
| 386 |
+
top_p=1,
|
| 387 |
+
temperature=1.0):
|
| 388 |
+
"""
|
| 389 |
+
The generate function for generating patches based on patches.
|
| 390 |
+
:param patches: the patches to be encoded
|
| 391 |
+
:param top_k: the top k for sampling
|
| 392 |
+
:param top_p: the top p for sampling
|
| 393 |
+
:param temperature: the temperature for sampling
|
| 394 |
+
:return: the generated patches
|
| 395 |
+
"""
|
| 396 |
+
if patches.shape[-1] % PATCH_SIZE != 0:
|
| 397 |
+
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
|
| 398 |
+
tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
|
| 399 |
+
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
|
| 400 |
+
else:
|
| 401 |
+
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
| 402 |
+
|
| 403 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
|
| 404 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
|
| 405 |
+
generated_patch = []
|
| 406 |
+
|
| 407 |
+
while True:
|
| 408 |
+
prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
|
| 409 |
+
prob = safe_normalize_probs(prob)
|
| 410 |
+
prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
|
| 411 |
+
prob = safe_normalize_probs(prob)
|
| 412 |
+
prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
|
| 413 |
+
prob = safe_normalize_probs(prob)
|
| 414 |
+
token = temperature_sampling(prob, temperature=temperature) # int
|
| 415 |
+
char = chr(token)
|
| 416 |
+
generated_patch.append(token)
|
| 417 |
+
|
| 418 |
+
if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
|
| 419 |
+
break
|
| 420 |
+
else:
|
| 421 |
+
tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
|
| 422 |
+
|
| 423 |
+
return generated_patch
|
utils (3).py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import bisect
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import numpy as np
|
| 7 |
+
from config import *
|
| 8 |
+
from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
|
| 9 |
+
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
| 10 |
+
from tokenizers import Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Patchilizer:
|
| 14 |
+
def __init__(self, stream=PATCH_STREAM):
|
| 15 |
+
self.stream = stream
|
| 16 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
| 17 |
+
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
| 18 |
+
self.bos_token_id = 1
|
| 19 |
+
self.eos_token_id = 2
|
| 20 |
+
self.special_token_id = 0
|
| 21 |
+
|
| 22 |
+
def split_bars(self, body_lines):
|
| 23 |
+
"""
|
| 24 |
+
Split a body of music into individual bars.
|
| 25 |
+
"""
|
| 26 |
+
new_bars = []
|
| 27 |
+
try:
|
| 28 |
+
for line in body_lines:
|
| 29 |
+
line_bars = re.split(self.regexPattern, line)
|
| 30 |
+
line_bars = list(filter(None, line_bars))
|
| 31 |
+
new_line_bars = []
|
| 32 |
+
|
| 33 |
+
if len(line_bars) == 1:
|
| 34 |
+
new_line_bars = line_bars
|
| 35 |
+
else:
|
| 36 |
+
if line_bars[0] in self.delimiters:
|
| 37 |
+
new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
|
| 38 |
+
else:
|
| 39 |
+
new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
|
| 40 |
+
if 'V' not in new_line_bars[-1]:
|
| 41 |
+
new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\n 的组合
|
| 42 |
+
new_line_bars = new_line_bars[:-1]
|
| 43 |
+
new_bars += new_line_bars
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
return new_bars
|
| 48 |
+
|
| 49 |
+
def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
|
| 50 |
+
if not generate_last and len(abc_text) % patch_size != 0:
|
| 51 |
+
abc_text += chr(self.eos_token_id)
|
| 52 |
+
patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
|
| 53 |
+
return patches
|
| 54 |
+
|
| 55 |
+
def patch2chars(self, patch):
|
| 56 |
+
"""
|
| 57 |
+
Convert a patch into a bar.
|
| 58 |
+
"""
|
| 59 |
+
bytes = ''
|
| 60 |
+
for idx in patch:
|
| 61 |
+
if idx == self.eos_token_id:
|
| 62 |
+
break
|
| 63 |
+
if idx < self.eos_token_id:
|
| 64 |
+
pass
|
| 65 |
+
bytes += chr(idx)
|
| 66 |
+
return bytes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def patchilize_metadata(self, metadata_lines):
|
| 70 |
+
|
| 71 |
+
metadata_patches = []
|
| 72 |
+
for line in metadata_lines:
|
| 73 |
+
metadata_patches += self.split_patches(line)
|
| 74 |
+
|
| 75 |
+
return metadata_patches
|
| 76 |
+
|
| 77 |
+
def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
|
| 78 |
+
|
| 79 |
+
tunebody_patches = []
|
| 80 |
+
bars = self.split_bars(tunebody_lines)
|
| 81 |
+
if encode_mode == 'train':
|
| 82 |
+
for bar in bars:
|
| 83 |
+
tunebody_patches += self.split_patches(bar)
|
| 84 |
+
elif encode_mode == 'generate':
|
| 85 |
+
for bar in bars[:-1]:
|
| 86 |
+
tunebody_patches += self.split_patches(bar)
|
| 87 |
+
tunebody_patches += self.split_patches(bars[-1], generate_last=True)
|
| 88 |
+
|
| 89 |
+
return tunebody_patches
|
| 90 |
+
|
| 91 |
+
def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
|
| 92 |
+
|
| 93 |
+
lines = abc_text.split('\n')
|
| 94 |
+
lines = list(filter(None, lines))
|
| 95 |
+
lines = [line + '\n' for line in lines]
|
| 96 |
+
|
| 97 |
+
tunebody_index = -1
|
| 98 |
+
for i, line in enumerate(lines):
|
| 99 |
+
if '[V:' in line:
|
| 100 |
+
tunebody_index = i
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
metadata_lines = lines[ : tunebody_index]
|
| 104 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 105 |
+
|
| 106 |
+
if self.stream:
|
| 107 |
+
tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
|
| 108 |
+
enumerate(tunebody_lines)]
|
| 109 |
+
|
| 110 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 111 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
|
| 112 |
+
|
| 113 |
+
if add_special_patches:
|
| 114 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 115 |
+
eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
|
| 116 |
+
|
| 117 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 118 |
+
tunebody_patches = tunebody_patches + [eos_patch]
|
| 119 |
+
|
| 120 |
+
if self.stream:
|
| 121 |
+
if len(metadata_patches) + len(tunebody_patches) > patch_length:
|
| 122 |
+
available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
|
| 123 |
+
line_index_for_cut_index = list(range(len(available_cut_indexes)))
|
| 124 |
+
end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
|
| 125 |
+
biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
|
| 126 |
+
available_cut_indexes = available_cut_indexes[:biggest_index + 1]
|
| 127 |
+
|
| 128 |
+
if len(available_cut_indexes) == 1:
|
| 129 |
+
choices = ['head']
|
| 130 |
+
elif len(available_cut_indexes) == 2:
|
| 131 |
+
choices = ['head', 'tail']
|
| 132 |
+
else:
|
| 133 |
+
choices = ['head', 'tail', 'middle']
|
| 134 |
+
choice = random.choice(choices)
|
| 135 |
+
if choice == 'head':
|
| 136 |
+
patches = metadata_patches + tunebody_patches[0:]
|
| 137 |
+
else:
|
| 138 |
+
if choice == 'tail':
|
| 139 |
+
cut_index = len(available_cut_indexes) - 1
|
| 140 |
+
else:
|
| 141 |
+
cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
|
| 142 |
+
|
| 143 |
+
line_index = line_index_for_cut_index[cut_index]
|
| 144 |
+
stream_tunebody_lines = tunebody_lines[line_index : ]
|
| 145 |
+
|
| 146 |
+
stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
|
| 147 |
+
if add_special_patches:
|
| 148 |
+
stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
|
| 149 |
+
patches = metadata_patches + stream_tunebody_patches
|
| 150 |
+
else:
|
| 151 |
+
patches = metadata_patches + tunebody_patches
|
| 152 |
+
else:
|
| 153 |
+
patches = metadata_patches + tunebody_patches
|
| 154 |
+
|
| 155 |
+
if cut:
|
| 156 |
+
patches = patches[ : patch_length]
|
| 157 |
+
else:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
# encode to ids
|
| 161 |
+
id_patches = []
|
| 162 |
+
for patch in patches:
|
| 163 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 164 |
+
id_patches.append(id_patch)
|
| 165 |
+
|
| 166 |
+
return id_patches
|
| 167 |
+
|
| 168 |
+
def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
|
| 169 |
+
|
| 170 |
+
lines = abc_code.split('\n')
|
| 171 |
+
lines = list(filter(None, lines))
|
| 172 |
+
|
| 173 |
+
tunebody_index = None
|
| 174 |
+
for i, line in enumerate(lines):
|
| 175 |
+
if line.startswith('[V:') or line.startswith('[r:'):
|
| 176 |
+
tunebody_index = i
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
metadata_lines = lines[ : tunebody_index]
|
| 180 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 181 |
+
|
| 182 |
+
metadata_lines = [line + '\n' for line in metadata_lines]
|
| 183 |
+
if self.stream:
|
| 184 |
+
if not abc_code.endswith('\n'):
|
| 185 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
|
| 186 |
+
else:
|
| 187 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
|
| 188 |
+
else:
|
| 189 |
+
tunebody_lines = [line + '\n' for line in tunebody_lines]
|
| 190 |
+
|
| 191 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 192 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
|
| 193 |
+
|
| 194 |
+
if add_special_patches:
|
| 195 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 196 |
+
|
| 197 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 198 |
+
|
| 199 |
+
patches = metadata_patches + tunebody_patches
|
| 200 |
+
patches = patches[ : patch_length]
|
| 201 |
+
|
| 202 |
+
# encode to ids
|
| 203 |
+
id_patches = []
|
| 204 |
+
for patch in patches:
|
| 205 |
+
if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
|
| 206 |
+
id_patch = [ord(c) for c in patch]
|
| 207 |
+
else:
|
| 208 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 209 |
+
id_patches.append(id_patch)
|
| 210 |
+
|
| 211 |
+
return id_patches
|
| 212 |
+
|
| 213 |
+
def decode(self, patches):
|
| 214 |
+
"""
|
| 215 |
+
Decode patches into music.
|
| 216 |
+
"""
|
| 217 |
+
return ''.join(self.patch2chars(patch) for patch in patches)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class PatchLevelDecoder(PreTrainedModel):
|
| 223 |
+
"""
|
| 224 |
+
A Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
| 225 |
+
It inherits PreTrainedModel from transformers.
|
| 226 |
+
"""
|
| 227 |
+
def __init__(self, config):
|
| 228 |
+
super().__init__(config)
|
| 229 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
|
| 230 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
| 231 |
+
self.base = GPT2Model(config)
|
| 232 |
+
|
| 233 |
+
def forward(self,
|
| 234 |
+
patches: torch.Tensor,
|
| 235 |
+
masks=None) -> torch.Tensor:
|
| 236 |
+
"""
|
| 237 |
+
The forward pass of the patch-level decoder model.
|
| 238 |
+
:param patches: the patches to be encoded
|
| 239 |
+
:param masks: the masks for the patches
|
| 240 |
+
:return: the encoded patches
|
| 241 |
+
"""
|
| 242 |
+
patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
|
| 243 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
|
| 244 |
+
patches = self.patch_embedding(patches.to(self.device))
|
| 245 |
+
|
| 246 |
+
if masks==None:
|
| 247 |
+
return self.base(inputs_embeds=patches)
|
| 248 |
+
else:
|
| 249 |
+
return self.base(inputs_embeds=patches,
|
| 250 |
+
attention_mask=masks)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CharLevelDecoder(PreTrainedModel):
|
| 254 |
+
"""
|
| 255 |
+
A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
|
| 256 |
+
based on the encoded patch features. It inherits PreTrainedModel from transformers.
|
| 257 |
+
"""
|
| 258 |
+
def __init__(self, config):
|
| 259 |
+
super().__init__(config)
|
| 260 |
+
self.special_token_id = 0
|
| 261 |
+
self.bos_token_id = 1
|
| 262 |
+
|
| 263 |
+
self.base = GPT2LMHeadModel(config)
|
| 264 |
+
|
| 265 |
+
def forward(self,
|
| 266 |
+
encoded_patches: torch.Tensor,
|
| 267 |
+
target_patches: torch.Tensor):
|
| 268 |
+
"""
|
| 269 |
+
The forward pass of the char-level decoder model.
|
| 270 |
+
:param encoded_patches: the encoded patches
|
| 271 |
+
:param target_patches: the target patches
|
| 272 |
+
:return: the output of the model
|
| 273 |
+
"""
|
| 274 |
+
# preparing the labels for model training
|
| 275 |
+
target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
|
| 276 |
+
# print('target_patches shape:', target_patches.shape)
|
| 277 |
+
|
| 278 |
+
target_masks = target_patches == self.special_token_id
|
| 279 |
+
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
| 280 |
+
|
| 281 |
+
# masking the labels for model training
|
| 282 |
+
target_masks = torch.ones_like(labels)
|
| 283 |
+
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
| 284 |
+
|
| 285 |
+
# select patches
|
| 286 |
+
if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
|
| 287 |
+
indices = list(range(len(target_patches)))
|
| 288 |
+
random.shuffle(indices)
|
| 289 |
+
selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
|
| 290 |
+
|
| 291 |
+
target_patches = target_patches[selected_indices,:]
|
| 292 |
+
target_masks = target_masks[selected_indices,:]
|
| 293 |
+
encoded_patches = encoded_patches[selected_indices,:]
|
| 294 |
+
|
| 295 |
+
# get input embeddings
|
| 296 |
+
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
| 297 |
+
|
| 298 |
+
# concatenate the encoded patches with the input embeddings
|
| 299 |
+
inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
|
| 300 |
+
|
| 301 |
+
output = self.base(inputs_embeds=inputs_embeds,
|
| 302 |
+
attention_mask=target_masks,
|
| 303 |
+
labels=labels)
|
| 304 |
+
# output_hidden_states=True=True)
|
| 305 |
+
|
| 306 |
+
return output
|
| 307 |
+
|
| 308 |
+
def generate(self,
|
| 309 |
+
encoded_patch: torch.Tensor, # [hidden_size]
|
| 310 |
+
tokens: torch.Tensor): # [1]
|
| 311 |
+
"""
|
| 312 |
+
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
| 313 |
+
:param encoded_patch: the encoded patch
|
| 314 |
+
:param tokens: already generated tokens in the patch
|
| 315 |
+
:return: the probability distribution of next token
|
| 316 |
+
"""
|
| 317 |
+
encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
|
| 318 |
+
tokens = tokens.reshape(1, -1)
|
| 319 |
+
|
| 320 |
+
# Get input embeddings
|
| 321 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
| 322 |
+
|
| 323 |
+
# Concatenate the encoded patch with the input embeddings
|
| 324 |
+
tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
|
| 325 |
+
|
| 326 |
+
# Get output from model
|
| 327 |
+
outputs = self.base(inputs_embeds=tokens)
|
| 328 |
+
|
| 329 |
+
# Get probabilities of next token
|
| 330 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
| 331 |
+
|
| 332 |
+
return probs
|
| 333 |
+
|
| 334 |
+
def safe_normalize_probs(probs):
|
| 335 |
+
epsilon = 1e-12 # Smallest value to avoid log(0) and maintain precision
|
| 336 |
+
probs = np.array(probs, dtype=np.float64)
|
| 337 |
+
probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
|
| 338 |
+
probs = probs + epsilon # Ensure strictly positive
|
| 339 |
+
s = probs.sum()
|
| 340 |
+
if s > 0:
|
| 341 |
+
probs = probs / s
|
| 342 |
+
else:
|
| 343 |
+
probs = np.zeros_like(probs)
|
| 344 |
+
probs[0] = 1.0
|
| 345 |
+
return probs
|
| 346 |
+
|
| 347 |
+
class NotaGenLMHeadModel(PreTrainedModel):
|
| 348 |
+
"""
|
| 349 |
+
NotaGen is a language model with a hierarchical structure.
|
| 350 |
+
It includes a patch-level decoder and a char-level decoder.
|
| 351 |
+
The patch-level decoder is used to generate patch features in an auto-regressive manner.
|
| 352 |
+
The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
|
| 353 |
+
It inherits PreTrainedModel from transformers.
|
| 354 |
+
"""
|
| 355 |
+
def __init__(self, encoder_config, decoder_config):
|
| 356 |
+
super().__init__(encoder_config)
|
| 357 |
+
self.special_token_id = 0
|
| 358 |
+
self.bos_token_id = 1
|
| 359 |
+
self.eos_token_id = 2
|
| 360 |
+
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
| 361 |
+
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
| 362 |
+
|
| 363 |
+
def forward(self,
|
| 364 |
+
patches: torch.Tensor,
|
| 365 |
+
masks: torch.Tensor):
|
| 366 |
+
"""
|
| 367 |
+
The forward pass of the bGPT model.
|
| 368 |
+
:param patches: the patches to be encoded
|
| 369 |
+
:param masks: the masks for the patches
|
| 370 |
+
:return: the decoded patches
|
| 371 |
+
"""
|
| 372 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
| 373 |
+
encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
|
| 374 |
+
|
| 375 |
+
left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
|
| 376 |
+
masks[:, 0] = 0
|
| 377 |
+
|
| 378 |
+
encoded_patches = encoded_patches[left_shift_masks == 1]
|
| 379 |
+
patches = patches[masks == 1]
|
| 380 |
+
|
| 381 |
+
return self.char_level_decoder(encoded_patches, patches)
|
| 382 |
+
|
| 383 |
+
def generate(self,
|
| 384 |
+
patches: torch.Tensor,
|
| 385 |
+
top_k=0,
|
| 386 |
+
top_p=1,
|
| 387 |
+
temperature=1.0):
|
| 388 |
+
"""
|
| 389 |
+
The generate function for generating patches based on patches.
|
| 390 |
+
:param patches: the patches to be encoded
|
| 391 |
+
:param top_k: the top k for sampling
|
| 392 |
+
:param top_p: the top p for sampling
|
| 393 |
+
:param temperature: the temperature for sampling
|
| 394 |
+
:return: the generated patches
|
| 395 |
+
"""
|
| 396 |
+
if patches.shape[-1] % PATCH_SIZE != 0:
|
| 397 |
+
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
|
| 398 |
+
tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
|
| 399 |
+
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
|
| 400 |
+
else:
|
| 401 |
+
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
| 402 |
+
|
| 403 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
|
| 404 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
|
| 405 |
+
generated_patch = []
|
| 406 |
+
|
| 407 |
+
while True:
|
| 408 |
+
prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
|
| 409 |
+
prob = safe_normalize_probs(prob)
|
| 410 |
+
prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
|
| 411 |
+
prob = safe_normalize_probs(prob)
|
| 412 |
+
prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
|
| 413 |
+
prob = safe_normalize_probs(prob)
|
| 414 |
+
token = temperature_sampling(prob, temperature=temperature) # int
|
| 415 |
+
char = chr(token)
|
| 416 |
+
generated_patch.append(token)
|
| 417 |
+
|
| 418 |
+
if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
|
| 419 |
+
break
|
| 420 |
+
else:
|
| 421 |
+
tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
|
| 422 |
+
|
| 423 |
+
return generated_patch
|
utils (4).py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import bisect
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import numpy as np
|
| 7 |
+
from config import *
|
| 8 |
+
from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
|
| 9 |
+
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
| 10 |
+
from tokenizers import Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Patchilizer:
|
| 14 |
+
def __init__(self, stream=PATCH_STREAM):
|
| 15 |
+
self.stream = stream
|
| 16 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
| 17 |
+
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
| 18 |
+
self.bos_token_id = 1
|
| 19 |
+
self.eos_token_id = 2
|
| 20 |
+
self.special_token_id = 0
|
| 21 |
+
|
| 22 |
+
def split_bars(self, body_lines):
|
| 23 |
+
"""
|
| 24 |
+
Split a body of music into individual bars.
|
| 25 |
+
"""
|
| 26 |
+
new_bars = []
|
| 27 |
+
try:
|
| 28 |
+
for line in body_lines:
|
| 29 |
+
line_bars = re.split(self.regexPattern, line)
|
| 30 |
+
line_bars = list(filter(None, line_bars))
|
| 31 |
+
new_line_bars = []
|
| 32 |
+
|
| 33 |
+
if len(line_bars) == 1:
|
| 34 |
+
new_line_bars = line_bars
|
| 35 |
+
else:
|
| 36 |
+
if line_bars[0] in self.delimiters:
|
| 37 |
+
new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
|
| 38 |
+
else:
|
| 39 |
+
new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
|
| 40 |
+
if 'V' not in new_line_bars[-1]:
|
| 41 |
+
new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\n 的组合
|
| 42 |
+
new_line_bars = new_line_bars[:-1]
|
| 43 |
+
new_bars += new_line_bars
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
return new_bars
|
| 48 |
+
|
| 49 |
+
def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
|
| 50 |
+
if not generate_last and len(abc_text) % patch_size != 0:
|
| 51 |
+
abc_text += chr(self.eos_token_id)
|
| 52 |
+
patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
|
| 53 |
+
return patches
|
| 54 |
+
|
| 55 |
+
def patch2chars(self, patch):
|
| 56 |
+
"""
|
| 57 |
+
Convert a patch into a bar.
|
| 58 |
+
"""
|
| 59 |
+
bytes = ''
|
| 60 |
+
for idx in patch:
|
| 61 |
+
if idx == self.eos_token_id:
|
| 62 |
+
break
|
| 63 |
+
if idx < self.eos_token_id:
|
| 64 |
+
pass
|
| 65 |
+
bytes += chr(idx)
|
| 66 |
+
return bytes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def patchilize_metadata(self, metadata_lines):
|
| 70 |
+
|
| 71 |
+
metadata_patches = []
|
| 72 |
+
for line in metadata_lines:
|
| 73 |
+
metadata_patches += self.split_patches(line)
|
| 74 |
+
|
| 75 |
+
return metadata_patches
|
| 76 |
+
|
| 77 |
+
def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
|
| 78 |
+
|
| 79 |
+
tunebody_patches = []
|
| 80 |
+
bars = self.split_bars(tunebody_lines)
|
| 81 |
+
if encode_mode == 'train':
|
| 82 |
+
for bar in bars:
|
| 83 |
+
tunebody_patches += self.split_patches(bar)
|
| 84 |
+
elif encode_mode == 'generate':
|
| 85 |
+
for bar in bars[:-1]:
|
| 86 |
+
tunebody_patches += self.split_patches(bar)
|
| 87 |
+
tunebody_patches += self.split_patches(bars[-1], generate_last=True)
|
| 88 |
+
|
| 89 |
+
return tunebody_patches
|
| 90 |
+
|
| 91 |
+
def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
|
| 92 |
+
|
| 93 |
+
lines = abc_text.split('\n')
|
| 94 |
+
lines = list(filter(None, lines))
|
| 95 |
+
lines = [line + '\n' for line in lines]
|
| 96 |
+
|
| 97 |
+
tunebody_index = -1
|
| 98 |
+
for i, line in enumerate(lines):
|
| 99 |
+
if '[V:' in line:
|
| 100 |
+
tunebody_index = i
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
metadata_lines = lines[ : tunebody_index]
|
| 104 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 105 |
+
|
| 106 |
+
if self.stream:
|
| 107 |
+
tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
|
| 108 |
+
enumerate(tunebody_lines)]
|
| 109 |
+
|
| 110 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 111 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
|
| 112 |
+
|
| 113 |
+
if add_special_patches:
|
| 114 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 115 |
+
eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
|
| 116 |
+
|
| 117 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 118 |
+
tunebody_patches = tunebody_patches + [eos_patch]
|
| 119 |
+
|
| 120 |
+
if self.stream:
|
| 121 |
+
if len(metadata_patches) + len(tunebody_patches) > patch_length:
|
| 122 |
+
available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
|
| 123 |
+
line_index_for_cut_index = list(range(len(available_cut_indexes)))
|
| 124 |
+
end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
|
| 125 |
+
biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
|
| 126 |
+
available_cut_indexes = available_cut_indexes[:biggest_index + 1]
|
| 127 |
+
|
| 128 |
+
if len(available_cut_indexes) == 1:
|
| 129 |
+
choices = ['head']
|
| 130 |
+
elif len(available_cut_indexes) == 2:
|
| 131 |
+
choices = ['head', 'tail']
|
| 132 |
+
else:
|
| 133 |
+
choices = ['head', 'tail', 'middle']
|
| 134 |
+
choice = random.choice(choices)
|
| 135 |
+
if choice == 'head':
|
| 136 |
+
patches = metadata_patches + tunebody_patches[0:]
|
| 137 |
+
else:
|
| 138 |
+
if choice == 'tail':
|
| 139 |
+
cut_index = len(available_cut_indexes) - 1
|
| 140 |
+
else:
|
| 141 |
+
cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
|
| 142 |
+
|
| 143 |
+
line_index = line_index_for_cut_index[cut_index]
|
| 144 |
+
stream_tunebody_lines = tunebody_lines[line_index : ]
|
| 145 |
+
|
| 146 |
+
stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
|
| 147 |
+
if add_special_patches:
|
| 148 |
+
stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
|
| 149 |
+
patches = metadata_patches + stream_tunebody_patches
|
| 150 |
+
else:
|
| 151 |
+
patches = metadata_patches + tunebody_patches
|
| 152 |
+
else:
|
| 153 |
+
patches = metadata_patches + tunebody_patches
|
| 154 |
+
|
| 155 |
+
if cut:
|
| 156 |
+
patches = patches[ : patch_length]
|
| 157 |
+
else:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
# encode to ids
|
| 161 |
+
id_patches = []
|
| 162 |
+
for patch in patches:
|
| 163 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 164 |
+
id_patches.append(id_patch)
|
| 165 |
+
|
| 166 |
+
return id_patches
|
| 167 |
+
|
| 168 |
+
def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
|
| 169 |
+
|
| 170 |
+
lines = abc_code.split('\n')
|
| 171 |
+
lines = list(filter(None, lines))
|
| 172 |
+
|
| 173 |
+
tunebody_index = None
|
| 174 |
+
for i, line in enumerate(lines):
|
| 175 |
+
if line.startswith('[V:') or line.startswith('[r:'):
|
| 176 |
+
tunebody_index = i
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
metadata_lines = lines[ : tunebody_index]
|
| 180 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 181 |
+
|
| 182 |
+
metadata_lines = [line + '\n' for line in metadata_lines]
|
| 183 |
+
if self.stream:
|
| 184 |
+
if not abc_code.endswith('\n'):
|
| 185 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
|
| 186 |
+
else:
|
| 187 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
|
| 188 |
+
else:
|
| 189 |
+
tunebody_lines = [line + '\n' for line in tunebody_lines]
|
| 190 |
+
|
| 191 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 192 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
|
| 193 |
+
|
| 194 |
+
if add_special_patches:
|
| 195 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 196 |
+
|
| 197 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 198 |
+
|
| 199 |
+
patches = metadata_patches + tunebody_patches
|
| 200 |
+
patches = patches[ : patch_length]
|
| 201 |
+
|
| 202 |
+
# encode to ids
|
| 203 |
+
id_patches = []
|
| 204 |
+
for patch in patches:
|
| 205 |
+
if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
|
| 206 |
+
id_patch = [ord(c) for c in patch]
|
| 207 |
+
else:
|
| 208 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 209 |
+
id_patches.append(id_patch)
|
| 210 |
+
|
| 211 |
+
return id_patches
|
| 212 |
+
|
| 213 |
+
def decode(self, patches):
|
| 214 |
+
"""
|
| 215 |
+
Decode patches into music.
|
| 216 |
+
"""
|
| 217 |
+
return ''.join(self.patch2chars(patch) for patch in patches)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class PatchLevelDecoder(PreTrainedModel):
|
| 223 |
+
"""
|
| 224 |
+
A Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
| 225 |
+
It inherits PreTrainedModel from transformers.
|
| 226 |
+
"""
|
| 227 |
+
def __init__(self, config):
|
| 228 |
+
super().__init__(config)
|
| 229 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
|
| 230 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
| 231 |
+
self.base = GPT2Model(config)
|
| 232 |
+
|
| 233 |
+
def forward(self,
|
| 234 |
+
patches: torch.Tensor,
|
| 235 |
+
masks=None) -> torch.Tensor:
|
| 236 |
+
"""
|
| 237 |
+
The forward pass of the patch-level decoder model.
|
| 238 |
+
:param patches: the patches to be encoded
|
| 239 |
+
:param masks: the masks for the patches
|
| 240 |
+
:return: the encoded patches
|
| 241 |
+
"""
|
| 242 |
+
patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
|
| 243 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
|
| 244 |
+
patches = self.patch_embedding(patches.to(self.device))
|
| 245 |
+
|
| 246 |
+
if masks==None:
|
| 247 |
+
return self.base(inputs_embeds=patches)
|
| 248 |
+
else:
|
| 249 |
+
return self.base(inputs_embeds=patches,
|
| 250 |
+
attention_mask=masks)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CharLevelDecoder(PreTrainedModel):
|
| 254 |
+
"""
|
| 255 |
+
A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
|
| 256 |
+
based on the encoded patch features. It inherits PreTrainedModel from transformers.
|
| 257 |
+
"""
|
| 258 |
+
def __init__(self, config):
|
| 259 |
+
super().__init__(config)
|
| 260 |
+
self.special_token_id = 0
|
| 261 |
+
self.bos_token_id = 1
|
| 262 |
+
|
| 263 |
+
self.base = GPT2LMHeadModel(config)
|
| 264 |
+
|
| 265 |
+
def forward(self,
|
| 266 |
+
encoded_patches: torch.Tensor,
|
| 267 |
+
target_patches: torch.Tensor):
|
| 268 |
+
"""
|
| 269 |
+
The forward pass of the char-level decoder model.
|
| 270 |
+
:param encoded_patches: the encoded patches
|
| 271 |
+
:param target_patches: the target patches
|
| 272 |
+
:return: the output of the model
|
| 273 |
+
"""
|
| 274 |
+
# preparing the labels for model training
|
| 275 |
+
target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
|
| 276 |
+
# print('target_patches shape:', target_patches.shape)
|
| 277 |
+
|
| 278 |
+
target_masks = target_patches == self.special_token_id
|
| 279 |
+
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
| 280 |
+
|
| 281 |
+
# masking the labels for model training
|
| 282 |
+
target_masks = torch.ones_like(labels)
|
| 283 |
+
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
| 284 |
+
|
| 285 |
+
# select patches
|
| 286 |
+
if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
|
| 287 |
+
indices = list(range(len(target_patches)))
|
| 288 |
+
random.shuffle(indices)
|
| 289 |
+
selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
|
| 290 |
+
|
| 291 |
+
target_patches = target_patches[selected_indices,:]
|
| 292 |
+
target_masks = target_masks[selected_indices,:]
|
| 293 |
+
encoded_patches = encoded_patches[selected_indices,:]
|
| 294 |
+
|
| 295 |
+
# get input embeddings
|
| 296 |
+
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
| 297 |
+
|
| 298 |
+
# concatenate the encoded patches with the input embeddings
|
| 299 |
+
inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
|
| 300 |
+
|
| 301 |
+
output = self.base(inputs_embeds=inputs_embeds,
|
| 302 |
+
attention_mask=target_masks,
|
| 303 |
+
labels=labels)
|
| 304 |
+
# output_hidden_states=True=True)
|
| 305 |
+
|
| 306 |
+
return output
|
| 307 |
+
|
| 308 |
+
def generate(self,
|
| 309 |
+
encoded_patch: torch.Tensor, # [hidden_size]
|
| 310 |
+
tokens: torch.Tensor): # [1]
|
| 311 |
+
"""
|
| 312 |
+
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
| 313 |
+
:param encoded_patch: the encoded patch
|
| 314 |
+
:param tokens: already generated tokens in the patch
|
| 315 |
+
:return: the probability distribution of next token
|
| 316 |
+
"""
|
| 317 |
+
encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
|
| 318 |
+
tokens = tokens.reshape(1, -1)
|
| 319 |
+
|
| 320 |
+
# Get input embeddings
|
| 321 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
| 322 |
+
|
| 323 |
+
# Concatenate the encoded patch with the input embeddings
|
| 324 |
+
tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
|
| 325 |
+
|
| 326 |
+
# Get output from model
|
| 327 |
+
outputs = self.base(inputs_embeds=tokens)
|
| 328 |
+
|
| 329 |
+
# Get probabilities of next token
|
| 330 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
| 331 |
+
|
| 332 |
+
return probs
|
| 333 |
+
|
| 334 |
+
def safe_normalize_probs(probs):
|
| 335 |
+
epsilon = 1e-12
|
| 336 |
+
probs = np.array(probs, dtype=np.float64)
|
| 337 |
+
probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
|
| 338 |
+
probs = probs + epsilon
|
| 339 |
+
s = probs.sum()
|
| 340 |
+
if s > 0:
|
| 341 |
+
probs = probs / s
|
| 342 |
+
else:
|
| 343 |
+
probs = np.zeros_like(probs)
|
| 344 |
+
probs[0] = 1.0
|
| 345 |
+
return probs
|
| 346 |
+
|
| 347 |
+
class NotaGenLMHeadModel(PreTrainedModel):
|
| 348 |
+
"""
|
| 349 |
+
NotaGen is a language model with a hierarchical structure.
|
| 350 |
+
It includes a patch-level decoder and a char-level decoder.
|
| 351 |
+
The patch-level decoder is used to generate patch features in an auto-regressive manner.
|
| 352 |
+
The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
|
| 353 |
+
It inherits PreTrainedModel from transformers.
|
| 354 |
+
"""
|
| 355 |
+
def __init__(self, encoder_config, decoder_config):
|
| 356 |
+
super().__init__(encoder_config)
|
| 357 |
+
self.special_token_id = 0
|
| 358 |
+
self.bos_token_id = 1
|
| 359 |
+
self.eos_token_id = 2
|
| 360 |
+
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
| 361 |
+
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
| 362 |
+
|
| 363 |
+
def forward(self,
|
| 364 |
+
patches: torch.Tensor,
|
| 365 |
+
masks: torch.Tensor):
|
| 366 |
+
"""
|
| 367 |
+
The forward pass of the bGPT model.
|
| 368 |
+
:param patches: the patches to be encoded
|
| 369 |
+
:param masks: the masks for the patches
|
| 370 |
+
:return: the decoded patches
|
| 371 |
+
"""
|
| 372 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
| 373 |
+
encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
|
| 374 |
+
|
| 375 |
+
left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
|
| 376 |
+
masks[:, 0] = 0
|
| 377 |
+
|
| 378 |
+
encoded_patches = encoded_patches[left_shift_masks == 1]
|
| 379 |
+
patches = patches[masks == 1]
|
| 380 |
+
|
| 381 |
+
return self.char_level_decoder(encoded_patches, patches)
|
| 382 |
+
|
| 383 |
+
def generate(self,
|
| 384 |
+
patches: torch.Tensor,
|
| 385 |
+
top_k=0,
|
| 386 |
+
top_p=1,
|
| 387 |
+
temperature=1.0):
|
| 388 |
+
"""
|
| 389 |
+
The generate function for generating patches based on patches.
|
| 390 |
+
:param patches: the patches to be encoded
|
| 391 |
+
:param top_k: the top k for sampling
|
| 392 |
+
:param top_p: the top p for sampling
|
| 393 |
+
:param temperature: the temperature for sampling
|
| 394 |
+
:return: the generated patches
|
| 395 |
+
"""
|
| 396 |
+
if patches.shape[-1] % PATCH_SIZE != 0:
|
| 397 |
+
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
|
| 398 |
+
tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
|
| 399 |
+
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
|
| 400 |
+
else:
|
| 401 |
+
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
| 402 |
+
|
| 403 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
|
| 404 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
|
| 405 |
+
generated_patch = []
|
| 406 |
+
|
| 407 |
+
while True:
|
| 408 |
+
prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
|
| 409 |
+
prob = safe_normalize_probs(prob)
|
| 410 |
+
prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
|
| 411 |
+
prob = safe_normalize_probs(prob)
|
| 412 |
+
prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
|
| 413 |
+
prob = safe_normalize_probs(prob)
|
| 414 |
+
token = temperature_sampling(prob, temperature=temperature) # int
|
| 415 |
+
char = chr(token)
|
| 416 |
+
generated_patch.append(token)
|
| 417 |
+
|
| 418 |
+
if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
|
| 419 |
+
break
|
| 420 |
+
else:
|
| 421 |
+
tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
|
| 422 |
+
|
| 423 |
+
return generated_patch
|
utils (5).py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import bisect
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import numpy as np
|
| 7 |
+
from config import *
|
| 8 |
+
from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
|
| 9 |
+
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
| 10 |
+
from tokenizers import Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Patchilizer:
|
| 14 |
+
def __init__(self, stream=PATCH_STREAM):
|
| 15 |
+
self.stream = stream
|
| 16 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
| 17 |
+
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
| 18 |
+
self.bos_token_id = 1
|
| 19 |
+
self.eos_token_id = 2
|
| 20 |
+
self.special_token_id = 0
|
| 21 |
+
|
| 22 |
+
def split_bars(self, body_lines):
|
| 23 |
+
"""
|
| 24 |
+
Split a body of music into individual bars.
|
| 25 |
+
"""
|
| 26 |
+
new_bars = []
|
| 27 |
+
try:
|
| 28 |
+
for line in body_lines:
|
| 29 |
+
line_bars = re.split(self.regexPattern, line)
|
| 30 |
+
line_bars = list(filter(None, line_bars))
|
| 31 |
+
new_line_bars = []
|
| 32 |
+
|
| 33 |
+
if len(line_bars) == 1:
|
| 34 |
+
new_line_bars = line_bars
|
| 35 |
+
else:
|
| 36 |
+
if line_bars[0] in self.delimiters:
|
| 37 |
+
new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
|
| 38 |
+
else:
|
| 39 |
+
new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
|
| 40 |
+
if 'V' not in new_line_bars[-1]:
|
| 41 |
+
new_line_bars[-2] += new_line_bars[-1]
|
| 42 |
+
new_line_bars = new_line_bars[:-1]
|
| 43 |
+
new_bars += new_line_bars
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
return new_bars
|
| 48 |
+
|
| 49 |
+
def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
|
| 50 |
+
if not generate_last and len(abc_text) % patch_size != 0:
|
| 51 |
+
abc_text += chr(self.eos_token_id)
|
| 52 |
+
patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
|
| 53 |
+
return patches
|
| 54 |
+
|
| 55 |
+
def patch2chars(self, patch):
|
| 56 |
+
"""
|
| 57 |
+
Convert a patch into a bar.
|
| 58 |
+
"""
|
| 59 |
+
bytes = ''
|
| 60 |
+
for idx in patch:
|
| 61 |
+
if idx == self.eos_token_id:
|
| 62 |
+
break
|
| 63 |
+
if idx < self.eos_token_id:
|
| 64 |
+
pass
|
| 65 |
+
bytes += chr(idx)
|
| 66 |
+
return bytes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def patchilize_metadata(self, metadata_lines):
|
| 70 |
+
|
| 71 |
+
metadata_patches = []
|
| 72 |
+
for line in metadata_lines:
|
| 73 |
+
metadata_patches += self.split_patches(line)
|
| 74 |
+
|
| 75 |
+
return metadata_patches
|
| 76 |
+
|
| 77 |
+
def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
|
| 78 |
+
|
| 79 |
+
tunebody_patches = []
|
| 80 |
+
bars = self.split_bars(tunebody_lines)
|
| 81 |
+
if encode_mode == 'train':
|
| 82 |
+
for bar in bars:
|
| 83 |
+
tunebody_patches += self.split_patches(bar)
|
| 84 |
+
elif encode_mode == 'generate':
|
| 85 |
+
for bar in bars[:-1]:
|
| 86 |
+
tunebody_patches += self.split_patches(bar)
|
| 87 |
+
tunebody_patches += self.split_patches(bars[-1], generate_last=True)
|
| 88 |
+
|
| 89 |
+
return tunebody_patches
|
| 90 |
+
|
| 91 |
+
def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
|
| 92 |
+
|
| 93 |
+
lines = abc_text.split('\n')
|
| 94 |
+
lines = list(filter(None, lines))
|
| 95 |
+
lines = [line + '\n' for line in lines]
|
| 96 |
+
|
| 97 |
+
tunebody_index = -1
|
| 98 |
+
for i, line in enumerate(lines):
|
| 99 |
+
if line.startswith('[V:'):
|
| 100 |
+
tunebody_index = i
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
metadata_lines = lines[ : tunebody_index]
|
| 104 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 105 |
+
|
| 106 |
+
if self.stream:
|
| 107 |
+
tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
|
| 108 |
+
enumerate(tunebody_lines)] # [r:n/n]
|
| 109 |
+
|
| 110 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 111 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
|
| 112 |
+
|
| 113 |
+
if add_special_patches:
|
| 114 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 115 |
+
eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
|
| 116 |
+
|
| 117 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 118 |
+
tunebody_patches = tunebody_patches + [eos_patch]
|
| 119 |
+
|
| 120 |
+
if self.stream:
|
| 121 |
+
if len(metadata_patches) + len(tunebody_patches) > patch_length:
|
| 122 |
+
available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
|
| 123 |
+
line_index_for_cut_index = list(range(len(available_cut_indexes)))
|
| 124 |
+
end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
|
| 125 |
+
biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
|
| 126 |
+
available_cut_indexes = available_cut_indexes[:biggest_index + 1]
|
| 127 |
+
|
| 128 |
+
if len(available_cut_indexes) == 1:
|
| 129 |
+
choices = ['head']
|
| 130 |
+
elif len(available_cut_indexes) == 2:
|
| 131 |
+
choices = ['head', 'tail']
|
| 132 |
+
else:
|
| 133 |
+
choices = ['head', 'tail', 'middle']
|
| 134 |
+
choice = random.choice(choices)
|
| 135 |
+
if choice == 'head':
|
| 136 |
+
patches = metadata_patches + tunebody_patches[0:]
|
| 137 |
+
else:
|
| 138 |
+
if choice == 'tail':
|
| 139 |
+
cut_index = len(available_cut_indexes) - 1
|
| 140 |
+
else:
|
| 141 |
+
cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
|
| 142 |
+
|
| 143 |
+
line_index = line_index_for_cut_index[cut_index]
|
| 144 |
+
stream_tunebody_lines = tunebody_lines[line_index : ]
|
| 145 |
+
|
| 146 |
+
stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
|
| 147 |
+
if add_special_patches:
|
| 148 |
+
stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
|
| 149 |
+
patches = metadata_patches + stream_tunebody_patches
|
| 150 |
+
else:
|
| 151 |
+
patches = metadata_patches + tunebody_patches
|
| 152 |
+
else:
|
| 153 |
+
patches = metadata_patches + tunebody_patches
|
| 154 |
+
|
| 155 |
+
if cut:
|
| 156 |
+
patches = patches[ : patch_length]
|
| 157 |
+
else:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
# encode to ids
|
| 161 |
+
id_patches = []
|
| 162 |
+
for patch in patches:
|
| 163 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 164 |
+
id_patches.append(id_patch)
|
| 165 |
+
|
| 166 |
+
return id_patches
|
| 167 |
+
|
| 168 |
+
def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
|
| 169 |
+
|
| 170 |
+
lines = abc_code.split('\n')
|
| 171 |
+
lines = list(filter(None, lines))
|
| 172 |
+
|
| 173 |
+
tunebody_index = None
|
| 174 |
+
for i, line in enumerate(lines):
|
| 175 |
+
if line.startswith('[V:') or line.startswith('[r:'):
|
| 176 |
+
tunebody_index = i
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
metadata_lines = lines[ : tunebody_index]
|
| 180 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 181 |
+
|
| 182 |
+
metadata_lines = [line + '\n' for line in metadata_lines]
|
| 183 |
+
if self.stream:
|
| 184 |
+
if not abc_code.endswith('\n'):
|
| 185 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
|
| 186 |
+
else:
|
| 187 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
|
| 188 |
+
else:
|
| 189 |
+
tunebody_lines = [line + '\n' for line in tunebody_lines]
|
| 190 |
+
|
| 191 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 192 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
|
| 193 |
+
|
| 194 |
+
if add_special_patches:
|
| 195 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 196 |
+
|
| 197 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 198 |
+
|
| 199 |
+
patches = metadata_patches + tunebody_patches
|
| 200 |
+
patches = patches[ : patch_length]
|
| 201 |
+
|
| 202 |
+
# encode to ids
|
| 203 |
+
id_patches = []
|
| 204 |
+
for patch in patches:
|
| 205 |
+
if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
|
| 206 |
+
id_patch = [ord(c) for c in patch]
|
| 207 |
+
else:
|
| 208 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 209 |
+
id_patches.append(id_patch)
|
| 210 |
+
|
| 211 |
+
return id_patches
|
| 212 |
+
|
| 213 |
+
def decode(self, patches):
|
| 214 |
+
"""
|
| 215 |
+
Decode patches into music.
|
| 216 |
+
"""
|
| 217 |
+
return ''.join(self.patch2chars(patch) for patch in patches)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class PatchLevelDecoder(PreTrainedModel):
|
| 223 |
+
"""
|
| 224 |
+
A Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
| 225 |
+
It inherits PreTrainedModel from transformers.
|
| 226 |
+
"""
|
| 227 |
+
def __init__(self, config):
|
| 228 |
+
super().__init__(config)
|
| 229 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
|
| 230 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
| 231 |
+
self.base = GPT2Model(config)
|
| 232 |
+
|
| 233 |
+
def forward(self,
|
| 234 |
+
patches: torch.Tensor,
|
| 235 |
+
masks=None) -> torch.Tensor:
|
| 236 |
+
"""
|
| 237 |
+
The forward pass of the patch-level decoder model.
|
| 238 |
+
:param patches: the patches to be encoded
|
| 239 |
+
:param masks: the masks for the patches
|
| 240 |
+
:return: the encoded patches
|
| 241 |
+
"""
|
| 242 |
+
patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
|
| 243 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
|
| 244 |
+
patches = self.patch_embedding(patches.to(self.device))
|
| 245 |
+
|
| 246 |
+
if masks==None:
|
| 247 |
+
return self.base(inputs_embeds=patches)
|
| 248 |
+
else:
|
| 249 |
+
return self.base(inputs_embeds=patches,
|
| 250 |
+
attention_mask=masks)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CharLevelDecoder(PreTrainedModel):
|
| 254 |
+
"""
|
| 255 |
+
A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
|
| 256 |
+
based on the encoded patch features. It inherits PreTrainedModel from transformers.
|
| 257 |
+
"""
|
| 258 |
+
def __init__(self, config):
|
| 259 |
+
super().__init__(config)
|
| 260 |
+
self.special_token_id = 0
|
| 261 |
+
self.bos_token_id = 1
|
| 262 |
+
|
| 263 |
+
self.base = GPT2LMHeadModel(config)
|
| 264 |
+
|
| 265 |
+
def forward(self,
|
| 266 |
+
encoded_patches: torch.Tensor,
|
| 267 |
+
target_patches: torch.Tensor):
|
| 268 |
+
"""
|
| 269 |
+
The forward pass of the char-level decoder model.
|
| 270 |
+
:param encoded_patches: the encoded patches
|
| 271 |
+
:param target_patches: the target patches
|
| 272 |
+
:return: the output of the model
|
| 273 |
+
"""
|
| 274 |
+
# preparing the labels for model training
|
| 275 |
+
target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
|
| 276 |
+
|
| 277 |
+
target_masks = target_patches == self.special_token_id
|
| 278 |
+
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
| 279 |
+
|
| 280 |
+
# masking the labels for model training
|
| 281 |
+
target_masks = torch.ones_like(labels)
|
| 282 |
+
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
| 283 |
+
|
| 284 |
+
# select patches
|
| 285 |
+
if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
|
| 286 |
+
indices = list(range(len(target_patches)))
|
| 287 |
+
random.shuffle(indices)
|
| 288 |
+
selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
|
| 289 |
+
|
| 290 |
+
target_patches = target_patches[selected_indices,:]
|
| 291 |
+
target_masks = target_masks[selected_indices,:]
|
| 292 |
+
encoded_patches = encoded_patches[selected_indices,:]
|
| 293 |
+
|
| 294 |
+
# get input embeddings
|
| 295 |
+
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
| 296 |
+
|
| 297 |
+
# concatenate the encoded patches with the input embeddings
|
| 298 |
+
inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
|
| 299 |
+
|
| 300 |
+
output = self.base(inputs_embeds=inputs_embeds,
|
| 301 |
+
attention_mask=target_masks,
|
| 302 |
+
labels=labels)
|
| 303 |
+
|
| 304 |
+
return output
|
| 305 |
+
|
| 306 |
+
def generate(self,
|
| 307 |
+
encoded_patch: torch.Tensor,
|
| 308 |
+
tokens: torch.Tensor):
|
| 309 |
+
"""
|
| 310 |
+
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
| 311 |
+
:param encoded_patch: the encoded patch
|
| 312 |
+
:param tokens: already generated tokens in the patch
|
| 313 |
+
:return: the probability distribution of next token
|
| 314 |
+
"""
|
| 315 |
+
encoded_patch = encoded_patch.reshape(1, 1, -1)
|
| 316 |
+
tokens = tokens.reshape(1, -1)
|
| 317 |
+
|
| 318 |
+
# Get input embeddings
|
| 319 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
| 320 |
+
|
| 321 |
+
# Concatenate the encoded patch with the input embeddings
|
| 322 |
+
tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
|
| 323 |
+
|
| 324 |
+
# Get output from model
|
| 325 |
+
outputs = self.base(inputs_embeds=tokens)
|
| 326 |
+
|
| 327 |
+
# Get probabilities of next token
|
| 328 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
| 329 |
+
|
| 330 |
+
return probs
|
| 331 |
+
|
| 332 |
+
def safe_normalize_probs(probs):
|
| 333 |
+
epsilon = 1e-12
|
| 334 |
+
probs = np.array(probs, dtype=np.float64)
|
| 335 |
+
probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
|
| 336 |
+
probs = probs + epsilon
|
| 337 |
+
s = probs.sum()
|
| 338 |
+
if s > 0:
|
| 339 |
+
probs = probs / s
|
| 340 |
+
else:
|
| 341 |
+
probs = np.zeros_like(probs)
|
| 342 |
+
probs[0] = 1.0
|
| 343 |
+
return probs
|
| 344 |
+
|
| 345 |
+
class NotaGenLMHeadModel(PreTrainedModel):
|
| 346 |
+
"""
|
| 347 |
+
NotaGen is a language model with a hierarchical structure.
|
| 348 |
+
It includes a patch-level decoder and a char-level decoder.
|
| 349 |
+
The patch-level decoder is used to generate patch features in an auto-regressive manner.
|
| 350 |
+
The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
|
| 351 |
+
It inherits PreTrainedModel from transformers.
|
| 352 |
+
"""
|
| 353 |
+
def __init__(self, encoder_config, decoder_config):
|
| 354 |
+
super().__init__(encoder_config)
|
| 355 |
+
self.special_token_id = 0
|
| 356 |
+
self.bos_token_id = 1
|
| 357 |
+
self.eos_token_id = 2
|
| 358 |
+
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
| 359 |
+
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
| 360 |
+
|
| 361 |
+
def forward(self,
|
| 362 |
+
patches: torch.Tensor,
|
| 363 |
+
masks: torch.Tensor):
|
| 364 |
+
"""
|
| 365 |
+
The forward pass of the bGPT model.
|
| 366 |
+
:param patches: the patches to be encoded
|
| 367 |
+
:param masks: the masks for the patches
|
| 368 |
+
:return: the decoded patches
|
| 369 |
+
"""
|
| 370 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
| 371 |
+
encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
|
| 372 |
+
|
| 373 |
+
left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
|
| 374 |
+
masks[:, 0] = 0
|
| 375 |
+
|
| 376 |
+
encoded_patches = encoded_patches[left_shift_masks == 1]
|
| 377 |
+
patches = patches[masks == 1]
|
| 378 |
+
|
| 379 |
+
return self.char_level_decoder(encoded_patches, patches)
|
| 380 |
+
|
| 381 |
+
def generate(self,
|
| 382 |
+
patches: torch.Tensor,
|
| 383 |
+
top_k=0,
|
| 384 |
+
top_p=1,
|
| 385 |
+
temperature=1.0):
|
| 386 |
+
"""
|
| 387 |
+
The generate function for generating patches based on patches.
|
| 388 |
+
:param patches: the patches to be encoded
|
| 389 |
+
:param top_k: the top k for sampling
|
| 390 |
+
:param top_p: the top p for sampling
|
| 391 |
+
:param temperature: the temperature for sampling
|
| 392 |
+
:return: the generated patches
|
| 393 |
+
"""
|
| 394 |
+
if patches.shape[-1] % PATCH_SIZE != 0:
|
| 395 |
+
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
|
| 396 |
+
tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
|
| 397 |
+
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
|
| 398 |
+
else:
|
| 399 |
+
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
| 400 |
+
|
| 401 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
| 402 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
|
| 403 |
+
generated_patch = []
|
| 404 |
+
|
| 405 |
+
while True:
|
| 406 |
+
prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy()
|
| 407 |
+
prob = safe_normalize_probs(prob)
|
| 408 |
+
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
|
| 409 |
+
prob = safe_normalize_probs(prob)
|
| 410 |
+
prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
|
| 411 |
+
prob = safe_normalize_probs(prob)
|
| 412 |
+
token = temperature_sampling(prob, temperature=temperature)
|
| 413 |
+
char = chr(token)
|
| 414 |
+
generated_patch.append(token)
|
| 415 |
+
|
| 416 |
+
if len(tokens) >= PATCH_SIZE:
|
| 417 |
+
break
|
| 418 |
+
else:
|
| 419 |
+
tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
|
| 420 |
+
|
| 421 |
+
return generated_patch
|
utils.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import bisect
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import numpy as np
|
| 7 |
+
from config import *
|
| 8 |
+
from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
|
| 9 |
+
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
| 10 |
+
from tokenizers import Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Patchilizer:
|
| 14 |
+
def __init__(self, stream=PATCH_STREAM):
|
| 15 |
+
self.stream = stream
|
| 16 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
| 17 |
+
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
| 18 |
+
self.bos_token_id = 1
|
| 19 |
+
self.eos_token_id = 2
|
| 20 |
+
self.special_token_id = 0
|
| 21 |
+
|
| 22 |
+
def split_bars(self, body_lines):
|
| 23 |
+
"""
|
| 24 |
+
Split a body of music into individual bars.
|
| 25 |
+
"""
|
| 26 |
+
new_bars = []
|
| 27 |
+
try:
|
| 28 |
+
for line in body_lines:
|
| 29 |
+
line_bars = re.split(self.regexPattern, line)
|
| 30 |
+
line_bars = list(filter(None, line_bars))
|
| 31 |
+
new_line_bars = []
|
| 32 |
+
|
| 33 |
+
if len(line_bars) == 1:
|
| 34 |
+
new_line_bars = line_bars
|
| 35 |
+
else:
|
| 36 |
+
if line_bars[0] in self.delimiters:
|
| 37 |
+
new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
|
| 38 |
+
else:
|
| 39 |
+
new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
|
| 40 |
+
if 'V' not in new_line_bars[-1]:
|
| 41 |
+
new_line_bars[-2] += new_line_bars[-1]
|
| 42 |
+
new_line_bars = new_line_bars[:-1]
|
| 43 |
+
new_bars += new_line_bars
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
return new_bars
|
| 48 |
+
|
| 49 |
+
def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
|
| 50 |
+
if not generate_last and len(abc_text) % patch_size != 0:
|
| 51 |
+
abc_text += chr(self.eos_token_id)
|
| 52 |
+
patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
|
| 53 |
+
return patches
|
| 54 |
+
|
| 55 |
+
def patch2chars(self, patch):
|
| 56 |
+
"""
|
| 57 |
+
Convert a patch into a bar.
|
| 58 |
+
"""
|
| 59 |
+
bytes = ''
|
| 60 |
+
for idx in patch:
|
| 61 |
+
if idx == self.eos_token_id:
|
| 62 |
+
break
|
| 63 |
+
if idx < self.eos_token_id:
|
| 64 |
+
pass
|
| 65 |
+
bytes += chr(idx)
|
| 66 |
+
return bytes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def patchilize_metadata(self, metadata_lines):
|
| 70 |
+
|
| 71 |
+
metadata_patches = []
|
| 72 |
+
for line in metadata_lines:
|
| 73 |
+
metadata_patches += self.split_patches(line)
|
| 74 |
+
|
| 75 |
+
return metadata_patches
|
| 76 |
+
|
| 77 |
+
def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
|
| 78 |
+
|
| 79 |
+
tunebody_patches = []
|
| 80 |
+
bars = self.split_bars(tunebody_lines)
|
| 81 |
+
if encode_mode == 'train':
|
| 82 |
+
for bar in bars:
|
| 83 |
+
tunebody_patches += self.split_patches(bar)
|
| 84 |
+
elif encode_mode == 'generate':
|
| 85 |
+
for bar in bars[:-1]:
|
| 86 |
+
tunebody_patches += self.split_patches(bar)
|
| 87 |
+
tunebody_patches += self.split_patches(bars[-1], generate_last=True)
|
| 88 |
+
|
| 89 |
+
return tunebody_patches
|
| 90 |
+
|
| 91 |
+
def encode(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
|
| 92 |
+
|
| 93 |
+
lines = abc_text.split('\n')
|
| 94 |
+
lines = list(filter(None, lines))
|
| 95 |
+
lines = [line + '\n' for line in lines]
|
| 96 |
+
|
| 97 |
+
tunebody_index = -1
|
| 98 |
+
for i, line in enumerate(lines):
|
| 99 |
+
if line.startswith('[r:'):
|
| 100 |
+
tunebody_index = i
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
metadata_lines = lines[: tunebody_index]
|
| 104 |
+
tunebody_lines = lines[tunebody_index:]
|
| 105 |
+
|
| 106 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 107 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
|
| 108 |
+
|
| 109 |
+
if add_special_patches:
|
| 110 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 111 |
+
eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
|
| 112 |
+
|
| 113 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 114 |
+
tunebody_patches = tunebody_patches + [eos_patch]
|
| 115 |
+
|
| 116 |
+
if self.stream:
|
| 117 |
+
if len(metadata_patches) + len(tunebody_patches) > patch_length:
|
| 118 |
+
available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if
|
| 119 |
+
'\n' in patch]
|
| 120 |
+
line_index_for_cut_index = list(range(len(available_cut_indexes)))
|
| 121 |
+
end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
|
| 122 |
+
biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
|
| 123 |
+
available_cut_indexes = available_cut_indexes[:biggest_index + 1]
|
| 124 |
+
|
| 125 |
+
if len(available_cut_indexes) == 1:
|
| 126 |
+
choices = ['head']
|
| 127 |
+
elif len(available_cut_indexes) == 2:
|
| 128 |
+
choices = ['head', 'tail']
|
| 129 |
+
else:
|
| 130 |
+
choices = ['head', 'tail', 'middle']
|
| 131 |
+
choice = random.choice(choices)
|
| 132 |
+
if choice == 'head':
|
| 133 |
+
patches = metadata_patches + tunebody_patches[0:]
|
| 134 |
+
else:
|
| 135 |
+
if choice == 'tail':
|
| 136 |
+
cut_index = len(available_cut_indexes) - 1
|
| 137 |
+
else:
|
| 138 |
+
cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
|
| 139 |
+
|
| 140 |
+
line_index = line_index_for_cut_index[cut_index]
|
| 141 |
+
stream_tunebody_lines = tunebody_lines[line_index:]
|
| 142 |
+
|
| 143 |
+
stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
|
| 144 |
+
if add_special_patches:
|
| 145 |
+
stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
|
| 146 |
+
patches = metadata_patches + stream_tunebody_patches
|
| 147 |
+
else:
|
| 148 |
+
patches = metadata_patches + tunebody_patches
|
| 149 |
+
else:
|
| 150 |
+
patches = metadata_patches + tunebody_patches
|
| 151 |
+
|
| 152 |
+
patches = patches[: patch_length]
|
| 153 |
+
|
| 154 |
+
# encode to ids
|
| 155 |
+
id_patches = []
|
| 156 |
+
for patch in patches:
|
| 157 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 158 |
+
id_patches.append(id_patch)
|
| 159 |
+
|
| 160 |
+
return id_patches
|
| 161 |
+
|
| 162 |
+
def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
|
| 163 |
+
|
| 164 |
+
lines = abc_code.split('\n')
|
| 165 |
+
lines = list(filter(None, lines))
|
| 166 |
+
|
| 167 |
+
tunebody_index = None
|
| 168 |
+
for i, line in enumerate(lines):
|
| 169 |
+
if line.startswith('[V:') or line.startswith('[r:'):
|
| 170 |
+
tunebody_index = i
|
| 171 |
+
break
|
| 172 |
+
|
| 173 |
+
metadata_lines = lines[ : tunebody_index]
|
| 174 |
+
tunebody_lines = lines[tunebody_index : ]
|
| 175 |
+
|
| 176 |
+
metadata_lines = [line + '\n' for line in metadata_lines]
|
| 177 |
+
if self.stream:
|
| 178 |
+
if not abc_code.endswith('\n'):
|
| 179 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
|
| 180 |
+
else:
|
| 181 |
+
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
|
| 182 |
+
else:
|
| 183 |
+
tunebody_lines = [line + '\n' for line in tunebody_lines]
|
| 184 |
+
|
| 185 |
+
metadata_patches = self.patchilize_metadata(metadata_lines)
|
| 186 |
+
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
|
| 187 |
+
|
| 188 |
+
if add_special_patches:
|
| 189 |
+
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
| 190 |
+
|
| 191 |
+
metadata_patches = [bos_patch] + metadata_patches
|
| 192 |
+
|
| 193 |
+
patches = metadata_patches + tunebody_patches
|
| 194 |
+
patches = patches[ : patch_length]
|
| 195 |
+
|
| 196 |
+
# encode to ids
|
| 197 |
+
id_patches = []
|
| 198 |
+
for patch in patches:
|
| 199 |
+
if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
|
| 200 |
+
id_patch = [ord(c) for c in patch]
|
| 201 |
+
else:
|
| 202 |
+
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
| 203 |
+
id_patches.append(id_patch)
|
| 204 |
+
|
| 205 |
+
return id_patches
|
| 206 |
+
|
| 207 |
+
def decode(self, patches):
|
| 208 |
+
"""
|
| 209 |
+
Decode patches into music.
|
| 210 |
+
"""
|
| 211 |
+
return ''.join(self.patch2chars(patch) for patch in patches)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class PatchLevelDecoder(PreTrainedModel):
|
| 217 |
+
"""
|
| 218 |
+
A Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
| 219 |
+
It inherits PreTrainedModel from transformers.
|
| 220 |
+
"""
|
| 221 |
+
def __init__(self, config):
|
| 222 |
+
super().__init__(config)
|
| 223 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
|
| 224 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
| 225 |
+
self.base = GPT2Model(config)
|
| 226 |
+
|
| 227 |
+
def forward(self,
|
| 228 |
+
patches: torch.Tensor,
|
| 229 |
+
masks=None) -> torch.Tensor:
|
| 230 |
+
"""
|
| 231 |
+
The forward pass of the patch-level decoder model.
|
| 232 |
+
:param patches: the patches to be encoded
|
| 233 |
+
:param masks: the masks for the patches
|
| 234 |
+
:return: the encoded patches
|
| 235 |
+
"""
|
| 236 |
+
patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
|
| 237 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
|
| 238 |
+
patches = self.patch_embedding(patches.to(self.device))
|
| 239 |
+
|
| 240 |
+
if masks==None:
|
| 241 |
+
return self.base(inputs_embeds=patches)
|
| 242 |
+
else:
|
| 243 |
+
return self.base(inputs_embeds=patches,
|
| 244 |
+
attention_mask=masks)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class CharLevelDecoder(PreTrainedModel):
|
| 248 |
+
"""
|
| 249 |
+
A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
|
| 250 |
+
based on the encoded patch features. It inherits PreTrainedModel from transformers.
|
| 251 |
+
"""
|
| 252 |
+
def __init__(self, config):
|
| 253 |
+
super().__init__(config)
|
| 254 |
+
self.special_token_id = 0
|
| 255 |
+
self.bos_token_id = 1
|
| 256 |
+
|
| 257 |
+
self.base = GPT2LMHeadModel(config)
|
| 258 |
+
|
| 259 |
+
def forward(self,
|
| 260 |
+
encoded_patches: torch.Tensor,
|
| 261 |
+
target_patches: torch.Tensor):
|
| 262 |
+
"""
|
| 263 |
+
The forward pass of the char-level decoder model.
|
| 264 |
+
:param encoded_patches: the encoded patches
|
| 265 |
+
:param target_patches: the target patches
|
| 266 |
+
:return: the output of the model
|
| 267 |
+
"""
|
| 268 |
+
target_patches = torch.cat((torch.ones_like(target_patches[:, 0:1]) * self.bos_token_id,
|
| 269 |
+
target_patches), dim=1) # [patch_len, patch_size + 1]
|
| 270 |
+
|
| 271 |
+
target_masks = target_patches == self.special_token_id # [patch_len, patch_size + 1]
|
| 272 |
+
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
| 273 |
+
|
| 274 |
+
target_masks = torch.ones_like(labels)
|
| 275 |
+
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
| 276 |
+
|
| 277 |
+
input_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
| 278 |
+
input_embeds = torch.cat((encoded_patches.unsqueeze(1), input_embeds[:, 1:, :]), dim=1)
|
| 279 |
+
logits = self.base(inputs_embeds=input_embeds,
|
| 280 |
+
attention_mask=target_masks).logits # [patch_len, patch_size + 1, vocab_size]
|
| 281 |
+
logits = logits[:, :-1, :]
|
| 282 |
+
token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=target_patches[:, 1:].unsqueeze(-1)).squeeze(-1) # [patch_len, patch_size]
|
| 283 |
+
token_logps = token_logps[target_masks[:, 1:] == 1]
|
| 284 |
+
all_logps = token_logps.sum()
|
| 285 |
+
|
| 286 |
+
return all_logps
|
| 287 |
+
|
| 288 |
+
def generate(self,
|
| 289 |
+
encoded_patch: torch.Tensor, # [hidden_size]
|
| 290 |
+
tokens: torch.Tensor): # [1]
|
| 291 |
+
"""
|
| 292 |
+
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
| 293 |
+
:param encoded_patch: the encoded patch
|
| 294 |
+
:param tokens: already generated tokens in the patch
|
| 295 |
+
:return: the probability distribution of next token
|
| 296 |
+
"""
|
| 297 |
+
encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
|
| 298 |
+
tokens = tokens.reshape(1, -1)
|
| 299 |
+
|
| 300 |
+
# Get input embeddings
|
| 301 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
| 302 |
+
|
| 303 |
+
# Concatenate the encoded patch with the input embeddings
|
| 304 |
+
tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
|
| 305 |
+
|
| 306 |
+
# Get output from model
|
| 307 |
+
outputs = self.base(inputs_embeds=tokens)
|
| 308 |
+
|
| 309 |
+
# Get probabilities of next token
|
| 310 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
| 311 |
+
|
| 312 |
+
return probs
|
| 313 |
+
|
| 314 |
+
def safe_normalize_probs(probs):
|
| 315 |
+
epsilon = 1e-12
|
| 316 |
+
probs = np.array(probs, dtype=np.float64)
|
| 317 |
+
probs = np.where(np.isnan(probs) | (probs < 0), 0, probs)
|
| 318 |
+
probs = probs + epsilon
|
| 319 |
+
s = probs.sum()
|
| 320 |
+
if s > 0:
|
| 321 |
+
probs = probs / s
|
| 322 |
+
else:
|
| 323 |
+
probs = np.zeros_like(probs)
|
| 324 |
+
probs[0] = 1.0
|
| 325 |
+
return probs
|
| 326 |
+
|
| 327 |
+
class NotaGenLMHeadModel(PreTrainedModel):
|
| 328 |
+
"""
|
| 329 |
+
NotaGen is a language model with a hierarchical structure.
|
| 330 |
+
It includes a patch-level decoder and a char-level decoder.
|
| 331 |
+
The patch-level decoder is used to generate patch features in an auto-regressive manner.
|
| 332 |
+
The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
|
| 333 |
+
It inherits PreTrainedModel from transformers.
|
| 334 |
+
"""
|
| 335 |
+
def __init__(self, encoder_config, decoder_config):
|
| 336 |
+
super().__init__(encoder_config)
|
| 337 |
+
self.special_token_id = 0
|
| 338 |
+
self.bos_token_id = 1
|
| 339 |
+
self.eos_token_id = 2
|
| 340 |
+
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
| 341 |
+
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
| 342 |
+
|
| 343 |
+
def forward(self,
|
| 344 |
+
patches: torch.Tensor,
|
| 345 |
+
masks: torch.Tensor):
|
| 346 |
+
"""
|
| 347 |
+
The forward pass of the bGPT model.
|
| 348 |
+
:param patches: the patches to be encoded
|
| 349 |
+
:param masks: the masks for the patches
|
| 350 |
+
:return: the decoded patches
|
| 351 |
+
"""
|
| 352 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
| 353 |
+
encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
|
| 354 |
+
|
| 355 |
+
left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
|
| 356 |
+
masks[:, 0] = 0
|
| 357 |
+
|
| 358 |
+
encoded_patches = encoded_patches[left_shift_masks == 1]
|
| 359 |
+
patches = patches[masks == 1]
|
| 360 |
+
|
| 361 |
+
return self.char_level_decoder(encoded_patches, patches)
|
| 362 |
+
|
| 363 |
+
def generate(self,
|
| 364 |
+
patches: torch.Tensor,
|
| 365 |
+
top_k=0,
|
| 366 |
+
top_p=1,
|
| 367 |
+
temperature=1.0):
|
| 368 |
+
"""
|
| 369 |
+
The generate function for generating patches based on patches.
|
| 370 |
+
:param patches: the patches to be encoded
|
| 371 |
+
:param top_k: the top k for sampling
|
| 372 |
+
:param top_p: the top p for sampling
|
| 373 |
+
:param temperature: the temperature for sampling
|
| 374 |
+
:return: the generated patches
|
| 375 |
+
"""
|
| 376 |
+
if patches.shape[-1] % PATCH_SIZE != 0:
|
| 377 |
+
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
|
| 378 |
+
tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
|
| 379 |
+
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
|
| 380 |
+
else:
|
| 381 |
+
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
| 382 |
+
|
| 383 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
|
| 384 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
|
| 385 |
+
generated_patch = []
|
| 386 |
+
|
| 387 |
+
while True:
|
| 388 |
+
prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
|
| 389 |
+
prob = safe_normalize_probs(prob)
|
| 390 |
+
prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
|
| 391 |
+
prob = safe_normalize_probs(prob)
|
| 392 |
+
prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
|
| 393 |
+
prob = safe_normalize_probs(prob)
|
| 394 |
+
token = temperature_sampling(prob, temperature=temperature) # int
|
| 395 |
+
char = chr(token)
|
| 396 |
+
generated_patch.append(token)
|
| 397 |
+
|
| 398 |
+
if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
|
| 399 |
+
break
|
| 400 |
+
else:
|
| 401 |
+
tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
|
| 402 |
+
|
| 403 |
+
return generated_patch
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
|
xml2abc.py
ADDED
|
@@ -0,0 +1,1609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=latin-1
|
| 3 |
+
'''
|
| 4 |
+
Copyright (C) 2012-2018: W.G. Vree
|
| 5 |
+
Contributions: M. Tarenskeen, N. Liberg, Paul Villiger, Janus Meuris, Larry Myerscough,
|
| 6 |
+
Dick Jackson, Jan Wybren de Jong, Mark Zealey.
|
| 7 |
+
|
| 8 |
+
This program is free software; you can redistribute it and/or modify it under the terms of the
|
| 9 |
+
Lesser GNU General Public License as published by the Free Software Foundation;
|
| 10 |
+
|
| 11 |
+
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
| 12 |
+
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
| 13 |
+
See the Lesser GNU General Public License for more details. <http://www.gnu.org/licenses/lgpl.html>.
|
| 14 |
+
'''
|
| 15 |
+
|
| 16 |
+
'''Small revisions made for NotaGen to improve the succuess rate of conversion.'''
|
| 17 |
+
|
| 18 |
+
try: import xml.etree.cElementTree as E
|
| 19 |
+
except: import xml.etree.ElementTree as E
|
| 20 |
+
import os, sys, types, re, math
|
| 21 |
+
|
| 22 |
+
VERSION = 143
|
| 23 |
+
|
| 24 |
+
python3 = sys.version_info.major > 2
|
| 25 |
+
if python3:
|
| 26 |
+
tupletype = tuple
|
| 27 |
+
listtype = list
|
| 28 |
+
max_int = sys.maxsize
|
| 29 |
+
else:
|
| 30 |
+
tupletype = types.TupleType
|
| 31 |
+
listtype = types.ListType
|
| 32 |
+
max_int = sys.maxint
|
| 33 |
+
|
| 34 |
+
note_ornamentation_map = { # for notations/, modified from EasyABC
|
| 35 |
+
'ornaments/trill-mark': 'T',
|
| 36 |
+
'ornaments/mordent': 'M',
|
| 37 |
+
'ornaments/inverted-mordent': 'P',
|
| 38 |
+
'ornaments/turn': '!turn!',
|
| 39 |
+
'ornaments/inverted-turn': '!invertedturn!',
|
| 40 |
+
'technical/up-bow': 'u',
|
| 41 |
+
'technical/down-bow': 'v',
|
| 42 |
+
'technical/harmonic': '!open!',
|
| 43 |
+
'technical/open-string': '!open!',
|
| 44 |
+
'technical/stopped': '!plus!',
|
| 45 |
+
'technical/snap-pizzicato': '!snap!',
|
| 46 |
+
'technical/thumb-position': '!thumb!',
|
| 47 |
+
'articulations/accent': '!>!',
|
| 48 |
+
'articulations/strong-accent':'!^!',
|
| 49 |
+
'articulations/staccato': '.',
|
| 50 |
+
'articulations/staccatissimo':'!wedge!',
|
| 51 |
+
'articulations/scoop': '!slide!',
|
| 52 |
+
'fermata': '!fermata!',
|
| 53 |
+
'arpeggiate': '!arpeggio!',
|
| 54 |
+
'articulations/tenuto': '!tenuto!',
|
| 55 |
+
'articulations/staccatissimo':'!wedge!', # not sure whether this is the right translation
|
| 56 |
+
'articulations/spiccato': '!wedge!', # not sure whether this is the right translation
|
| 57 |
+
'articulations/breath-mark': '!breath!', # this may need to be tested to make sure it appears on the right side of the note
|
| 58 |
+
'articulations/detached-legato': '!tenuto!.',
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
dynamics_map = { # for direction/direction-type/dynamics/
|
| 62 |
+
'p': '!p!',
|
| 63 |
+
'pp': '!pp!',
|
| 64 |
+
'ppp': '!ppp!',
|
| 65 |
+
'pppp': '!pppp!',
|
| 66 |
+
'f': '!f!',
|
| 67 |
+
'ff': '!ff!',
|
| 68 |
+
'fff': '!fff!',
|
| 69 |
+
'ffff': '!ffff!',
|
| 70 |
+
'mp': '!mp!',
|
| 71 |
+
'mf': '!mf!',
|
| 72 |
+
'sfz': '!sfz!',
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
percSvg = '''%%beginsvg
|
| 76 |
+
<defs>
|
| 77 |
+
<text id="x" x="-3" y="0"></text>
|
| 78 |
+
<text id="x-" x="-3" y="0"></text>
|
| 79 |
+
<text id="x+" x="-3" y="0"></text>
|
| 80 |
+
<text id="normal" x="-3.7" y="0"></text>
|
| 81 |
+
<text id="normal-" x="-3.7" y="0"></text>
|
| 82 |
+
<text id="normal+" x="-3.7" y="0"></text>
|
| 83 |
+
<g id="circle-x"><text x="-3" y="0"></text><circle r="4" class="stroke"></circle></g>
|
| 84 |
+
<g id="circle-x-"><text x="-3" y="0"></text><circle r="4" class="stroke"></circle></g>
|
| 85 |
+
<path id="triangle" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
|
| 86 |
+
<path id="triangle-" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
|
| 87 |
+
<path id="triangle+" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="fill:#000"></path>
|
| 88 |
+
<path id="square" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
|
| 89 |
+
<path id="square-" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
|
| 90 |
+
<path id="square+" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="fill:#000"></path>
|
| 91 |
+
<path id="diamond" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
|
| 92 |
+
<path id="diamond-" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
|
| 93 |
+
<path id="diamond+" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="fill:#000"></path>
|
| 94 |
+
</defs>
|
| 95 |
+
%%endsvg'''
|
| 96 |
+
|
| 97 |
+
tabSvg = '''%%beginsvg
|
| 98 |
+
<style type="text/css">
|
| 99 |
+
.bf {font-family:sans-serif; font-size:7px}
|
| 100 |
+
</style>
|
| 101 |
+
<defs>
|
| 102 |
+
<rect id="clr" x="-3" y="-1" width="6" height="5" fill="white"></rect>
|
| 103 |
+
<rect id="clr2" x="-3" y="-1" width="11" height="5" fill="white"></rect>'''
|
| 104 |
+
|
| 105 |
+
kopSvg = '<g id="kop%s" class="bf"><use xlink:href="#clr"></use><text x="-2" y="3">%s</text></g>\n'
|
| 106 |
+
kopSvg2 = '<g id="kop%s" class="bf"><use xlink:href="#clr2"></use><text x="-2" y="3">%s</text></g>\n'
|
| 107 |
+
|
| 108 |
+
def info (s, warn=1): sys.stderr.write ((warn and '-- ' or '') + s + '\n')
|
| 109 |
+
|
| 110 |
+
#-------------------
|
| 111 |
+
# data abstractions
|
| 112 |
+
#-------------------
|
| 113 |
+
class Measure:
|
| 114 |
+
def __init__ (s, p):
|
| 115 |
+
s.reset ()
|
| 116 |
+
s.ixp = p # part number
|
| 117 |
+
s.ixm = 0 # measure number
|
| 118 |
+
s.mdur = 0 # measure duration (nominal metre value in divisions)
|
| 119 |
+
s.divs = 0 # number of divisions per 1/4
|
| 120 |
+
s.mtr = 4,4 # meter
|
| 121 |
+
|
| 122 |
+
def reset (s): # reset each measure
|
| 123 |
+
s.attr = '' # measure signatures, tempo
|
| 124 |
+
s.lline = '' # left barline, but only holds ':' at start of repeat, otherwise empty
|
| 125 |
+
s.rline = '|' # right barline
|
| 126 |
+
s.lnum = '' # (left) volta number
|
| 127 |
+
|
| 128 |
+
class Note:
|
| 129 |
+
def __init__ (s, dur=0, n=None):
|
| 130 |
+
s.tijd = 0 # the time in XML division units
|
| 131 |
+
s.dur = dur # duration of a note in XML divisions
|
| 132 |
+
s.fact = None # time modification for tuplet notes (num, div)
|
| 133 |
+
s.tup = [''] # start(s) and/or stop(s) of tuplet
|
| 134 |
+
s.tupabc = '' # abc tuplet string to issue before note
|
| 135 |
+
s.beam = 0 # 1 = beamed
|
| 136 |
+
s.grace = 0 # 1 = grace note
|
| 137 |
+
s.before = [] # abc string that goes before the note/chord
|
| 138 |
+
s.after = '' # the same after the note/chord
|
| 139 |
+
s.ns = n and [n] or [] # notes in the chord
|
| 140 |
+
s.lyrs = {} # {number -> syllabe}
|
| 141 |
+
s.tab = None # (string number, fret number)
|
| 142 |
+
s.ntdec = '' # !string!, !courtesy!
|
| 143 |
+
|
| 144 |
+
class Elem:
|
| 145 |
+
def __init__ (s, string):
|
| 146 |
+
s.tijd = 0 # the time in XML division units
|
| 147 |
+
s.str = string # any abc string that is not a note
|
| 148 |
+
|
| 149 |
+
class Counter:
|
| 150 |
+
def inc (s, key, voice): s.counters [key][voice] = s.counters [key].get (voice, 0) + 1
|
| 151 |
+
def clear (s, vnums): # reset all counters
|
| 152 |
+
tups = list( zip (vnums.keys (), len (vnums) * [0]))
|
| 153 |
+
s.counters = {'note': dict (tups), 'nopr': dict (tups), 'nopt': dict (tups)}
|
| 154 |
+
def getv (s, key, voice): return s.counters[key][voice]
|
| 155 |
+
def prcnt (s, ip): # print summary of all non zero counters
|
| 156 |
+
for iv in s.counters ['note']:
|
| 157 |
+
if s.getv ('nopr', iv) != 0:
|
| 158 |
+
info ( 'part %d, voice %d has %d skipped non printable notes' % (ip, iv, s.getv ('nopr', iv)))
|
| 159 |
+
if s.getv ('nopt', iv) != 0:
|
| 160 |
+
info ( 'part %d, voice %d has %d notes without pitch' % (ip, iv, s.getv ('nopt', iv)))
|
| 161 |
+
if s.getv ('note', iv) == 0: # no real notes counted in this voice
|
| 162 |
+
info ( 'part %d, skipped empty voice %d' % (ip, iv))
|
| 163 |
+
|
| 164 |
+
class Music:
|
| 165 |
+
def __init__(s, options):
|
| 166 |
+
s.tijd = 0 # the current time
|
| 167 |
+
s.maxtime = 0 # maximum time in a measure
|
| 168 |
+
s.gMaten = [] # [voices,.. for all measures in a part]
|
| 169 |
+
s.gLyrics = [] # [{num: (abc_lyric_string, melis)},.. for all measures in a part]
|
| 170 |
+
s.vnums = {} # all used voice id's in a part (xml voice id's == numbers)
|
| 171 |
+
s.cnt = Counter () # global counter object
|
| 172 |
+
s.vceCnt = 1 # the global voice count over all parts
|
| 173 |
+
s.lastnote = None # the last real note record inserted in s.voices
|
| 174 |
+
s.bpl = options.b # the max number of bars per line when writing abc
|
| 175 |
+
s.cpl = options.n # the number of chars per line when writing abc
|
| 176 |
+
s.repbra = 0 # true if volta is used somewhere
|
| 177 |
+
s.nvlt = options.v # no volta on higher voice numbers
|
| 178 |
+
s.jscript = options.j # compatibility with javascript version
|
| 179 |
+
|
| 180 |
+
def initVoices (s, newPart=0):
|
| 181 |
+
s.vtimes, s.voices, s.lyrics = {}, {}, {}
|
| 182 |
+
for v in s.vnums:
|
| 183 |
+
s.vtimes [v] = 0 # {voice: the end time of the last item in each voice}
|
| 184 |
+
s.voices [v] = [] # {voice: [Note|Elem, ..]}
|
| 185 |
+
s.lyrics [v] = [] # {voice: [{num: syl}, ..]}
|
| 186 |
+
if newPart: s.cnt.clear (s.vnums) # clear counters once per part
|
| 187 |
+
|
| 188 |
+
def incTime (s, dt):
|
| 189 |
+
s.tijd += dt
|
| 190 |
+
if s.tijd < 0: s.tijd = 0 # erroneous <backup> element
|
| 191 |
+
if s.tijd > s.maxtime: s.maxtime = s.tijd
|
| 192 |
+
|
| 193 |
+
def appendElemCv (s, voices, elem):
|
| 194 |
+
for v in voices:
|
| 195 |
+
s.appendElem (v, elem) # insert element in all voices
|
| 196 |
+
|
| 197 |
+
def insertElem (s, v, elem): # insert at the start of voice v in the current measure
|
| 198 |
+
obj = Elem (elem)
|
| 199 |
+
obj.tijd = 0 # because voice is sorted later
|
| 200 |
+
s.voices [v].insert (0, obj)
|
| 201 |
+
|
| 202 |
+
def appendObj (s, v, obj, dur):
|
| 203 |
+
obj.tijd = s.tijd
|
| 204 |
+
s.voices [v].append (obj)
|
| 205 |
+
s.incTime (dur)
|
| 206 |
+
if s.tijd > s.vtimes[v]: s.vtimes[v] = s.tijd # don't update for inserted earlier items
|
| 207 |
+
|
| 208 |
+
def appendElem (s, v, elem, tel=0):
|
| 209 |
+
s.appendObj (v, Elem (elem), 0)
|
| 210 |
+
if tel: s.cnt.inc ('note', v) # count number of certain elements in each voice (in addition to notes)
|
| 211 |
+
|
| 212 |
+
def appendElemT (s, v, elem, tijd): # insert element at specified time
|
| 213 |
+
obj = Elem (elem)
|
| 214 |
+
obj.tijd = tijd
|
| 215 |
+
s.voices [v].append (obj)
|
| 216 |
+
|
| 217 |
+
def appendNote (s, v, note, noot):
|
| 218 |
+
note.ns.append (note.ntdec + noot)
|
| 219 |
+
s.appendObj (v, note, int (note.dur))
|
| 220 |
+
s.lastnote = note # remember last note/rest for later modifications (chord, grace)
|
| 221 |
+
if noot != 'z' and noot != 'x': # real notes and grace notes
|
| 222 |
+
s.cnt.inc ('note', v) # count number of real notes in each voice
|
| 223 |
+
if not note.grace: # for every real note
|
| 224 |
+
s.lyrics[v].append (note.lyrs) # even when it has no lyrics
|
| 225 |
+
|
| 226 |
+
def getLastRec (s, voice):
|
| 227 |
+
if s.gMaten: return s.gMaten[-1][voice][-1] # the last record in the last measure
|
| 228 |
+
return None # no previous records in the first measure
|
| 229 |
+
|
| 230 |
+
def getLastMelis (s, voice, num): # get melisma of last measure
|
| 231 |
+
if s.gLyrics:
|
| 232 |
+
lyrdict = s.gLyrics[-1][voice] # the previous lyrics dict in this voice
|
| 233 |
+
if num in lyrdict: return lyrdict[num][1] # lyrdict = num -> (lyric string, melisma)
|
| 234 |
+
return 0 # no previous lyrics in voice or line number
|
| 235 |
+
|
| 236 |
+
def addChord (s, note, noot): # careful: we assume that chord notes follow immediately
|
| 237 |
+
for d in note.before: # put all decorations before chord
|
| 238 |
+
if d not in s.lastnote.before:
|
| 239 |
+
s.lastnote.before += [d]
|
| 240 |
+
s.lastnote.ns.append (note.ntdec + noot)
|
| 241 |
+
|
| 242 |
+
def addBar (s, lbrk, m): # linebreak, measure data
|
| 243 |
+
if m.mdur and s.maxtime > m.mdur: info ('measure %d in part %d longer than metre' % (m.ixm+1, m.ixp+1))
|
| 244 |
+
s.tijd = s.maxtime # the time of the bar lines inserted here
|
| 245 |
+
for v in s.vnums:
|
| 246 |
+
if m.lline or m.lnum: # if left barline or left volta number
|
| 247 |
+
p = s.getLastRec (v) # get the previous barline record
|
| 248 |
+
if p: # in measure 1 no previous measure is available
|
| 249 |
+
x = p.str # p.str is the ABC barline string
|
| 250 |
+
if m.lline: # append begin of repeat, m.lline == ':'
|
| 251 |
+
x = (x + m.lline).replace (':|:','::').replace ('||','|')
|
| 252 |
+
if s.nvlt == 3: # add volta number only to lowest voice in part 0
|
| 253 |
+
if m.ixp + v == min (s.vnums): x += m.lnum
|
| 254 |
+
elif m.lnum: # new behaviour with I:repbra 0
|
| 255 |
+
x += m.lnum # add volta number(s) or text to all voices
|
| 256 |
+
s.repbra = 1 # signal occurrence of a volta
|
| 257 |
+
p.str = x # modify previous right barline
|
| 258 |
+
elif m.lline: # begin of new part and left repeat bar is required
|
| 259 |
+
s.insertElem (v, '|:')
|
| 260 |
+
if lbrk:
|
| 261 |
+
p = s.getLastRec (v) # get the previous barline record
|
| 262 |
+
if p: p.str += lbrk # insert linebreak char after the barlines+volta
|
| 263 |
+
if m.attr: # insert signatures at front of buffer
|
| 264 |
+
s.insertElem (v, '%s' % m.attr)
|
| 265 |
+
s.appendElem (v, ' %s' % m.rline) # insert current barline record at time maxtime
|
| 266 |
+
s.voices[v] = sortMeasure (s.voices[v], m) # make all times consistent
|
| 267 |
+
lyrs = s.lyrics[v] # [{number: sylabe}, .. for all notes]
|
| 268 |
+
lyrdict = {} # {number: (abc_lyric_string, melis)} for this voice
|
| 269 |
+
nums = [num for d in lyrs for num in d.keys ()] # the lyrics numbers in this measure
|
| 270 |
+
maxNums = max (nums + [0]) # the highest lyrics number in this measure
|
| 271 |
+
for i in range (maxNums, 0, -1):
|
| 272 |
+
xs = [syldict.get (i, '') for syldict in lyrs] # collect the syllabi with number i
|
| 273 |
+
melis = s.getLastMelis (v, i) # get melisma from last measure
|
| 274 |
+
lyrdict [i] = abcLyr (xs, melis)
|
| 275 |
+
s.lyrics[v] = lyrdict # {number: (abc_lyric_string, melis)} for this measure
|
| 276 |
+
mkBroken (s.voices[v])
|
| 277 |
+
s.gMaten.append (s.voices)
|
| 278 |
+
s.gLyrics.append (s.lyrics)
|
| 279 |
+
s.tijd = s.maxtime = 0
|
| 280 |
+
s.initVoices ()
|
| 281 |
+
|
| 282 |
+
def outVoices (s, divs, ip, isSib): # output all voices of part ip
|
| 283 |
+
vvmap = {} # xml voice number -> abc voice number (one part)
|
| 284 |
+
vnum_keys = list (s.vnums.keys ())
|
| 285 |
+
if s.jscript or isSib: vnum_keys.sort ()
|
| 286 |
+
lvc = min (vnum_keys or [1]) # lowest xml voice number of this part
|
| 287 |
+
for iv in vnum_keys:
|
| 288 |
+
if s.cnt.getv ('note', iv) == 0: # no real notes counted in this voice
|
| 289 |
+
continue # skip empty voices
|
| 290 |
+
if abcOut.denL: unitL = abcOut.denL # take the unit length from the -d option
|
| 291 |
+
else: unitL = compUnitLength (iv, s.gMaten, divs) # compute the best unit length for this voice
|
| 292 |
+
abcOut.cmpL.append (unitL) # remember for header output
|
| 293 |
+
vn, vl = [], {} # for voice iv: collect all notes to vn and all lyric lines to vl
|
| 294 |
+
for im in range (len (s.gMaten)):
|
| 295 |
+
measure = s.gMaten [im][iv]
|
| 296 |
+
vn.append (outVoice (measure, divs [im], im, ip, unitL))
|
| 297 |
+
checkMelismas (s.gLyrics, s.gMaten, im, iv)
|
| 298 |
+
for n, (lyrstr, melis) in s.gLyrics [im][iv].items ():
|
| 299 |
+
if n in vl:
|
| 300 |
+
while len (vl[n]) < im: vl[n].append ('') # fill in skipped measures
|
| 301 |
+
vl[n].append (lyrstr)
|
| 302 |
+
else:
|
| 303 |
+
vl[n] = im * [''] + [lyrstr] # must skip im measures
|
| 304 |
+
for n, lyrs in vl.items (): # fill up possibly empty lyric measures at the end
|
| 305 |
+
mis = len (vn) - len (lyrs)
|
| 306 |
+
lyrs += mis * ['']
|
| 307 |
+
abcOut.add ('V:%d' % s.vceCnt)
|
| 308 |
+
if s.repbra:
|
| 309 |
+
if s.nvlt == 1 and s.vceCnt > 1: abcOut.add ('I:repbra 0') # only volta on first voice
|
| 310 |
+
if s.nvlt == 2 and iv > lvc: abcOut.add ('I:repbra 0') # only volta on first voice of each part
|
| 311 |
+
if s.cpl > 0: s.bpl = 0 # option -n (max chars per line) overrules -b (max bars per line)
|
| 312 |
+
elif s.bpl == 0: s.cpl = 100 # the default: 100 chars per line
|
| 313 |
+
bn = 0 # count bars
|
| 314 |
+
while vn: # while still measures available
|
| 315 |
+
ib = 1
|
| 316 |
+
chunk = vn [0]
|
| 317 |
+
while ib < len (vn):
|
| 318 |
+
if s.cpl > 0 and len (chunk) + len (vn [ib]) >= s.cpl: break # line full (number of chars)
|
| 319 |
+
if s.bpl > 0 and ib >= s.bpl: break # line full (number of bars)
|
| 320 |
+
chunk += vn [ib]
|
| 321 |
+
ib += 1
|
| 322 |
+
bn += ib
|
| 323 |
+
abcOut.add (chunk + ' %%%d' % bn) # line with barnumer
|
| 324 |
+
del vn[:ib] # chop ib bars
|
| 325 |
+
lyrlines = sorted (vl.items ()) # order the numbered lyric lines for output
|
| 326 |
+
for n, lyrs in lyrlines:
|
| 327 |
+
abcOut.add ('w: ' + '|'.join (lyrs[:ib]) + '|')
|
| 328 |
+
del lyrs[:ib]
|
| 329 |
+
vvmap [iv] = s.vceCnt # xml voice number -> abc voice number
|
| 330 |
+
s.vceCnt += 1 # count voices over all parts
|
| 331 |
+
s.gMaten = [] # reset the follwing instance vars for each part
|
| 332 |
+
s.gLyrics = []
|
| 333 |
+
s.cnt.prcnt (ip+1) # print summary of skipped items in this part
|
| 334 |
+
return vvmap
|
| 335 |
+
|
| 336 |
+
class ABCoutput:
|
| 337 |
+
pagekeys = 'scale,pageheight,pagewidth,leftmargin,rightmargin,topmargin,botmargin'.split (',')
|
| 338 |
+
def __init__ (s, fnmext, pad, X, options):
|
| 339 |
+
s.fnmext = fnmext
|
| 340 |
+
s.outlist = [] # list of ABC strings
|
| 341 |
+
s.title = 'T:Title'
|
| 342 |
+
s.key = 'none'
|
| 343 |
+
s.clefs = {} # clefs for all abc-voices
|
| 344 |
+
s.mtr = 'none'
|
| 345 |
+
s.tempo = 0 # 0 -> no tempo field
|
| 346 |
+
s.tempo_units = (1,4) # note type of tempo direction
|
| 347 |
+
s.pad = pad # the output path or none
|
| 348 |
+
s.X = X + 1 # the abc tune number
|
| 349 |
+
s.denL = options.d # denominator of the unit length (L:) from -d option
|
| 350 |
+
s.volpan = int (options.m) # 0 -> no %%MIDI, 1 -> only program, 2 -> all %%MIDI
|
| 351 |
+
s.cmpL = [] # computed optimal unit length for all voices
|
| 352 |
+
s.jscript = options.j # compatibility with javascript version
|
| 353 |
+
s.tstep = options.t # translate percmap to voicemap
|
| 354 |
+
s.stemless = 0 # use U:s=!stemless!
|
| 355 |
+
s.shiftStem = options.s # shift note heads 3 units left
|
| 356 |
+
if pad:
|
| 357 |
+
_, base_name = os.path.split (fnmext)
|
| 358 |
+
s.outfile = open (os.path.join (pad, base_name), 'w', encoding='utf-8')
|
| 359 |
+
else: s.outfile = sys.stdout
|
| 360 |
+
if s.jscript: s.X = 1 # always X:1 in javascript version
|
| 361 |
+
s.pageFmt = {}
|
| 362 |
+
for k in s.pagekeys: s.pageFmt [k] = None
|
| 363 |
+
if len (options.p) == 7:
|
| 364 |
+
for k, v in zip (s.pagekeys, options.p):
|
| 365 |
+
try: s.pageFmt [k] = float (v)
|
| 366 |
+
except: info ('illegal float %s for %s', (k, v)); continue
|
| 367 |
+
|
| 368 |
+
def add (s, str):
|
| 369 |
+
s.outlist.append (str + '\n') # collect all ABC output
|
| 370 |
+
|
| 371 |
+
def mkHeader (s, stfmap, partlist, midimap, vmpdct, koppen): # stfmap = [parts], part = [staves], stave = [voices]
|
| 372 |
+
accVce, accStf, staffs = [], [], stfmap[:] # staffs is consumed
|
| 373 |
+
for x in partlist: # collect partnames into accVce and staff groups into accStf
|
| 374 |
+
try: prgroupelem (x, ('', ''), '', stfmap, accVce, accStf)
|
| 375 |
+
except: info ('lousy musicxml: error in part-list')
|
| 376 |
+
staves = ' '.join (accStf)
|
| 377 |
+
clfnms = {}
|
| 378 |
+
for part, (partname, partabbrv) in zip (staffs, accVce):
|
| 379 |
+
if not part: continue # skip empty part
|
| 380 |
+
firstVoice = part[0][0] # the first voice number in this part
|
| 381 |
+
nm = partname.replace ('\n','\\n').replace ('.:','.').strip (':')
|
| 382 |
+
snm = partabbrv.replace ('\n','\\n').replace ('.:','.').strip (':')
|
| 383 |
+
clfnms [firstVoice] = (nm and 'nm="%s"' % nm or '') + (snm and ' snm="%s"' % snm or '')
|
| 384 |
+
hd = ['X:%d\n%s\n' % (s.X, s.title)]
|
| 385 |
+
for i, k in enumerate (s.pagekeys):
|
| 386 |
+
if s.jscript and k in ['pageheight','topmargin', 'botmargin']: continue
|
| 387 |
+
if s.pageFmt [k] != None: hd.append ('%%%%%s %.2f%s\n' % (k, s.pageFmt [k], i > 0 and 'cm' or ''))
|
| 388 |
+
if staves and len (accStf) > 1: hd.append ('%%score ' + staves + '\n')
|
| 389 |
+
tempo = s.tempo and 'Q:%d/%d=%s\n' % (s.tempo_units [0], s.tempo_units [1], s.tempo) or '' # default no tempo field
|
| 390 |
+
d = {} # determine the most frequently occurring unit length over all voices
|
| 391 |
+
for x in s.cmpL: d[x] = d.get (x, 0) + 1
|
| 392 |
+
if s.jscript: defLs = sorted (d.items (), key=lambda x: (-x[1], x[0])) # when tie (1) sort on key (0)
|
| 393 |
+
else: defLs = sorted (d.items (), key=lambda x: -x[1])
|
| 394 |
+
defL = s.denL and s.denL or defLs [0][0] # override default unit length with -d option
|
| 395 |
+
hd.append ('L:1/%d\n%sM:%s\n' % (defL, tempo, s.mtr))
|
| 396 |
+
hd.append ('K:%s\n' % s.key)
|
| 397 |
+
if s.stemless: hd.append ('U:s=!stemless!\n')
|
| 398 |
+
vxs = sorted (vmpdct.keys ())
|
| 399 |
+
for vx in vxs: hd.extend (vmpdct [vx])
|
| 400 |
+
s.dojef = 0 # translate percmap to voicemap
|
| 401 |
+
for vnum, clef in s.clefs.items ():
|
| 402 |
+
ch, prg, vol, pan = midimap [vnum-1][:4]
|
| 403 |
+
dmap = midimap [vnum - 1][4:] # map of abc percussion notes to midi notes
|
| 404 |
+
if dmap and 'perc' not in clef: clef = (clef + ' map=perc').strip ();
|
| 405 |
+
hd.append ('V:%d %s %s\n' % (vnum, clef, clfnms.get (vnum, '')))
|
| 406 |
+
if vnum in vmpdct:
|
| 407 |
+
hd.append ('%%%%voicemap tab%d\n' % vnum)
|
| 408 |
+
hd.append ('K:none\nM:none\n%%clef none\n%%staffscale 1.6\n%%flatbeams true\n%%stemdir down\n')
|
| 409 |
+
if 'perc' in clef: hd.append ('K:none\n'); # no key for a perc voice
|
| 410 |
+
if s.volpan > 1: # option -m 2 -> output all recognized midi commands when needed and present in xml
|
| 411 |
+
if ch > 0 and ch != vnum: hd.append ('%%%%MIDI channel %d\n' % ch)
|
| 412 |
+
if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
|
| 413 |
+
if vol >= 0: hd.append ('%%%%MIDI control 7 %.0f\n' % vol) # volume == 0 is possible ...
|
| 414 |
+
if pan >= 0: hd.append ('%%%%MIDI control 10 %.0f\n' % pan)
|
| 415 |
+
elif s.volpan > 0: # default -> only output midi program command when present in xml
|
| 416 |
+
if dmap and ch > 0: hd.append ('%%%%MIDI channel %d\n' % ch) # also channel if percussion part
|
| 417 |
+
if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
|
| 418 |
+
for abcNote, step, midiNote, notehead in dmap:
|
| 419 |
+
if not notehead: notehead = 'normal'
|
| 420 |
+
if abcMid (abcNote) != midiNote or abcNote != step:
|
| 421 |
+
if s.volpan > 0: hd.append ('%%%%MIDI drummap %s %s\n' % (abcNote, midiNote))
|
| 422 |
+
hd.append ('I:percmap %s %s %s %s\n' % (abcNote, step, midiNote, notehead))
|
| 423 |
+
s.dojef = s.tstep
|
| 424 |
+
if defL != s.cmpL [vnum-1]: # only if computed unit length different from header
|
| 425 |
+
hd.append ('L:1/%d\n' % s.cmpL [vnum-1])
|
| 426 |
+
s.outlist = hd + s.outlist
|
| 427 |
+
if koppen: # output SVG stuff needed for tablature
|
| 428 |
+
k1 = kopSvg.replace ('-2','-5') if s.shiftStem else kopSvg # shift note heads 3 units left
|
| 429 |
+
k2 = kopSvg2.replace ('-2','-5') if s.shiftStem else kopSvg2
|
| 430 |
+
tb = tabSvg.replace ('-3','-6') if s.shiftStem else tabSvg
|
| 431 |
+
ks = sorted (koppen.keys ()) # javascript compatibility
|
| 432 |
+
ks = [k2 % (k, k) if len (k) == 2 else k1 % (k, k) for k in ks]
|
| 433 |
+
tbs = map (lambda x: x.strip () + '\n', tb.splitlines ()) # javascript compatibility
|
| 434 |
+
s.outlist = tbs + ks + ['</defs>\n%%endsvg\n'] + s.outlist
|
| 435 |
+
|
| 436 |
+
def writeall (s): # determine the required encoding of the entire ABC output
|
| 437 |
+
str = ''.join (s.outlist)
|
| 438 |
+
# print(str)
|
| 439 |
+
if s.dojef: str = perc2map (str)
|
| 440 |
+
if python3: s.outfile.write (str)
|
| 441 |
+
else: s.outfile.write (str)
|
| 442 |
+
if s.pad: s.outfile.close () # close each file with -o option
|
| 443 |
+
else: s.outfile.write ('\n') # add empty line between tunes on stdout
|
| 444 |
+
info ('%s written with %d voices' % (s.fnmext, len (s.clefs)), warn=0)
|
| 445 |
+
|
| 446 |
+
#----------------
|
| 447 |
+
# functions
|
| 448 |
+
#----------------
|
| 449 |
+
def abcLyr (xs, melis): # Convert list xs to abc lyrics.
|
| 450 |
+
if not ''.join (xs): return '', 0 # there is no lyrics in this measure
|
| 451 |
+
res = []
|
| 452 |
+
for x in xs: # xs has for every note a lyrics syllabe or an empty string
|
| 453 |
+
if x == '': # note without lyrics
|
| 454 |
+
if melis: x = '_' # set melisma
|
| 455 |
+
else: x = '*' # skip note
|
| 456 |
+
elif x.endswith ('_') and not x.endswith ('\_'): # start of new melisma
|
| 457 |
+
x = x.replace ('_', '') # remove and set melis boolean
|
| 458 |
+
melis = 1 # so next skips will become melisma
|
| 459 |
+
else: melis = 0 # melisma stops on first syllable
|
| 460 |
+
res.append (x)
|
| 461 |
+
return (' '.join (res), melis)
|
| 462 |
+
|
| 463 |
+
def simplify (a, b): # divide a and b by their greatest common divisor
|
| 464 |
+
x, y = a, b
|
| 465 |
+
while b: a, b = b, a % b
|
| 466 |
+
return x // a, y // a
|
| 467 |
+
|
| 468 |
+
def abcdur (nx, divs, uL): # convert an musicXML duration d to abc units with L:1/uL
|
| 469 |
+
if nx.dur == 0: return '' # when called for elements without duration
|
| 470 |
+
num, den = simplify (uL * nx.dur, divs * 4) # L=1/8 -> uL = 8 units
|
| 471 |
+
if nx.fact: # apply tuplet time modification
|
| 472 |
+
numfac, denfac = nx.fact
|
| 473 |
+
num, den = simplify (num * numfac, den * denfac)
|
| 474 |
+
if den > 64: # limit the denominator to a maximum of 64
|
| 475 |
+
x = float (num) / den; n = math.floor (x); # when just above an integer n
|
| 476 |
+
if x - n < 0.1 * x: num, den = n, 1; # round to n
|
| 477 |
+
num64 = 64. * num / den + 1.0e-15 # to get Python2 behaviour of round
|
| 478 |
+
num, den = simplify (int (round (num64)), 64)
|
| 479 |
+
if num == 1:
|
| 480 |
+
if den == 1: dabc = ''
|
| 481 |
+
elif den == 2: dabc = '/'
|
| 482 |
+
else: dabc = '/%d' % den
|
| 483 |
+
elif den == 1: dabc = '%d' % num
|
| 484 |
+
else: dabc = '%d/%d' % (num, den)
|
| 485 |
+
return dabc
|
| 486 |
+
|
| 487 |
+
def abcMid (note): # abc note -> midi pitch
|
| 488 |
+
r = re.search (r"([_^]*)([A-Ga-g])([',]*)", note)
|
| 489 |
+
if not r: return -1
|
| 490 |
+
acc, n, oct = r.groups ()
|
| 491 |
+
nUp = n.upper ()
|
| 492 |
+
p = 60 + [0,2,4,5,7,9,11]['CDEFGAB'.index (nUp)] + (12 if nUp != n else 0);
|
| 493 |
+
if acc: p += (1 if acc[0] == '^' else -1) * len (acc)
|
| 494 |
+
if oct: p += (12 if oct[0] == "'" else -12) * len (oct)
|
| 495 |
+
return p
|
| 496 |
+
|
| 497 |
+
def staffStep (ptc, o, clef, tstep):
|
| 498 |
+
ndif = 0
|
| 499 |
+
if 'stafflines=1' in clef: ndif += 4 # meaning of one line: E (xml) -> B (abc)
|
| 500 |
+
if not tstep and clef.startswith ('bass'): ndif += 12 # transpose bass -> treble (C3 -> A4)
|
| 501 |
+
if ndif: # diatonic transposition == addition modulo 7
|
| 502 |
+
nm7 = 'C,D,E,F,G,A,B'.split (',')
|
| 503 |
+
n = nm7.index (ptc) + ndif
|
| 504 |
+
ptc, o = nm7 [n % 7], o + n // 7
|
| 505 |
+
if o > 4: ptc = ptc.lower ()
|
| 506 |
+
if o > 5: ptc = ptc + (o-5) * "'"
|
| 507 |
+
if o < 4: ptc = ptc + (4-o) * ","
|
| 508 |
+
return ptc
|
| 509 |
+
|
| 510 |
+
def setKey (fifths, mode):
|
| 511 |
+
sharpness = ['Fb', 'Cb','Gb','Db','Ab','Eb','Bb','F','C','G','D','A', 'E', 'B', 'F#','C#','G#','D#','A#','E#','B#']
|
| 512 |
+
offTab = {'maj':8, 'ion':8, 'm':11, 'min':11, 'aeo':11, 'mix':9, 'dor':10, 'phr':12, 'lyd':7, 'loc':13, 'non':8}
|
| 513 |
+
mode = mode.lower ()[:3] # only first three chars, no case
|
| 514 |
+
key = sharpness [offTab [mode] + fifths] + (mode if offTab [mode] != 8 else '')
|
| 515 |
+
accs = ['F','C','G','D','A','E','B']
|
| 516 |
+
if fifths >= 0: msralts = dict (zip (accs[:fifths], fifths * [1]))
|
| 517 |
+
else: msralts = dict (zip (accs[fifths:], -fifths * [-1]))
|
| 518 |
+
return key, msralts
|
| 519 |
+
|
| 520 |
+
def insTup (ix, notes, fact): # read one nested tuplet
|
| 521 |
+
tupcnt = 0
|
| 522 |
+
nx = notes [ix]
|
| 523 |
+
if 'start' in nx.tup:
|
| 524 |
+
nx.tup.remove ('start') # do recursive calls when starts remain
|
| 525 |
+
tix = ix # index of first tuplet note
|
| 526 |
+
fn, fd = fact # xml time-mod of the higher level
|
| 527 |
+
fnum, fden = nx.fact # xml time-mod of the current level
|
| 528 |
+
tupfact = fnum//fn, fden//fd # abc time mod of this level
|
| 529 |
+
while ix < len (notes):
|
| 530 |
+
nx = notes [ix]
|
| 531 |
+
if isinstance (nx, Elem) or nx.grace:
|
| 532 |
+
ix += 1 # skip all non tuplet elements
|
| 533 |
+
continue
|
| 534 |
+
if 'start' in nx.tup: # more nested tuplets to start
|
| 535 |
+
ix, tupcntR = insTup (ix, notes, tupfact) # ix is on the stop note!
|
| 536 |
+
tupcnt += tupcntR
|
| 537 |
+
elif nx.fact:
|
| 538 |
+
tupcnt += 1 # count tuplet elements
|
| 539 |
+
if 'stop' in nx.tup:
|
| 540 |
+
nx.tup.remove ('stop')
|
| 541 |
+
break
|
| 542 |
+
if not nx.fact: # stop on first non tuplet note
|
| 543 |
+
ix = lastix # back to last tuplet note
|
| 544 |
+
break
|
| 545 |
+
lastix = ix
|
| 546 |
+
ix += 1
|
| 547 |
+
# put abc tuplet notation before the recursive ones
|
| 548 |
+
tup = (tupfact[0], tupfact[1], tupcnt)
|
| 549 |
+
if tup == (3, 2, 3): tupPrefix = '(3'
|
| 550 |
+
else: tupPrefix = '(%d:%d:%d' % tup
|
| 551 |
+
notes [tix].tupabc = tupPrefix + notes [tix].tupabc
|
| 552 |
+
return ix, tupcnt # ix is on the last tuplet note
|
| 553 |
+
|
| 554 |
+
def mkBroken (vs): # introduce broken rhythms (vs: one voice, one measure)
|
| 555 |
+
vs = [n for n in vs if isinstance (n, Note)]
|
| 556 |
+
i = 0
|
| 557 |
+
while i < len (vs) - 1:
|
| 558 |
+
n1, n2 = vs[i], vs[i+1] # scan all adjacent pairs
|
| 559 |
+
# skip if note in tuplet or has no duration or outside beam
|
| 560 |
+
if not n1.fact and not n2.fact and n1.dur > 0 and n2.beam:
|
| 561 |
+
if n1.dur * 3 == n2.dur:
|
| 562 |
+
n2.dur = (2 * n2.dur) // 3
|
| 563 |
+
n1.dur = n1.dur * 2
|
| 564 |
+
n1.after = '<' + n1.after
|
| 565 |
+
i += 1 # do not chain broken rhythms
|
| 566 |
+
elif n2.dur * 3 == n1.dur:
|
| 567 |
+
n1.dur = (2 * n1.dur) // 3
|
| 568 |
+
n2.dur = n2.dur * 2
|
| 569 |
+
n1.after = '>' + n1.after
|
| 570 |
+
i += 1 # do not chain broken rhythms
|
| 571 |
+
i += 1
|
| 572 |
+
|
| 573 |
+
def outVoice (measure, divs, im, ip, unitL): # note/elem objects of one measure in one voice
|
| 574 |
+
ix = 0
|
| 575 |
+
while ix < len (measure): # set all (nested) tuplet annotations
|
| 576 |
+
nx = measure [ix]
|
| 577 |
+
if isinstance (nx, Note) and nx.fact and not nx.grace:
|
| 578 |
+
ix, tupcnt = insTup (ix, measure, (1, 1)) # read one tuplet, insert annotation(s)
|
| 579 |
+
ix += 1
|
| 580 |
+
vs = []
|
| 581 |
+
for nx in measure:
|
| 582 |
+
if isinstance (nx, Note):
|
| 583 |
+
durstr = abcdur (nx, divs, unitL) # xml -> abc duration string
|
| 584 |
+
chord = len (nx.ns) > 1
|
| 585 |
+
cns = [nt[:-1] for nt in nx.ns if nt.endswith ('-')]
|
| 586 |
+
tie = ''
|
| 587 |
+
if chord and len (cns) == len (nx.ns): # all chord notes tied
|
| 588 |
+
nx.ns = cns # chord notes without tie
|
| 589 |
+
tie = '-' # one tie for whole chord
|
| 590 |
+
s = nx.tupabc + ''.join (nx.before)
|
| 591 |
+
if chord: s += '['
|
| 592 |
+
for nt in nx.ns: s += nt
|
| 593 |
+
if chord: s += ']' + tie
|
| 594 |
+
if s.endswith ('-'): s, tie = s[:-1], '-' # split off tie
|
| 595 |
+
s += durstr + tie # and put it back again
|
| 596 |
+
s += nx.after
|
| 597 |
+
nospace = nx.beam
|
| 598 |
+
else:
|
| 599 |
+
if isinstance (nx.str, listtype): nx.str = nx.str [0]
|
| 600 |
+
s = nx.str
|
| 601 |
+
nospace = 1
|
| 602 |
+
if nospace: vs.append (s)
|
| 603 |
+
else: vs.append (' ' + s)
|
| 604 |
+
vs = ''.join (vs) # ad hoc: remove multiple pedal directions
|
| 605 |
+
while vs.find ('!ped!!ped!') >= 0: vs = vs.replace ('!ped!!ped!','!ped!')
|
| 606 |
+
while vs.find ('!ped-up!!ped-up!') >= 0: vs = vs.replace ('!ped-up!!ped-up!','!ped-up!')
|
| 607 |
+
while vs.find ('!8va(!!8va)!') >= 0: vs = vs.replace ('!8va(!!8va)!','') # remove empty ottava's
|
| 608 |
+
return vs
|
| 609 |
+
|
| 610 |
+
def sortMeasure (voice, m):
|
| 611 |
+
voice.sort (key=lambda o: o.tijd) # sort on time
|
| 612 |
+
time = 0
|
| 613 |
+
v = []
|
| 614 |
+
rs = [] # holds rests in between notes
|
| 615 |
+
for i, nx in enumerate (voice): # establish sequentiality
|
| 616 |
+
if nx.tijd > time and chkbug (nx.tijd - time, m):
|
| 617 |
+
v.append (Note (nx.tijd - time, 'x')) # fill hole with invisble rest
|
| 618 |
+
rs.append (len (v) - 1)
|
| 619 |
+
if isinstance (nx, Elem):
|
| 620 |
+
if nx.tijd < time: nx.tijd = time # shift elems without duration to where they fit
|
| 621 |
+
v.append (nx)
|
| 622 |
+
time = nx.tijd
|
| 623 |
+
continue
|
| 624 |
+
if nx.tijd < time: # overlapping element
|
| 625 |
+
if nx.ns[0] == 'z': continue # discard overlapping rest
|
| 626 |
+
if v[-1].tijd <= nx.tijd: # we can do something
|
| 627 |
+
if v[-1].ns[0] == 'z': # shorten rest
|
| 628 |
+
v[-1].dur = nx.tijd - v[-1].tijd
|
| 629 |
+
if v[-1].dur == 0: del v[-1] # nothing left
|
| 630 |
+
info ('overlap in part %d, measure %d: rest shortened' % (m.ixp+1, m.ixm+1))
|
| 631 |
+
else: # make a chord of overlap
|
| 632 |
+
v[-1].ns += nx.ns
|
| 633 |
+
info ('overlap in part %d, measure %d: added chord' % (m.ixp+1, m.ixm+1))
|
| 634 |
+
nx.dur = (nx.tijd + nx.dur) - time # the remains
|
| 635 |
+
if nx.dur <= 0: continue # nothing left
|
| 636 |
+
nx.tijd = time # append remains
|
| 637 |
+
else: # give up
|
| 638 |
+
info ('overlapping notes in one voice! part %d, measure %d, note %s discarded' % (m.ixp+1, m.ixm+1, isinstance (nx, Note) and nx.ns or nx.str))
|
| 639 |
+
continue
|
| 640 |
+
v.append (nx)
|
| 641 |
+
if isinstance (nx, Note):
|
| 642 |
+
if nx.ns [0] in 'zx':
|
| 643 |
+
rs.append (len (v) - 1) # remember rests between notes
|
| 644 |
+
elif len (rs):
|
| 645 |
+
if nx.beam and not nx.grace: # copy beam into rests
|
| 646 |
+
for j in rs: v[j].beam = nx.beam
|
| 647 |
+
rs = [] # clear rests on each note
|
| 648 |
+
time = nx.tijd + nx.dur
|
| 649 |
+
# when a measure contains no elements and no forwards -> no incTime -> s.maxtime = 0 -> right barline
|
| 650 |
+
# is inserted at time == 0 (in addbar) and is only element in the voice when sortMeasure is called
|
| 651 |
+
if time == 0: info ('empty measure in part %d, measure %d, it should contain at least a rest to advance the time!' % (m.ixp+1, m.ixm+1))
|
| 652 |
+
return v
|
| 653 |
+
|
| 654 |
+
def getPartlist (ps): # correct part-list (from buggy xml-software)
|
| 655 |
+
xs = [] # the corrected part-list
|
| 656 |
+
e = [] # stack of opened part-groups
|
| 657 |
+
for x in list (ps): # insert missing stops, delete double starts
|
| 658 |
+
if x.tag == 'part-group':
|
| 659 |
+
num, type = x.get ('number'), x.get ('type')
|
| 660 |
+
if type == 'start':
|
| 661 |
+
if num in e: # missing stop: insert one
|
| 662 |
+
xs.append (E.Element ('part-group', number = num, type = 'stop'))
|
| 663 |
+
xs.append (x)
|
| 664 |
+
else: # normal start
|
| 665 |
+
xs.append (x)
|
| 666 |
+
e.append (num)
|
| 667 |
+
else:
|
| 668 |
+
if num in e: # normal stop
|
| 669 |
+
e.remove (num)
|
| 670 |
+
xs.append (x)
|
| 671 |
+
else: pass # double stop: skip it
|
| 672 |
+
else: xs.append (x)
|
| 673 |
+
for num in reversed (e): # fill missing stops at the end
|
| 674 |
+
xs.append (E.Element ('part-group', number = num, type = 'stop'))
|
| 675 |
+
return xs
|
| 676 |
+
|
| 677 |
+
def parseParts (xs, d, e): # -> [elems on current level], rest of xs
|
| 678 |
+
if not xs: return [],[]
|
| 679 |
+
x = xs.pop (0)
|
| 680 |
+
if x.tag == 'part-group':
|
| 681 |
+
num, type = x.get ('number'), x.get ('type')
|
| 682 |
+
if type == 'start': # go one level deeper
|
| 683 |
+
s = [x.findtext (n, '') for n in ['group-symbol','group-barline','group-name','group-abbreviation']]
|
| 684 |
+
d [num] = s # remember groupdata by group number
|
| 685 |
+
e.append (num) # make stack of open group numbers
|
| 686 |
+
elemsnext, rest1 = parseParts (xs, d, e) # parse one level deeper to next stop
|
| 687 |
+
elems, rest2 = parseParts (rest1, d, e) # parse the rest on this level
|
| 688 |
+
return [elemsnext] + elems, rest2
|
| 689 |
+
else: # stop: close level and return group-data
|
| 690 |
+
nums = e.pop () # last open group number in stack order
|
| 691 |
+
if xs and xs[0].get ('type') == 'stop': # two consequetive stops
|
| 692 |
+
if num != nums: # in the wrong order (tempory solution)
|
| 693 |
+
d[nums], d[num] = d[num], d[nums] # exchange values (only works for two stops!!!)
|
| 694 |
+
sym = d[num] # retrieve an return groupdata as last element of the group
|
| 695 |
+
return [sym], xs
|
| 696 |
+
else:
|
| 697 |
+
elems, rest = parseParts (xs, d, e) # parse remaining elements on current level
|
| 698 |
+
name = x.findtext ('part-name',''), x.findtext ('part-abbreviation','')
|
| 699 |
+
return [name] + elems, rest
|
| 700 |
+
|
| 701 |
+
def bracePart (part): # put a brace on multistaff part and group voices
|
| 702 |
+
if not part: return [] # empty part in the score
|
| 703 |
+
brace = []
|
| 704 |
+
for ivs in part:
|
| 705 |
+
if len (ivs) == 1: # stave with one voice
|
| 706 |
+
brace.append ('%s' % ivs[0])
|
| 707 |
+
else: # stave with multiple voices
|
| 708 |
+
brace += ['('] + ['%s' % iv for iv in ivs] + [')']
|
| 709 |
+
brace.append ('|')
|
| 710 |
+
del brace[-1] # no barline at the end
|
| 711 |
+
if len (part) > 1:
|
| 712 |
+
brace = ['{'] + brace + ['}']
|
| 713 |
+
return brace
|
| 714 |
+
|
| 715 |
+
def prgroupelem (x, gnm, bar, pmap, accVce, accStf): # collect partnames (accVce) and %%score map (accStf)
|
| 716 |
+
if type (x) == tupletype: # partname-tuple = (part-name, part-abbrev)
|
| 717 |
+
y = pmap.pop (0)
|
| 718 |
+
if gnm[0]: x = [n1 + ':' + n2 for n1, n2 in zip (gnm, x)] # put group-name before part-name
|
| 719 |
+
accVce.append (x)
|
| 720 |
+
accStf.extend (bracePart (y))
|
| 721 |
+
elif len (x) == 2 and type (x[0]) == tupletype: # misuse of group just to add extra name to stave
|
| 722 |
+
y = pmap.pop (0)
|
| 723 |
+
nms = [n1 + ':' + n2 for n1, n2 in zip (x[0], x[1][2:])] # x[0] = partname-tuple, x[1][2:] = groupname-tuple
|
| 724 |
+
accVce.append (nms)
|
| 725 |
+
accStf.extend (bracePart (y))
|
| 726 |
+
else:
|
| 727 |
+
prgrouplist (x, bar, pmap, accVce, accStf)
|
| 728 |
+
|
| 729 |
+
def prgrouplist (x, pbar, pmap, accVce, accStf): # collect partnames, scoremap for a part-group
|
| 730 |
+
sym, bar, gnm, gabbr = x[-1] # bracket symbol, continue barline, group-name-tuple
|
| 731 |
+
bar = bar == 'yes' or pbar # pbar -> the parent has bar
|
| 732 |
+
accStf.append (sym == 'brace' and '{' or '[')
|
| 733 |
+
for z in x[:-1]:
|
| 734 |
+
prgroupelem (z, (gnm, gabbr), bar, pmap, accVce, accStf)
|
| 735 |
+
if bar: accStf.append ('|')
|
| 736 |
+
if bar: del accStf [-1] # remove last one before close
|
| 737 |
+
accStf.append (sym == 'brace' and '}' or ']')
|
| 738 |
+
|
| 739 |
+
def compUnitLength (iv, maten, divs): # compute optimal unit length
|
| 740 |
+
uLmin, minLen = 0, max_int
|
| 741 |
+
for uL in [4,8,16]: # try 1/4, 1/8 and 1/16
|
| 742 |
+
vLen = 0 # total length of abc duration strings in this voice
|
| 743 |
+
for im, m in enumerate (maten): # all measures
|
| 744 |
+
for e in m[iv]: # all notes in voice iv
|
| 745 |
+
if isinstance (e, Elem) or e.dur == 0: continue # no real durations
|
| 746 |
+
vLen += len (abcdur (e, divs [im], uL)) # add len of duration string
|
| 747 |
+
if vLen < minLen: uLmin, minLen = uL, vLen # remember the smallest
|
| 748 |
+
return uLmin
|
| 749 |
+
|
| 750 |
+
def doSyllable (syl):
|
| 751 |
+
txt = ''
|
| 752 |
+
for e in syl:
|
| 753 |
+
if e.tag == 'elision': txt += '~'
|
| 754 |
+
elif e.tag == 'text': # escape - and space characters
|
| 755 |
+
txt += (e.text or '').replace ('_','\_').replace('-', r'\-').replace(' ', '~')
|
| 756 |
+
if not txt: return txt
|
| 757 |
+
if syl.findtext('syllabic') in ['begin', 'middle']: txt += '-'
|
| 758 |
+
if syl.find('extend') is not None: txt += '_'
|
| 759 |
+
return txt
|
| 760 |
+
|
| 761 |
+
def checkMelismas (lyrics, maten, im, iv):
|
| 762 |
+
if im == 0: return
|
| 763 |
+
maat = maten [im][iv] # notes of the current measure
|
| 764 |
+
curlyr = lyrics [im][iv] # lyrics dict of current measure
|
| 765 |
+
prvlyr = lyrics [im-1][iv] # lyrics dict of previous measure
|
| 766 |
+
for n, (lyrstr, melis) in prvlyr.items (): # all lyric numbers in the previous measure
|
| 767 |
+
if n not in curlyr and melis: # melisma required, but no lyrics present -> make one!
|
| 768 |
+
ms = getMelisma (maat) # get a melisma for the current measure
|
| 769 |
+
if ms: curlyr [n] = (ms, 0) # set melisma as the n-th lyrics of the current measure
|
| 770 |
+
|
| 771 |
+
def getMelisma (maat): # get melisma from notes in maat
|
| 772 |
+
ms = []
|
| 773 |
+
for note in maat: # every note should get an underscore
|
| 774 |
+
if not isinstance (note, Note): continue # skip Elem's
|
| 775 |
+
if note.grace: continue # skip grace notes
|
| 776 |
+
if note.ns [0] in 'zx': break # stop on first rest
|
| 777 |
+
ms.append ('_')
|
| 778 |
+
return ' '.join (ms)
|
| 779 |
+
|
| 780 |
+
def perc2map (abcIn):
|
| 781 |
+
fillmap = {'diamond':1, 'triangle':1, 'square':1, 'normal':1};
|
| 782 |
+
abc = map (lambda x: x.strip (), percSvg.splitlines ())
|
| 783 |
+
id='default'
|
| 784 |
+
maps = {'default': []};
|
| 785 |
+
dmaps = {'default': []}
|
| 786 |
+
r1 = re.compile (r'V:\s*(\S+)')
|
| 787 |
+
ls = abcIn.splitlines ()
|
| 788 |
+
for x in ls:
|
| 789 |
+
if 'I:percmap' in x:
|
| 790 |
+
noot, step, midi, kop = map (lambda x: x.strip (), x.split ()[1:])
|
| 791 |
+
if kop in fillmap: kop = kop + '+' + ',' + kop
|
| 792 |
+
x = '%%%%map perc%s %s print=%s midi=%s heads=%s' % (id, noot, step, midi, kop)
|
| 793 |
+
maps [id].append (x)
|
| 794 |
+
if '%%MIDI' in x: dmaps [id].append (x)
|
| 795 |
+
if 'V:' in x:
|
| 796 |
+
r = r1.match (x)
|
| 797 |
+
if r:
|
| 798 |
+
id = r.group (1);
|
| 799 |
+
if id not in maps: maps [id] = []; dmaps [id] = []
|
| 800 |
+
ids = sorted (maps.keys ())
|
| 801 |
+
for id in ids: abc += maps [id]
|
| 802 |
+
id='default'
|
| 803 |
+
for x in ls:
|
| 804 |
+
if 'I:percmap' in x: continue
|
| 805 |
+
if '%%MIDI' in x: continue
|
| 806 |
+
if 'V:' in x or 'K:' in x:
|
| 807 |
+
r = r1.match (x)
|
| 808 |
+
if r: id = r.group (1)
|
| 809 |
+
abc.append (x)
|
| 810 |
+
if id in dmaps and len (dmaps [id]) > 0: abc.extend (dmaps [id]); del dmaps [id]
|
| 811 |
+
if 'perc' in x and 'map=' not in x: x += ' map=perc';
|
| 812 |
+
if 'map=perc' in x and len (maps [id]) > 0: abc.append ('%%voicemap perc' + id);
|
| 813 |
+
if 'map=off' in x: abc.append ('%%voicemap');
|
| 814 |
+
else:
|
| 815 |
+
abc.append (x)
|
| 816 |
+
return '\n'.join (abc) + '\n'
|
| 817 |
+
|
| 818 |
+
def addoct (ptc, o): # xml staff step, xml octave number
|
| 819 |
+
p = ptc
|
| 820 |
+
if o > 4: p = ptc.lower ()
|
| 821 |
+
if o > 5: p = p + (o-5) * "'"
|
| 822 |
+
if o < 4: p = p + (4-o) * ","
|
| 823 |
+
return p # abc pitch == abc note without accidental
|
| 824 |
+
|
| 825 |
+
def chkbug (dt, m):
|
| 826 |
+
if dt > m.divs / 16: return 1 # duration should be > 1/64 note
|
| 827 |
+
info ('MuseScore bug: incorrect duration, smaller then 1/64! in measure %d, part %d' % (m.ixm, m.ixp))
|
| 828 |
+
return 0
|
| 829 |
+
|
| 830 |
+
#----------------
|
| 831 |
+
# parser
|
| 832 |
+
#----------------
|
| 833 |
+
class Parser:
|
| 834 |
+
note_alts = [ # 3 alternative notations of the same note for tablature mapping
|
| 835 |
+
[x.strip () for x in '=C, ^C, =D, ^D, =E, =F, ^F, =G, ^G, =A, ^A, =B'.split (',')],
|
| 836 |
+
[x.strip () for x in '^B, _D,^^C, _E, _F, ^E, _G,^^F, _A,^^G, _B, _C'.split (',')],
|
| 837 |
+
[x.strip () for x in '__D,^^B,__E,__F,^^D,__G,^^E,__A,_/A,__B,__C,^^A'.split (',')] ]
|
| 838 |
+
step_map = {'C':0,'D':2,'E':4,'F':5,'G':7,'A':9,'B':11}
|
| 839 |
+
def __init__ (s, options):
|
| 840 |
+
# unfold repeats, number of chars per line, credit filter level, volta option
|
| 841 |
+
s.slurBuf = {} # dict of open slurs keyed by slur number
|
| 842 |
+
s.dirStk = {} # {direction-type + number -> (type, voice | time)} dict for proper closing
|
| 843 |
+
s.ingrace = 0 # marks a sequence of grace notes
|
| 844 |
+
s.msc = Music (options) # global music data abstraction
|
| 845 |
+
s.unfold = options.u # turn unfolding repeats on
|
| 846 |
+
s.ctf = options.c # credit text filter level
|
| 847 |
+
s.gStfMap = [] # [[abc voice numbers] for all parts]
|
| 848 |
+
s.midiMap = [] # midi-settings for each abc voice, in order
|
| 849 |
+
s.drumInst = {} # inst_id -> midi pitch for channel 10 notes
|
| 850 |
+
s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
|
| 851 |
+
s.instMid = [] # [{inst id -> midi-settings} for all parts]
|
| 852 |
+
s.midDflt = [-1,-1,-1,-91] # default midi settings for channel, program, volume, panning
|
| 853 |
+
s.msralts = {} # xml-notenames (without octave) with accidentals from the key
|
| 854 |
+
s.curalts = {} # abc-notenames (with voice number) with passing accidentals
|
| 855 |
+
s.stfMap = {} # xml staff number -> [xml voice number]
|
| 856 |
+
s.vce2stf = {} # xml voice number -> allocated staff number
|
| 857 |
+
s.clefMap = {} # xml staff number -> abc clef (for header only)
|
| 858 |
+
s.curClef = {} # xml staff number -> current abc clef
|
| 859 |
+
s.stemDir = {} # xml voice number -> current stem direction
|
| 860 |
+
s.clefOct = {} # xml staff number -> current clef-octave-change
|
| 861 |
+
s.curStf = {} # xml voice number -> current xml staff number
|
| 862 |
+
s.nolbrk = options.x; # generate no linebreaks ($)
|
| 863 |
+
s.jscript = options.j # compatibility with javascript version
|
| 864 |
+
s.ornaments = sorted (note_ornamentation_map.items ())
|
| 865 |
+
s.doPageFmt = len (options.p) == 1 # translate xml page format
|
| 866 |
+
s.tstep = options.t # clef determines step on staff (percussion)
|
| 867 |
+
s.dirtov1 = options.v1 # all directions to first voice of staff
|
| 868 |
+
s.ped = options.ped # render pedal directions
|
| 869 |
+
s.wstems = options.stm # translate stem elements
|
| 870 |
+
s.pedVce = None # voice for pedal directions
|
| 871 |
+
s.repeat_str = {} # staff number -> [measure number, repeat-text]
|
| 872 |
+
s.tabVceMap = {} # abc voice num -> [%%map ...] for tab voices
|
| 873 |
+
s.koppen = {} # noteheads needed for %%map
|
| 874 |
+
|
| 875 |
+
def matchSlur (s, type2, n, v2, note2, grace, stopgrace): # match slur number n in voice v2, add abc code to before/after
|
| 876 |
+
if type2 not in ['start', 'stop']: return # slur type continue has no abc equivalent
|
| 877 |
+
if n == None: n = '1'
|
| 878 |
+
if n in s.slurBuf:
|
| 879 |
+
type1, v1, note1, grace1 = s.slurBuf [n]
|
| 880 |
+
if type2 != type1: # slur complete, now check the voice
|
| 881 |
+
if v2 == v1: # begins and ends in the same voice: keep it
|
| 882 |
+
if type1 == 'start' and (not grace1 or not stopgrace): # normal slur: start before stop and no grace slur
|
| 883 |
+
note1.before = ['('] + note1.before # keep left-right order!
|
| 884 |
+
note2.after += ')'
|
| 885 |
+
# no else: don't bother with reversed stave spanning slurs
|
| 886 |
+
del s.slurBuf [n] # slur finished, remove from stack
|
| 887 |
+
else: # double definition, keep the last
|
| 888 |
+
info ('double slur numbers %s-%s in part %d, measure %d, voice %d note %s, first discarded' % (type2, n, s.msr.ixp+1, s.msr.ixm+1, v2, note2.ns))
|
| 889 |
+
s.slurBuf [n] = (type2, v2, note2, grace)
|
| 890 |
+
else: # unmatched slur, put in dict
|
| 891 |
+
s.slurBuf [n] = (type2, v2, note2, grace)
|
| 892 |
+
|
| 893 |
+
def doNotations (s, note, nttn, isTab):
|
| 894 |
+
for key, val in s.ornaments:
|
| 895 |
+
if nttn.find (key) != None: note.before += [val] # just concat all ornaments
|
| 896 |
+
trem = nttn.find ('ornaments/tremolo')
|
| 897 |
+
if trem != None:
|
| 898 |
+
type = trem.get ('type')
|
| 899 |
+
if type == 'single':
|
| 900 |
+
note.before.insert (0, '!%s!' % (int (trem.text) * '/'))
|
| 901 |
+
else:
|
| 902 |
+
note.fact = None # no time modification in ABC
|
| 903 |
+
if s.tstep: # abc2svg version
|
| 904 |
+
if type == 'stop': note.before.insert (0, '!trem%s!' % trem.text);
|
| 905 |
+
else: # abc2xml version
|
| 906 |
+
if type == 'start': note.before.insert (0, '!%s-!' % (int (trem.text) * '/'));
|
| 907 |
+
fingering = nttn.findall ('technical/fingering')
|
| 908 |
+
for finger in fingering: # handle multiple finger annotations
|
| 909 |
+
if not isTab: note.before += ['!%s!' % finger.text] # fingering goes before chord (addChord)
|
| 910 |
+
snaar = nttn.find ('technical/string')
|
| 911 |
+
if snaar != None and isTab:
|
| 912 |
+
if s.tstep:
|
| 913 |
+
fret = nttn.find ('technical/fret')
|
| 914 |
+
if fret != None: note.tab = (snaar.text, fret.text)
|
| 915 |
+
else:
|
| 916 |
+
deco = '!%s!' % snaar.text # no double string decos (bug in musescore)
|
| 917 |
+
if deco not in note.ntdec: note.ntdec += deco
|
| 918 |
+
wvln = nttn.find ('ornaments/wavy-line')
|
| 919 |
+
if wvln != None:
|
| 920 |
+
if wvln.get ('type') == 'start': note.before = ['!trill(!'] + note.before # keep left-right order!
|
| 921 |
+
elif wvln.get ('type') == 'stop': note.before = ['!trill)!'] + note.before
|
| 922 |
+
glis = nttn.find ('glissando')
|
| 923 |
+
if glis == None: glis = nttn.find ('slide') # treat slide as glissando
|
| 924 |
+
if glis != None:
|
| 925 |
+
lt = '~' if glis.get ('line-type') =='wavy' else '-'
|
| 926 |
+
if glis.get ('type') == 'start': note.before = ['!%s(!' % lt] + note.before # keep left-right order!
|
| 927 |
+
elif glis.get ('type') == 'stop': note.before = ['!%s)!' % lt] + note.before
|
| 928 |
+
|
| 929 |
+
def tabnote (s, alt, ptc, oct, v, ntrec):
|
| 930 |
+
p = s.step_map [ptc] + int (alt or '0') # p in -2 .. 13
|
| 931 |
+
if p > 11: oct += 1 # octave correction
|
| 932 |
+
if p < 0: oct -= 1
|
| 933 |
+
p = p % 12 # remap p into 0..11
|
| 934 |
+
snaar_nw, fret_nw = ntrec.tab # the computed/annotated allocation of nt
|
| 935 |
+
for i in range (4): # support same note on 4 strings
|
| 936 |
+
na = s.note_alts [i % 3] [p] # get alternative representation of same note
|
| 937 |
+
o = oct
|
| 938 |
+
if na in ['^B', '^^B']: o -= 1 # because in adjacent octave
|
| 939 |
+
if na in ['_C', '__C']: o += 1
|
| 940 |
+
if '/' in na or i == 3: o = 9 # emergency notation for 4th string case
|
| 941 |
+
nt = addoct (na, o)
|
| 942 |
+
snaar, fret = s.tabmap.get ((v, nt), ('', '')) # the current allocation of nt
|
| 943 |
+
if not snaar: break # note not yet allocated
|
| 944 |
+
if snaar_nw == snaar: return nt # use present allocation
|
| 945 |
+
if i == 3: # new allocaion needed but none is free
|
| 946 |
+
fmt = 'rejected: voice %d note %3s string %s fret %2s remains: string %s fret %s'
|
| 947 |
+
info (fmt % (v, nt, snaar_nw, fret_nw, snaar, fret), 1)
|
| 948 |
+
ntrec.tab = (snaar, fret)
|
| 949 |
+
s.tabmap [v, nt] = ntrec.tab # for tablature map (voice, note) -> (string, fret)
|
| 950 |
+
return nt # ABC code always in key C (with midi pitch alterations)
|
| 951 |
+
|
| 952 |
+
def ntAbc (s, ptc, oct, note, v, ntrec, isTab): # pitch, octave -> abc notation
|
| 953 |
+
acc2alt = {
|
| 954 |
+
'double-flat': -2,
|
| 955 |
+
'flat-flat': -2,
|
| 956 |
+
'flat': -1,
|
| 957 |
+
'natural-flat': -1,
|
| 958 |
+
'natural': 0,
|
| 959 |
+
'sharp': 1,
|
| 960 |
+
'natural-sharp': 1,
|
| 961 |
+
'sharp-sharp': 2,
|
| 962 |
+
'double-sharp': 2
|
| 963 |
+
}
|
| 964 |
+
oct += s.clefOct.get (s.curStf [v], 0) # minus clef-octave-change value
|
| 965 |
+
acc = note.findtext ('accidental') # should be the notated accidental
|
| 966 |
+
alt = note.findtext ('pitch/alter') # pitch alteration (midi)
|
| 967 |
+
if ntrec.tab: return s.tabnote (alt, ptc, oct, v, ntrec) # implies s.tstep is true (options.t was given)
|
| 968 |
+
elif isTab and s.tstep:
|
| 969 |
+
nt = ['__','_','','^','^^'][int (alt or '0') + 2] + addoct (ptc, oct)
|
| 970 |
+
info ('no string notation found for note %s in voice %d' % (nt, v), 1)
|
| 971 |
+
p = addoct (ptc, oct)
|
| 972 |
+
if alt == None and s.msralts.get (ptc, 0): alt = 0 # no alt but key implies alt -> natural!!
|
| 973 |
+
if alt == None and (p, v) in s.curalts: alt = 0 # no alt but previous note had one -> natural!!
|
| 974 |
+
if acc == None and alt == None: return p # no acc, no alt
|
| 975 |
+
elif acc != None:
|
| 976 |
+
alt = acc2alt [acc] # acc takes precedence over the pitch here!
|
| 977 |
+
else: # now see if we really must add an accidental
|
| 978 |
+
alt = int (float (alt))
|
| 979 |
+
if (p, v) in s.curalts: # the note in this voice has been altered before
|
| 980 |
+
if alt == s.curalts [(p, v)]: return p # alteration still the same
|
| 981 |
+
elif alt == s.msralts.get (ptc, 0): return p # alteration implied by the key
|
| 982 |
+
tieElms = note.findall ('tie') + note.findall ('notations/tied') # in xml we have separate notated ties and playback ties
|
| 983 |
+
if 'stop' in [e.get ('type') for e in tieElms]: return p # don't alter tied notes
|
| 984 |
+
info ('accidental %d added in part %d, measure %d, voice %d note %s' % (alt, s.msr.ixp+1, s.msr.ixm+1, v+1, p))
|
| 985 |
+
s.curalts [(p, v)] = alt
|
| 986 |
+
p = ['__','_','=','^','^^'][alt+2] + p # and finally ... prepend the accidental
|
| 987 |
+
return p
|
| 988 |
+
|
| 989 |
+
def doNote (s, n): # parse a musicXML note tag
|
| 990 |
+
note = Note ()
|
| 991 |
+
v = int (n.findtext ('voice', '1'))
|
| 992 |
+
if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
|
| 993 |
+
chord = n.find ('chord') != None
|
| 994 |
+
p = n.findtext ('pitch/step') or n.findtext ('unpitched/display-step')
|
| 995 |
+
o = n.findtext ('pitch/octave') or n.findtext ('unpitched/display-octave')
|
| 996 |
+
r = n.find ('rest')
|
| 997 |
+
numer = n.findtext ('time-modification/actual-notes')
|
| 998 |
+
if numer:
|
| 999 |
+
denom = n.findtext ('time-modification/normal-notes')
|
| 1000 |
+
note.fact = (int (numer), int (denom))
|
| 1001 |
+
note.tup = [x.get ('type') for x in n.findall ('notations/tuplet')]
|
| 1002 |
+
dur = n.findtext ('duration')
|
| 1003 |
+
grc = n.find ('grace')
|
| 1004 |
+
note.grace = grc != None
|
| 1005 |
+
note.before, note.after = [], '' # strings with ABC stuff that goes before or after a note/chord
|
| 1006 |
+
if note.grace and not s.ingrace: # open a grace sequence
|
| 1007 |
+
s.ingrace = 1
|
| 1008 |
+
note.before = ['{']
|
| 1009 |
+
if grc.get ('slash') == 'yes': note.before += ['/'] # acciaccatura
|
| 1010 |
+
stopgrace = not note.grace and s.ingrace
|
| 1011 |
+
if stopgrace: # close the grace sequence
|
| 1012 |
+
s.ingrace = 0
|
| 1013 |
+
s.msc.lastnote.after += '}' # close grace on lastenote.after
|
| 1014 |
+
if dur == None or note.grace: dur = 0
|
| 1015 |
+
if r == None and n.get ('print-object') == 'no':
|
| 1016 |
+
if chord: return
|
| 1017 |
+
r = 1 # turn invisible notes (that advance the time) into invisible rests
|
| 1018 |
+
note.dur = int (dur)
|
| 1019 |
+
if r == None and (not p or not o): # not a rest and no pitch
|
| 1020 |
+
s.msc.cnt.inc ('nopt', v) # count unpitched notes
|
| 1021 |
+
o, p = 5,'E' # make it an E5 ??
|
| 1022 |
+
isTab = s.curClef and s.curClef.get (s.curStf [v], '').startswith ('tab')
|
| 1023 |
+
nttn = n.find ('notations') # add ornaments
|
| 1024 |
+
if nttn != None: s.doNotations (note, nttn, isTab)
|
| 1025 |
+
e = n.find ('stem') if r == None else None # no !stemless! before rest
|
| 1026 |
+
if e != None and e.text == 'none' and (not isTab or v in s.hasStems or s.tstep):
|
| 1027 |
+
note.before += ['!stemless!']; abcOut.stemless = 0; # ???????????U???????????!stemless!
|
| 1028 |
+
# note.before += ['s']; abcOut.stemless = 1;
|
| 1029 |
+
e = n.find ('accidental')
|
| 1030 |
+
if e != None and e.get ('parentheses') == 'yes': note.ntdec += '!courtesy!'
|
| 1031 |
+
if r != None: noot = 'x' if n.get ('print-object') == 'no' or isTab else 'z'
|
| 1032 |
+
else: noot = s.ntAbc (p, int (o), n, v, note, isTab)
|
| 1033 |
+
if n.find ('unpitched') != None:
|
| 1034 |
+
clef = s.curClef [s.curStf [v]] # the current clef for this voice
|
| 1035 |
+
step = staffStep (p, int (o), clef, s.tstep) # (clef independent) step value of note on the staff
|
| 1036 |
+
instr = n.find ('instrument')
|
| 1037 |
+
instId = instr.get ('id') if instr != None else 'dummyId'
|
| 1038 |
+
midi = s.drumInst.get (instId, abcMid (noot))
|
| 1039 |
+
nh = n.findtext ('notehead', '').replace (' ','-') # replace spaces in xml notehead names for percmap
|
| 1040 |
+
if nh == 'x': noot = '^' + noot.replace ('^','').replace ('_','')
|
| 1041 |
+
if nh in ['circle-x','diamond','triangle']: noot = '_' + noot.replace ('^','').replace ('_','')
|
| 1042 |
+
if nh and n.find ('notehead').get ('filled','') == 'yes': nh += '+'
|
| 1043 |
+
if nh and n.find ('notehead').get ('filled','') == 'no': nh += '-'
|
| 1044 |
+
s.drumNotes [(v, noot)] = (step, midi, nh) # keep data for percussion map
|
| 1045 |
+
tieElms = n.findall ('tie') + n.findall ('notations/tied') # in xml we have separate notated ties and playback ties
|
| 1046 |
+
if 'start' in [e.get ('type') for e in tieElms]: # n can have stop and start tie
|
| 1047 |
+
noot = noot + '-'
|
| 1048 |
+
note.beam = sum ([1 for b in n.findall('beam') if b.text in ['continue', 'end']]) + int (note.grace)
|
| 1049 |
+
lyrlast = 0; rsib = re.compile (r'^.*verse')
|
| 1050 |
+
for e in n.findall ('lyric'):
|
| 1051 |
+
lyrnum = int (rsib.sub ('', e.get ('number', '1'))) # also do Sibelius numbers
|
| 1052 |
+
if lyrnum == 0: lyrnum = lyrlast + 1 # and correct Sibelius bugs
|
| 1053 |
+
else: lyrlast = lyrnum
|
| 1054 |
+
note.lyrs [lyrnum] = doSyllable (e)
|
| 1055 |
+
stemdir = n.findtext ('stem')
|
| 1056 |
+
if s.wstems and (stemdir == 'up' or stemdir == 'down'):
|
| 1057 |
+
if stemdir != s.stemDir.get (v, ''):
|
| 1058 |
+
s.stemDir [v] = stemdir
|
| 1059 |
+
s.msc.appendElem (v, '[I:stemdir %s]' % stemdir)
|
| 1060 |
+
if chord: s.msc.addChord (note, noot)
|
| 1061 |
+
else:
|
| 1062 |
+
xmlstaff = int (n.findtext ('staff', '1'))
|
| 1063 |
+
if s.curStf [v] != xmlstaff: # the note should go to another staff
|
| 1064 |
+
dstaff = xmlstaff - s.curStf [v] # relative new staff number
|
| 1065 |
+
s.curStf [v] = xmlstaff # remember the new staff for this voice
|
| 1066 |
+
s.msc.appendElem (v, '[I:staff %+d]' % dstaff) # insert a move before the note
|
| 1067 |
+
s.msc.appendNote (v, note, noot)
|
| 1068 |
+
for slur in n.findall ('notations/slur'): # s.msc.lastnote points to the last real note/chord inserted above
|
| 1069 |
+
s.matchSlur (slur.get ('type'), slur.get ('number'), v, s.msc.lastnote, note.grace, stopgrace) # match slur definitions
|
| 1070 |
+
|
| 1071 |
+
def doAttr (s, e): # parse a musicXML attribute tag
|
| 1072 |
+
teken = {'C1':'alto1','C2':'alto2','C3':'alto','C4':'tenor','F4':'bass','F3':'bass3','G2':'treble','TAB':'tab','percussion':'perc'}
|
| 1073 |
+
dvstxt = e.findtext ('divisions')
|
| 1074 |
+
if dvstxt: s.msr.divs = int (dvstxt)
|
| 1075 |
+
steps = int (e.findtext ('transpose/chromatic', '0')) # for transposing instrument
|
| 1076 |
+
fifths = e.findtext ('key/fifths')
|
| 1077 |
+
first = s.msc.tijd == 0 and s.msr.ixm == 0 # first attributes in first measure
|
| 1078 |
+
if fifths:
|
| 1079 |
+
key, s.msralts = setKey (int (fifths), e.findtext ('key/mode','major'))
|
| 1080 |
+
if first and not steps and abcOut.key == 'none':
|
| 1081 |
+
abcOut.key = key # first measure -> header, if not transposing instrument or percussion part!
|
| 1082 |
+
elif key != abcOut.key or not first:
|
| 1083 |
+
s.msr.attr += '[K:%s]' % key # otherwise -> voice
|
| 1084 |
+
beats = e.findtext ('time/beats')
|
| 1085 |
+
if beats:
|
| 1086 |
+
unit = e.findtext ('time/beat-type')
|
| 1087 |
+
mtr = beats + '/' + unit
|
| 1088 |
+
if first: abcOut.mtr = mtr # first measure -> header
|
| 1089 |
+
else: s.msr.attr += '[M:%s]' % mtr # otherwise -> voice
|
| 1090 |
+
s.msr.mtr = int (beats), int (unit)
|
| 1091 |
+
s.msr.mdur = (s.msr.divs * s.msr.mtr[0] * 4) // s.msr.mtr[1] # duration of measure in xml-divisions
|
| 1092 |
+
for ms in e.findall('measure-style'):
|
| 1093 |
+
n = int (ms.get ('number', '1')) # staff number
|
| 1094 |
+
voices = s.stfMap [n] # all voices of staff n
|
| 1095 |
+
for mr in ms.findall('measure-repeat'):
|
| 1096 |
+
ty = mr.get('type')
|
| 1097 |
+
if ty == 'start': # remember start measure number and text voor each staff
|
| 1098 |
+
s.repeat_str [n] = [s.msr.ixm, mr.text]
|
| 1099 |
+
for v in voices: # insert repeat into all voices, value will be overwritten at stop
|
| 1100 |
+
s.msc.insertElem (v, s.repeat_str [n])
|
| 1101 |
+
elif ty == 'stop': # calculate repeat measure count for this staff n
|
| 1102 |
+
start_ix, text_ = s.repeat_str [n]
|
| 1103 |
+
repeat_count = s.msr.ixm - start_ix
|
| 1104 |
+
if text_:
|
| 1105 |
+
mid_str = "%s " % text_
|
| 1106 |
+
repeat_count /= int (text_)
|
| 1107 |
+
else:
|
| 1108 |
+
mid_str = "" # overwrite repeat with final string
|
| 1109 |
+
s.repeat_str [n][0] = '[I:repeat %s%d]' % (mid_str, repeat_count)
|
| 1110 |
+
del s.repeat_str [n] # remove closed repeats
|
| 1111 |
+
toct = e.findtext ('transpose/octave-change', '')
|
| 1112 |
+
if toct: steps += 12 * int (toct) # extra transposition of toct octaves
|
| 1113 |
+
for clef in e.findall ('clef'): # a part can have multiple staves
|
| 1114 |
+
n = int (clef.get ('number', '1')) # local staff number for this clef
|
| 1115 |
+
sgn = clef.findtext ('sign')
|
| 1116 |
+
line = clef.findtext ('line', '') if sgn not in ['percussion','TAB'] else ''
|
| 1117 |
+
cs = teken.get (sgn + line, '')
|
| 1118 |
+
oct = clef.findtext ('clef-octave-change', '') or '0'
|
| 1119 |
+
if oct: cs += {-2:'-15', -1:'-8', 1:'+8', 2:'+15'}.get (int (oct), '')
|
| 1120 |
+
s.clefOct [n] = -int (oct); # xml playback pitch -> abc notation pitch
|
| 1121 |
+
if steps: cs += ' transpose=' + str (steps)
|
| 1122 |
+
stfdtl = e.find ('staff-details')
|
| 1123 |
+
if stfdtl and int (stfdtl.get ('number', '1')) == n:
|
| 1124 |
+
lines = stfdtl.findtext ('staff-lines')
|
| 1125 |
+
if lines:
|
| 1126 |
+
lns= '|||' if lines == '3' and sgn == 'TAB' else lines
|
| 1127 |
+
cs += ' stafflines=%s' % lns
|
| 1128 |
+
s.stafflines = int (lines) # remember for tab staves
|
| 1129 |
+
strings = stfdtl.findall ('staff-tuning')
|
| 1130 |
+
if strings:
|
| 1131 |
+
tuning = [st.findtext ('tuning-step') + st.findtext ('tuning-octave') for st in strings]
|
| 1132 |
+
cs += ' strings=%s' % ','.join (tuning)
|
| 1133 |
+
capo = stfdtl.findtext ('capo')
|
| 1134 |
+
if capo: cs += ' capo=%s' % capo
|
| 1135 |
+
s.curClef [n] = cs # keep track of current clef (for percmap)
|
| 1136 |
+
if first: s.clefMap [n] = cs # clef goes to header (where it is mapped to voices)
|
| 1137 |
+
else:
|
| 1138 |
+
voices = s.stfMap[n] # clef change to all voices of staff n
|
| 1139 |
+
for v in voices:
|
| 1140 |
+
if n != s.curStf [v]: # voice is not at its home staff n
|
| 1141 |
+
dstaff = n - s.curStf [v]
|
| 1142 |
+
s.curStf [v] = n # reset current staff at start of measure to home position
|
| 1143 |
+
s.msc.appendElem (v, '[I:staff %+d]' % dstaff)
|
| 1144 |
+
s.msc.appendElem (v, '[K:%s]' % cs)
|
| 1145 |
+
|
| 1146 |
+
def findVoice (s, i, es):
|
| 1147 |
+
stfnum = int (es[i].findtext ('staff',1)) # directions belong to a staff
|
| 1148 |
+
vs = s.stfMap [stfnum] # voices in this staff
|
| 1149 |
+
v1 = vs [0] if vs else 1 # directions to first voice of staff
|
| 1150 |
+
if s.dirtov1: return stfnum, v1, v1 # option --v1
|
| 1151 |
+
for e in es [i+1:]: # or to the voice of the next note
|
| 1152 |
+
if e.tag == 'note':
|
| 1153 |
+
v = int (e.findtext ('voice', '1'))
|
| 1154 |
+
if s.isSib: v += 100 * int (e.findtext ('staff', '1')) # repair bug in Sibelius
|
| 1155 |
+
stf = s.vce2stf [v] # use our own staff allocation
|
| 1156 |
+
return stf, v, v1 # voice of next note, first voice of staff
|
| 1157 |
+
if e.tag == 'backup': break
|
| 1158 |
+
return stfnum, v1, v1 # no note found, fall back to v1
|
| 1159 |
+
|
| 1160 |
+
def doDirection (s, e, i, es): # parse a musicXML direction tag
|
| 1161 |
+
def addDirection (x, vs, tijd, stfnum):
|
| 1162 |
+
if not x: return
|
| 1163 |
+
vs = s.stfMap [stfnum] if '!8v' in x else [vs] # ottava's go to all voices of staff
|
| 1164 |
+
for v in vs:
|
| 1165 |
+
if tijd != None: # insert at time of encounter
|
| 1166 |
+
s.msc.appendElemT (v, x.replace ('(',')').replace ('ped','ped-up'), tijd)
|
| 1167 |
+
else:
|
| 1168 |
+
s.msc.appendElem (v, x)
|
| 1169 |
+
def startStop (dtype, vs, stfnum=1):
|
| 1170 |
+
typmap = {'down':'!8va(!', 'up':'!8vb(!', 'crescendo':'!<(!', 'diminuendo':'!>(!', 'start':'!ped!'}
|
| 1171 |
+
type = t.get ('type', '')
|
| 1172 |
+
k = dtype + t.get ('number', '1') # key to match the closing direction
|
| 1173 |
+
if type in typmap: # opening the direction
|
| 1174 |
+
x = typmap [type]
|
| 1175 |
+
if k in s.dirStk: # closing direction already encountered
|
| 1176 |
+
stype, tijd = s.dirStk [k]; del s.dirStk [k]
|
| 1177 |
+
if stype == 'stop':
|
| 1178 |
+
addDirection (x, vs, tijd, stfnum)
|
| 1179 |
+
else:
|
| 1180 |
+
info ('%s direction %s has no stop in part %d, measure %d, voice %d' % (dtype, stype, s.msr.ixp+1, s.msr.ixm+1, vs+1))
|
| 1181 |
+
s.dirStk [k] = ((type , vs)) # remember voice and type for closing
|
| 1182 |
+
else:
|
| 1183 |
+
s.dirStk [k] = ((type , vs)) # remember voice and type for closing
|
| 1184 |
+
elif type == 'stop':
|
| 1185 |
+
if k in s.dirStk: # matching open direction found
|
| 1186 |
+
type, vs = s.dirStk [k]; del s.dirStk [k] # into the same voice
|
| 1187 |
+
if type == 'stop':
|
| 1188 |
+
info ('%s direction %s has double stop in part %d, measure %d, voice %d' % (dtype, type, s.msr.ixp+1, s.msr.ixm+1, vs+1))
|
| 1189 |
+
x = ''
|
| 1190 |
+
else:
|
| 1191 |
+
x = typmap [type].replace ('(',')').replace ('ped','ped-up')
|
| 1192 |
+
else: # closing direction found before opening
|
| 1193 |
+
s.dirStk [k] = ('stop', s.msc.tijd)
|
| 1194 |
+
x = '' # delay code generation until opening found
|
| 1195 |
+
elif type in ['continue', 'resume', 'discontinue', 'change']:
|
| 1196 |
+
# ?? 'continue' ? 'resume'????????????????
|
| 1197 |
+
# ??????????????
|
| 1198 |
+
# info('Ignoring unsupported direction type: %s' % type)
|
| 1199 |
+
x = ''
|
| 1200 |
+
else: raise ValueError ('wrong direction type')
|
| 1201 |
+
addDirection (x, vs, None, stfnum)
|
| 1202 |
+
tempo, wrdstxt = None, ''
|
| 1203 |
+
plcmnt = e.get ('placement')
|
| 1204 |
+
stf, vs, v1 = s.findVoice (i, es)
|
| 1205 |
+
jmp = '' # for jump sound elements: dacapo, dalsegno and family
|
| 1206 |
+
jmps = [('dacapo','D.C.'),('dalsegno','D.S.'),('tocoda','dacoda'),('fine','fine'),('coda','O'),('segno','S')]
|
| 1207 |
+
t = e.find ('sound') # there are many possible attributes for sound
|
| 1208 |
+
if t != None:
|
| 1209 |
+
minst = t.find ('midi-instrument')
|
| 1210 |
+
if minst:
|
| 1211 |
+
prg = t.findtext ('midi-instrument/midi-program')
|
| 1212 |
+
chn = t.findtext ('midi-instrument/midi-channel')
|
| 1213 |
+
vids = [v for v, id in s.vceInst.items () if id == minst.get ('id')]
|
| 1214 |
+
if vids: vs = vids [0] # direction for the indentified voice, not the staff
|
| 1215 |
+
parm, inst = ('program', str (int (prg) - 1)) if prg else ('channel', chn)
|
| 1216 |
+
if inst and abcOut.volpan > 0: s.msc.appendElem (vs, '[I:MIDI= %s %s]' % (parm, inst))
|
| 1217 |
+
tempo = t.get ('tempo') # look for tempo attribute
|
| 1218 |
+
if tempo:
|
| 1219 |
+
tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
|
| 1220 |
+
tempo_units = (1,4) # always 1/4 for sound elements!
|
| 1221 |
+
for r, v in jmps:
|
| 1222 |
+
if t.get (r, ''): jmp = v; break
|
| 1223 |
+
dirtypes = e.findall ('direction-type')
|
| 1224 |
+
for dirtyp in dirtypes:
|
| 1225 |
+
units = { 'whole': (1,1), 'half': (1,2), 'quarter': (1,4), 'eighth': (1,8) }
|
| 1226 |
+
metr = dirtyp.find ('metronome')
|
| 1227 |
+
if metr != None:
|
| 1228 |
+
t = metr.findtext ('beat-unit', '')
|
| 1229 |
+
if t in units: tempo_units = units [t]
|
| 1230 |
+
else: tempo_units = units ['quarter']
|
| 1231 |
+
if metr.find ('beat-unit-dot') != None:
|
| 1232 |
+
tempo_units = simplify (tempo_units [0] * 3, tempo_units [1] * 2)
|
| 1233 |
+
|
| 1234 |
+
debugtext = metr.findtext ('per-minute')
|
| 1235 |
+
tmpro = None
|
| 1236 |
+
if metr.findtext ('per-minute'):
|
| 1237 |
+
tmpro = re.search ('[.\d]+', metr.findtext ('per-minute')) # look for a number #####
|
| 1238 |
+
if tmpro: tempo = tmpro.group () # overwrites the value set by the sound element of this direction
|
| 1239 |
+
t = dirtyp.find ('wedge')
|
| 1240 |
+
if t != None: startStop ('wedge', vs)
|
| 1241 |
+
allwrds = dirtyp.findall ('words') # insert text annotations
|
| 1242 |
+
if not allwrds: allwrds = dirtyp.findall ('rehearsal') # treat rehearsal mark as text annotation
|
| 1243 |
+
for wrds in allwrds:
|
| 1244 |
+
if jmp: # ignore the words when a jump sound element is present in this direction
|
| 1245 |
+
s.msc.appendElem (vs, '!%s!' % jmp , 1) # to voice
|
| 1246 |
+
break
|
| 1247 |
+
plc = plcmnt == 'below' and '_' or '^'
|
| 1248 |
+
if float (wrds.get ('default-y', '0')) < 0: plc = '_'
|
| 1249 |
+
wrdstxt += (wrds.text or '').replace ('"','\\"').replace ('\n', '\\n')
|
| 1250 |
+
wrdstxt = wrdstxt.strip ()
|
| 1251 |
+
for key, val in dynamics_map.items ():
|
| 1252 |
+
if dirtyp.find ('dynamics/' + key) != None:
|
| 1253 |
+
s.msc.appendElem (vs, val, 1) # to voice
|
| 1254 |
+
if dirtyp.find ('coda') != None: s.msc.appendElem (vs, 'O', 1)
|
| 1255 |
+
if dirtyp.find ('segno') != None: s.msc.appendElem (vs, 'S', 1)
|
| 1256 |
+
t = dirtyp.find ('octave-shift')
|
| 1257 |
+
if t != None: startStop ('octave-shift', vs, stf) # assume size == 8 for the time being
|
| 1258 |
+
t = dirtyp.find ('pedal')
|
| 1259 |
+
if t != None and s.ped:
|
| 1260 |
+
if not s.pedVce: s.pedVce = vs
|
| 1261 |
+
startStop ('pedal', s.pedVce)
|
| 1262 |
+
if dirtyp.findtext ('other-direction') == 'diatonic fretting': s.diafret = 1;
|
| 1263 |
+
if tempo:
|
| 1264 |
+
tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
|
| 1265 |
+
if s.msc.tijd == 0 and s.msr.ixm == 0: # first measure -> header
|
| 1266 |
+
abcOut.tempo = tempo
|
| 1267 |
+
abcOut.tempo_units = tempo_units
|
| 1268 |
+
else:
|
| 1269 |
+
s.msc.appendElem (v1, '[Q:%d/%d=%s]' % (tempo_units [0], tempo_units [1], tempo)) # otherwise -> 1st voice
|
| 1270 |
+
if wrdstxt: s.msc.appendElem (vs, '"%s%s"' % (plc, wrdstxt), 1) # to voice, but after tempo
|
| 1271 |
+
|
| 1272 |
+
def doHarmony (s, e, i, es): # parse a musicXMl harmony tag
|
| 1273 |
+
_, vt, _ = s.findVoice (i, es)
|
| 1274 |
+
short = {'major':'', 'minor':'m', 'augmented':'+', 'diminished':'dim', 'dominant':'7', 'half-diminished':'m7b5'}
|
| 1275 |
+
accmap = {'major':'maj', 'dominant':'', 'minor':'m', 'diminished':'dim', 'augmented':'+', 'suspended':'sus'}
|
| 1276 |
+
modmap = {'second':'2', 'fourth':'4', 'seventh':'7', 'sixth':'6', 'ninth':'9', '11th':'11', '13th':'13'}
|
| 1277 |
+
altmap = {'1':'#', '0':'', '-1':'b'}
|
| 1278 |
+
root = e.findtext ('root/root-step','')
|
| 1279 |
+
alt = altmap.get (e.findtext ('root/root-alter'), '')
|
| 1280 |
+
sus = ''
|
| 1281 |
+
kind = e.findtext ('kind', '')
|
| 1282 |
+
if kind in short: kind = short [kind]
|
| 1283 |
+
elif '-' in kind: # xml chord names: <triad name>-<modification>
|
| 1284 |
+
triad, mod = kind.split ('-')
|
| 1285 |
+
kind = accmap.get (triad, '') + modmap.get (mod, '')
|
| 1286 |
+
if kind.startswith ('sus'): kind, sus = '', kind # sus-suffix goes to the end
|
| 1287 |
+
elif kind == 'none': kind = e.find ('kind').get ('text','')
|
| 1288 |
+
degrees = e.findall ('degree')
|
| 1289 |
+
for d in degrees: # chord alterations
|
| 1290 |
+
kind += altmap.get (d.findtext ('degree-alter'),'') + d.findtext ('degree-value','')
|
| 1291 |
+
kind = kind.replace ('79','9').replace ('713','13').replace ('maj6','6')
|
| 1292 |
+
bass = e.findtext ('bass/bass-step','') + altmap.get (e.findtext ('bass/bass-alter'),'')
|
| 1293 |
+
s.msc.appendElem (vt, '"%s%s%s%s%s"' % (root, alt, kind, sus, bass and '/' + bass), 1)
|
| 1294 |
+
|
| 1295 |
+
def doBarline (s, e): # 0 = no repeat, 1 = begin repeat, 2 = end repeat
|
| 1296 |
+
rep = e.find ('repeat')
|
| 1297 |
+
if rep != None: rep = rep.get ('direction')
|
| 1298 |
+
if s.unfold: # unfold repeat, don't translate barlines
|
| 1299 |
+
return rep and (rep == 'forward' and 1 or 2) or 0
|
| 1300 |
+
loc = e.get ('location', 'right') # right is the default
|
| 1301 |
+
if loc == 'right': # only change style for the right side
|
| 1302 |
+
style = e.findtext ('bar-style')
|
| 1303 |
+
if style == 'light-light': s.msr.rline = '||'
|
| 1304 |
+
elif style == 'light-heavy': s.msr.rline = '|]'
|
| 1305 |
+
if rep != None: # repeat found
|
| 1306 |
+
if rep == 'forward': s.msr.lline = ':'
|
| 1307 |
+
else: s.msr.rline = ':|' # override barline style
|
| 1308 |
+
end = e.find ('ending')
|
| 1309 |
+
if end != None:
|
| 1310 |
+
if end.get ('type') == 'start':
|
| 1311 |
+
n = end.get ('number', '1').replace ('.','').replace (' ','')
|
| 1312 |
+
try: list (map (int, n.split (','))) # should be a list of integers
|
| 1313 |
+
except: n = '"%s"' % n.strip () # illegal musicXML
|
| 1314 |
+
s.msr.lnum = n # assume a start is always at the beginning of a measure
|
| 1315 |
+
elif s.msr.rline == '|': # stop and discontinue the same in ABC ?
|
| 1316 |
+
s.msr.rline = '||' # to stop on a normal barline use || in ABC ?
|
| 1317 |
+
return 0
|
| 1318 |
+
|
| 1319 |
+
def doPrint (s, e): # print element, measure number -> insert a line break
|
| 1320 |
+
if e.get ('new-system') == 'yes' or e.get ('new-page') == 'yes':
|
| 1321 |
+
if not s.nolbrk: return '$' # a line break
|
| 1322 |
+
|
| 1323 |
+
def doPartList (s, e): # translate the start/stop-event-based xml-partlist into proper tree
|
| 1324 |
+
for sp in e.findall ('part-list/score-part'):
|
| 1325 |
+
midi = {}
|
| 1326 |
+
for m in sp.findall ('midi-instrument'):
|
| 1327 |
+
x = [m.findtext (p, s.midDflt [i]) for i,p in enumerate (['midi-channel','midi-program','volume','pan'])]
|
| 1328 |
+
pan = float (x[3])
|
| 1329 |
+
if pan >= -90 and pan <= 90: # would be better to map behind-pannings
|
| 1330 |
+
pan = (float (x[3]) + 90) / 180 * 127 # xml between -90 and +90
|
| 1331 |
+
midi [m.get ('id')] = [int (x[0]), int (x[1]), float (x[2]) * 1.27, pan] # volume 100 -> midi 127
|
| 1332 |
+
up = m.findtext ('midi-unpitched')
|
| 1333 |
+
if up: s.drumInst [m.get ('id')] = int (up) - 1 # store midi-pitch for channel 10 notes
|
| 1334 |
+
s.instMid.append (midi)
|
| 1335 |
+
ps = e.find ('part-list') # partlist = [groupelem]
|
| 1336 |
+
xs = getPartlist (ps) # groupelem = partname | grouplist
|
| 1337 |
+
partlist, _ = parseParts (xs, {}, []) # grouplist = [groupelem, ..., groupdata]
|
| 1338 |
+
return partlist # groupdata = [group-symbol, group-barline, group-name, group-abbrev]
|
| 1339 |
+
|
| 1340 |
+
def mkTitle (s, e):
|
| 1341 |
+
def filterCredits (y): # y == filter level, higher filters less
|
| 1342 |
+
cs = []
|
| 1343 |
+
for x in credits: # skip redundant credit lines
|
| 1344 |
+
if y < 6 and (x in title or x in mvttl): continue # sure skip
|
| 1345 |
+
if y < 5 and (x in composer or x in lyricist): continue # almost sure skip
|
| 1346 |
+
if y < 4 and ((title and title in x) or (mvttl and mvttl in x)): continue # may skip too much
|
| 1347 |
+
if y < 3 and ([1 for c in composer if c in x] or [1 for c in lyricist if c in x]): continue # skips too much
|
| 1348 |
+
if y < 2 and re.match (r'^[\d\W]*$', x): continue # line only contains numbers and punctuation
|
| 1349 |
+
cs.append (x)
|
| 1350 |
+
if y == 0 and (title + mvttl): cs = '' # default: only credit when no title set
|
| 1351 |
+
return cs
|
| 1352 |
+
title = e.findtext ('work/work-title', '').strip ()
|
| 1353 |
+
mvttl = e.findtext ('movement-title', '').strip ()
|
| 1354 |
+
composer, lyricist, credits = [], [], []
|
| 1355 |
+
for creator in e.findall ('identification/creator'):
|
| 1356 |
+
if creator.text:
|
| 1357 |
+
if creator.get ('type') == 'composer':
|
| 1358 |
+
composer += [line.strip () for line in creator.text.split ('\n')]
|
| 1359 |
+
elif creator.get ('type') in ('lyricist', 'transcriber'):
|
| 1360 |
+
lyricist += [line.strip () for line in creator.text.split ('\n')]
|
| 1361 |
+
for rights in e.findall ('identification/rights'):
|
| 1362 |
+
if rights.text:
|
| 1363 |
+
lyricist += [line.strip () for line in rights.text.split ('\n')]
|
| 1364 |
+
for credit in e.findall('credit'):
|
| 1365 |
+
cs = ''.join (e.text or '' for e in credit.findall('credit-words'))
|
| 1366 |
+
credits += [re.sub (r'\s*[\r\n]\s*', ' ', cs)]
|
| 1367 |
+
credits = filterCredits (s.ctf)
|
| 1368 |
+
if title: title = 'T:%s\n' % title.replace ('\n', '\nT:')
|
| 1369 |
+
if mvttl: title += 'T:%s\n' % mvttl.replace ('\n', '\nT:')
|
| 1370 |
+
if credits: title += '\n'.join (['T:%s' % c for c in credits]) + '\n'
|
| 1371 |
+
if composer: title += '\n'.join (['C:%s' % c for c in composer]) + '\n'
|
| 1372 |
+
if lyricist: title += '\n'.join (['Z:%s' % c for c in lyricist]) + '\n'
|
| 1373 |
+
if title: abcOut.title = title[:-1]
|
| 1374 |
+
s.isSib = 'Sibelius' in (e.findtext ('identification/encoding/software') or '')
|
| 1375 |
+
if s.isSib: info ('Sibelius MusicXMl is unreliable')
|
| 1376 |
+
|
| 1377 |
+
def doDefaults (s, e):
|
| 1378 |
+
if not s.doPageFmt: return # return if -pf option absent
|
| 1379 |
+
d = e.find ('defaults');
|
| 1380 |
+
if d == None: return;
|
| 1381 |
+
mils = d.findtext ('scaling/millimeters') # mills == staff height (mm)
|
| 1382 |
+
tenths = d.findtext ('scaling/tenths') # staff height in tenths
|
| 1383 |
+
if not mils or not tenths: return
|
| 1384 |
+
xmlScale = float (mils) / float (tenths) / 10 # tenths -> mm
|
| 1385 |
+
space = 10 * xmlScale # space between staff lines == 10 tenths
|
| 1386 |
+
abcScale = space / 0.2117 # 0.2117 cm = 6pt = space between staff lines for scale = 1.0 in abcm2ps
|
| 1387 |
+
abcOut.pageFmt ['scale'] = abcScale
|
| 1388 |
+
eks = 2 * ['page-layout/'] + 4 * ['page-layout/page-margins/']
|
| 1389 |
+
eks = [a+b for a,b in zip (eks, 'page-height,page-width,left-margin,right-margin,top-margin,bottom-margin'.split (','))]
|
| 1390 |
+
for i in range (6):
|
| 1391 |
+
v = d.findtext (eks [i])
|
| 1392 |
+
k = abcOut.pagekeys [i+1] # pagekeys [0] == scale already done, skip it
|
| 1393 |
+
if not abcOut.pageFmt [k] and v:
|
| 1394 |
+
try: abcOut.pageFmt [k] = float (v) * xmlScale # -> cm
|
| 1395 |
+
except: info ('illegal value %s for xml element %s', (v, eks [i])); continue # just skip illegal values
|
| 1396 |
+
|
| 1397 |
+
def locStaffMap (s, part, maten): # map voice to staff with majority voting
|
| 1398 |
+
vmap = {} # {voice -> {staff -> n}} count occurrences of voice in staff
|
| 1399 |
+
s.vceInst = {} # {voice -> instrument id} for this part
|
| 1400 |
+
s.msc.vnums = {} # voice id's
|
| 1401 |
+
s.hasStems = {} # XML voice nums with at least one note with a stem (for tab key)
|
| 1402 |
+
s.stfMap, s.clefMap = {}, {} # staff -> [voices], staff -> clef
|
| 1403 |
+
ns = part.findall ('measure/note')
|
| 1404 |
+
for n in ns: # count staff allocations for all notes
|
| 1405 |
+
v = int (n.findtext ('voice', '1'))
|
| 1406 |
+
if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
|
| 1407 |
+
s.msc.vnums [v] = 1 # collect all used voice id's in this part
|
| 1408 |
+
sn = int (n.findtext ('staff', '1'))
|
| 1409 |
+
s.stfMap [sn] = []
|
| 1410 |
+
if v not in vmap:
|
| 1411 |
+
vmap [v] = {sn:1}
|
| 1412 |
+
else:
|
| 1413 |
+
d = vmap[v] # counter for voice v
|
| 1414 |
+
d[sn] = d.get (sn, 0) + 1 # ++ number of allocations for staff sn
|
| 1415 |
+
x = n.find ('instrument')
|
| 1416 |
+
if x != None: s.vceInst [v] = x.get ('id')
|
| 1417 |
+
x, noRest = n.findtext ('stem'), n.find ('rest') == None
|
| 1418 |
+
if noRest and (not x or x != 'none'): s.hasStems [v] = 1 # XML voice v has at least one stem
|
| 1419 |
+
vks = list (vmap.keys ())
|
| 1420 |
+
if s.jscript or s.isSib: vks.sort ()
|
| 1421 |
+
for v in vks: # choose staff with most allocations for each voice
|
| 1422 |
+
xs = [(n, sn) for sn, n in vmap[v].items ()]
|
| 1423 |
+
xs.sort ()
|
| 1424 |
+
stf = xs[-1][1] # the winner: staff with most notes of voice v
|
| 1425 |
+
s.stfMap [stf].append (v)
|
| 1426 |
+
s.vce2stf [v] = stf # reverse map
|
| 1427 |
+
s.curStf [v] = stf # current staff of XML voice v
|
| 1428 |
+
|
| 1429 |
+
def addStaffMap (s, vvmap): # vvmap: xml voice number -> global abc voice number
|
| 1430 |
+
part = [] # default: brace on staffs of one part
|
| 1431 |
+
for stf, voices in sorted (s.stfMap.items ()): # s.stfMap has xml staff and voice numbers
|
| 1432 |
+
locmap = [vvmap [iv] for iv in voices if iv in vvmap]
|
| 1433 |
+
nostem = [(iv not in s.hasStems) for iv in voices if iv in vvmap] # same order as locmap
|
| 1434 |
+
if locmap: # abc voice number of staff stf
|
| 1435 |
+
part.append (locmap)
|
| 1436 |
+
clef = s.clefMap.get (stf, 'treble') # {xml staff number -> clef}
|
| 1437 |
+
for i, iv in enumerate (locmap):
|
| 1438 |
+
clef_attr = ''
|
| 1439 |
+
if clef.startswith ('tab'):
|
| 1440 |
+
if nostem [i] and 'nostems' not in clef: clef_attr = ' nostems'
|
| 1441 |
+
if s.diafret and 'diafret' not in clef: clef_attr += ' diafret' # for all voices in the part
|
| 1442 |
+
abcOut.clefs [iv] = clef + clef_attr # add nostems when all notes of voice had no stem
|
| 1443 |
+
s.gStfMap.append (part)
|
| 1444 |
+
|
| 1445 |
+
def addMidiMap (s, ip, vvmap): # map abc voices to midi settings
|
| 1446 |
+
instr = s.instMid [ip] # get the midi settings for this part
|
| 1447 |
+
if instr.values (): defInstr = list(instr.values ())[0] # default settings = first instrument
|
| 1448 |
+
else: defInstr = s.midDflt # no instruments defined
|
| 1449 |
+
xs = []
|
| 1450 |
+
for v, vabc in vvmap.items (): # xml voice num, abc voice num
|
| 1451 |
+
ks = sorted (s.drumNotes.items ())
|
| 1452 |
+
ds = [(nt, step, midi, head) for (vd, nt), (step, midi, head) in ks if v == vd] # map perc notes
|
| 1453 |
+
id = s.vceInst.get (v, '') # get the instrument-id for part with multiple instruments
|
| 1454 |
+
if id in instr: # id is defined as midi-instrument in part-list
|
| 1455 |
+
xs.append ((vabc, instr [id] + ds)) # get midi settings for id
|
| 1456 |
+
else: xs.append ((vabc, defInstr + ds)) # only one instrument for this part
|
| 1457 |
+
xs.sort () # put abc voices in order
|
| 1458 |
+
s.midiMap.extend ([midi for v, midi in xs])
|
| 1459 |
+
snaarmap = ['E','G','B','d', 'f', 'a', "c'", "e'"]
|
| 1460 |
+
diamap = '0,1-,1,1+,2,3,3,4,4,5,6,6+,7,8-,8,8+,9,10,10,11,11,12,13,13+,14'.split (',')
|
| 1461 |
+
for k in sorted (s.tabmap.keys ()): # add %%map's for all tab voices
|
| 1462 |
+
v, noot = k;
|
| 1463 |
+
snaar, fret = s.tabmap [k];
|
| 1464 |
+
if s.diafret: fret = diamap [int (fret)]
|
| 1465 |
+
vabc = vvmap [v]
|
| 1466 |
+
snaar = s.stafflines - int (snaar)
|
| 1467 |
+
xs = s.tabVceMap.get (vabc, [])
|
| 1468 |
+
xs.append ('%%%%map tab%d %s print=%s heads=kop%s\n' % (vabc, noot, snaarmap [snaar], fret))
|
| 1469 |
+
s.tabVceMap [vabc] = xs
|
| 1470 |
+
s.koppen [fret] = 1 # collect noteheads for SVG defs
|
| 1471 |
+
|
| 1472 |
+
def parse (s, fobj):
|
| 1473 |
+
vvmapAll = {} # collect xml->abc voice maps (vvmap) of all parts
|
| 1474 |
+
e = E.parse (fobj)
|
| 1475 |
+
s.mkTitle (e)
|
| 1476 |
+
s.doDefaults (e)
|
| 1477 |
+
partlist = s.doPartList (e)
|
| 1478 |
+
parts = e.findall ('part')
|
| 1479 |
+
for ip, p in enumerate (parts):
|
| 1480 |
+
maten = p.findall ('measure')
|
| 1481 |
+
s.locStaffMap (p, maten) # {voice -> staff} for this part
|
| 1482 |
+
s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
|
| 1483 |
+
s.clefOct = {} # xml staff number -> current clef-octave-change
|
| 1484 |
+
s.curClef = {} # xml staff number -> current abc clef
|
| 1485 |
+
s.stemDir = {} # xml voice number -> current stem direction
|
| 1486 |
+
s.tabmap = {} # (xml voice, abc note) -> (string, fret)
|
| 1487 |
+
s.diafret = 0 # use diatonic fretting
|
| 1488 |
+
s.stafflines = 5
|
| 1489 |
+
s.msc.initVoices (newPart = 1) # create all voices
|
| 1490 |
+
aantalHerhaald = 0 # keep track of number of repititions
|
| 1491 |
+
herhaalMaat = 0 # target measure of the repitition
|
| 1492 |
+
divisions = [] # current value of <divisions> for each measure
|
| 1493 |
+
s.msr = Measure (ip) # various measure data
|
| 1494 |
+
while s.msr.ixm < len (maten):
|
| 1495 |
+
if ip == 31 and s.msr.ixm == 405:
|
| 1496 |
+
print('')
|
| 1497 |
+
maat = maten [s.msr.ixm]
|
| 1498 |
+
herhaal, lbrk = 0, ''
|
| 1499 |
+
s.msr.reset ()
|
| 1500 |
+
s.curalts = {} # passing accidentals are reset each measure
|
| 1501 |
+
es = list (maat)
|
| 1502 |
+
for i, e in enumerate (es):
|
| 1503 |
+
if e.tag == 'note': s.doNote (e)
|
| 1504 |
+
elif e.tag == 'attributes': s.doAttr (e)
|
| 1505 |
+
elif e.tag == 'direction':
|
| 1506 |
+
s.doDirection (e, i, es)
|
| 1507 |
+
elif e.tag == 'sound': s.doDirection (maat, i, es) # sound element directly in measure!
|
| 1508 |
+
elif e.tag == 'harmony': s.doHarmony (e, i, es)
|
| 1509 |
+
elif e.tag == 'barline':
|
| 1510 |
+
herhaal = s.doBarline (e)
|
| 1511 |
+
elif e.tag == 'backup':
|
| 1512 |
+
dt = int (e.findtext ('duration'))
|
| 1513 |
+
if chkbug (dt, s.msr): s.msc.incTime (-dt)
|
| 1514 |
+
elif e.tag == 'forward':
|
| 1515 |
+
dt = int (e.findtext ('duration'))
|
| 1516 |
+
if chkbug (dt, s.msr): s.msc.incTime (dt)
|
| 1517 |
+
elif e.tag == 'print': lbrk = s.doPrint (e)
|
| 1518 |
+
s.msc.addBar (lbrk, s.msr)
|
| 1519 |
+
divisions.append (s.msr.divs)
|
| 1520 |
+
if herhaal == 1:
|
| 1521 |
+
herhaalMaat = s.msr.ixm
|
| 1522 |
+
s.msr.ixm += 1
|
| 1523 |
+
elif herhaal == 2:
|
| 1524 |
+
if aantalHerhaald < 1: # jump
|
| 1525 |
+
s.msr.ixm = herhaalMaat
|
| 1526 |
+
aantalHerhaald += 1
|
| 1527 |
+
else:
|
| 1528 |
+
aantalHerhaald = 0 # reset
|
| 1529 |
+
s.msr.ixm += 1 # just continue
|
| 1530 |
+
else: s.msr.ixm += 1 # on to the next measure
|
| 1531 |
+
for rv in s.repeat_str.values (): # close hanging measure-repeats without stop
|
| 1532 |
+
rv [0] = '[I:repeat %s %d]' % (rv [1], 1)
|
| 1533 |
+
vvmap = s.msc.outVoices (divisions, ip, s.isSib)
|
| 1534 |
+
s.addStaffMap (vvmap) # update global staff map
|
| 1535 |
+
s.addMidiMap (ip, vvmap)
|
| 1536 |
+
vvmapAll.update (vvmap)
|
| 1537 |
+
if vvmapAll: # skip output if no part has any notes
|
| 1538 |
+
abcOut.mkHeader (s.gStfMap, partlist, s.midiMap, s.tabVceMap, s.koppen)
|
| 1539 |
+
abcOut.writeall ()
|
| 1540 |
+
else: info ('nothing written, %s has no notes ...' % abcOut.fnmext)
|
| 1541 |
+
|
| 1542 |
+
#----------------
|
| 1543 |
+
# Main Program
|
| 1544 |
+
#----------------
|
| 1545 |
+
if __name__ == '__main__':
|
| 1546 |
+
from optparse import OptionParser
|
| 1547 |
+
from glob import glob
|
| 1548 |
+
from zipfile import ZipFile
|
| 1549 |
+
ustr = '%prog [-h] [-u] [-m] [-c C] [-d D] [-n CPL] [-b BPL] [-o DIR] [-v V]\n'
|
| 1550 |
+
ustr += '[-x] [-p PFMT] [-t] [-s] [-i] [--v1] [--noped] [--stems] <file1> [<file2> ...]'
|
| 1551 |
+
parser = OptionParser (usage=ustr, version=str(VERSION))
|
| 1552 |
+
parser.add_option ("-u", action="store_true", help="unfold simple repeats")
|
| 1553 |
+
parser.add_option ("-m", action="store", help="0 -> no %%MIDI, 1 -> minimal %%MIDI, 2-> all %%MIDI", default=0)
|
| 1554 |
+
parser.add_option ("-c", action="store", type="int", help="set credit text filter to C", default=0, metavar='C')
|
| 1555 |
+
parser.add_option ("-d", action="store", type="int", help="set L:1/D", default=0, metavar='D') # ??????????????L
|
| 1556 |
+
parser.add_option ("-n", action="store", type="int", help="CPL: max number of characters per line (default 100)", default=0, metavar='CPL')
|
| 1557 |
+
parser.add_option ("-b", action="store", type="int", help="BPL: max number of bars per line", default=0, metavar='BPL')
|
| 1558 |
+
parser.add_option ("-o", action="store", help="store abc files in DIR", default='', metavar='DIR')
|
| 1559 |
+
parser.add_option ("-v", action="store", type="int", help="set volta typesetting behaviour to V", default=0, metavar='V')
|
| 1560 |
+
parser.add_option ("-x", action="store_true", help="output no line breaks")
|
| 1561 |
+
parser.add_option ("-p", action="store", help="pageformat PFMT (cm) = scale, pageheight, pagewidth, leftmargin, rightmargin, topmargin, botmargin", default='', metavar='PFMT')
|
| 1562 |
+
parser.add_option ("-j", action="store_true", help="switch for compatibility with javascript version")
|
| 1563 |
+
parser.add_option ("-t", action="store_true", help="translate perc- and tab-staff to ABC code with %%map, %%voicemap")
|
| 1564 |
+
parser.add_option ("-s", action="store_true", help="shift node heads 3 units left in a tab staff")
|
| 1565 |
+
parser.add_option ("--v1", action="store_true", help="start-stop directions allways to first voice of staff")
|
| 1566 |
+
parser.add_option ("--noped", action="store_false", help="skip all pedal directions", dest='ped', default=True)
|
| 1567 |
+
parser.add_option ("--stems", action="store_true", help="translate stem directions", dest='stm', default=False)
|
| 1568 |
+
parser.add_option ("-i", action="store_true", help="read xml file from standard input")
|
| 1569 |
+
options, args = parser.parse_args ()
|
| 1570 |
+
if options.n < 0: parser.error ('only values >= 0')
|
| 1571 |
+
if options.b < 0: parser.error ('only values >= 0')
|
| 1572 |
+
if options.d and options.d not in [2**n for n in range (10)]:
|
| 1573 |
+
parser.error ('D should be on of %s' % ','.join ([str(2**n) for n in range (10)]))
|
| 1574 |
+
options.p = options.p and options.p.split (',') or [] # ==> [] | [string]
|
| 1575 |
+
if len (args) == 0 and not options.i: parser.error ('no input file given')
|
| 1576 |
+
pad = options.o
|
| 1577 |
+
if pad:
|
| 1578 |
+
if not os.path.exists (pad): os.mkdir (pad)
|
| 1579 |
+
if not os.path.isdir (pad): parser.error ('%s is not a directory' % pad)
|
| 1580 |
+
fnmext_list = []
|
| 1581 |
+
for i in args: fnmext_list += glob (i)
|
| 1582 |
+
if options.i: fnmext_list = ['stdin.xml']
|
| 1583 |
+
if not fnmext_list: parser.error ('none of the input files exist')
|
| 1584 |
+
for X, fnmext in enumerate (fnmext_list):
|
| 1585 |
+
fnm, ext = os.path.splitext (fnmext)
|
| 1586 |
+
if ext.lower () not in ('.xml','.mxl','.musicxml'):
|
| 1587 |
+
info ('skipped input file %s, it should have extension .xml or .mxl' % fnmext)
|
| 1588 |
+
continue
|
| 1589 |
+
if os.path.isdir (fnmext):
|
| 1590 |
+
info ('skipped directory %s. Only files are accepted' % fnmext)
|
| 1591 |
+
continue
|
| 1592 |
+
if fnmext == 'stdin.xml':
|
| 1593 |
+
fobj = sys.stdin
|
| 1594 |
+
elif ext.lower () == '.mxl': # extract .xml file from .mxl file
|
| 1595 |
+
z = ZipFile(fnmext)
|
| 1596 |
+
for n in z.namelist(): # assume there is always an xml file in a mxl archive !!
|
| 1597 |
+
if (n[:4] != 'META') and (n[-4:].lower() == '.xml'):
|
| 1598 |
+
fobj = z.open (n)
|
| 1599 |
+
break # assume only one MusicXML file per archive
|
| 1600 |
+
else:
|
| 1601 |
+
fobj = open (fnmext, 'rb') # open regular xml file
|
| 1602 |
+
|
| 1603 |
+
abcOut = ABCoutput (fnm + '.abc', pad, X, options) # create global ABC output object
|
| 1604 |
+
psr = Parser (options) # xml parser
|
| 1605 |
+
try:
|
| 1606 |
+
psr.parse (fobj) # parse file fobj and write abc to <fnm>.abc
|
| 1607 |
+
except:
|
| 1608 |
+
etype, value, traceback = sys.exc_info () # works in python 2 & 3
|
| 1609 |
+
info ('** %s occurred: %s in %s' % (etype, value, fnmext), 0)
|