| import os
|
| import time
|
| import torch
|
| import re
|
| import difflib
|
| from utils import *
|
| from config import *
|
| from transformers import GPT2Config, BitsAndBytesConfig
|
| from bitsandbytes.nn import Linear8bitLt
|
| from bitsandbytes.optim import GlobalOptimManager
|
| from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
|
| from abctoolkit.transpose import Note_list, Pitch_sign_list
|
| from abctoolkit.duration import calculate_bartext_duration
|
| import requests
|
| import torch
|
| from huggingface_hub import hf_hub_download
|
| import logging
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
| Note_list = Note_list + ['z', 'x']
|
|
|
| if torch.cuda.is_available():
|
| device = torch.device("cuda")
|
| else:
|
| device = torch.device("cpu")
|
|
|
| patchilizer = Patchilizer()
|
|
|
| patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
|
| max_length=PATCH_LENGTH,
|
| max_position_embeddings=PATCH_LENGTH,
|
| n_embd=HIDDEN_SIZE,
|
| num_attention_heads=HIDDEN_SIZE // 64,
|
| vocab_size=1)
|
| byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
| max_length=PATCH_SIZE + 1,
|
| max_position_embeddings=PATCH_SIZE + 1,
|
| hidden_size=HIDDEN_SIZE,
|
| num_attention_heads=HIDDEN_SIZE // 64,
|
| vocab_size=128)
|
|
|
| quantization_config = BitsAndBytesConfig(
|
| load_in_8bit=True,
|
| llm_int8_skip_modules=["patch_embedding"],
|
| bnb_4bit_use_double_quant=True
|
| )
|
|
|
|
|
| model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
|
|
|
| def download_model_weights():
|
| weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
|
| local_weights_path = os.path.join(os.getcwd(), weights_path)
|
|
|
|
|
| if os.path.exists(local_weights_path):
|
| logger.info(f"Model weights already exist at {local_weights_path}")
|
| return local_weights_path
|
|
|
| logger.info("Downloading model weights from HuggingFace Hub...")
|
| try:
|
|
|
| downloaded_path = hf_hub_download(
|
| repo_id="ElectricAlexis/NotaGen",
|
| filename=weights_path,
|
| local_dir=os.getcwd(),
|
| local_dir_use_symlinks=False
|
| )
|
| logger.info(f"Model weights downloaded successfully to {downloaded_path}")
|
| return downloaded_path
|
| except Exception as e:
|
| logger.error(f"Error downloading model weights: {str(e)}")
|
| raise RuntimeError(f"Failed to download model weights: {str(e)}")
|
|
|
|
|
| def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
| """
|
| 为 k-bit 训练准备模型。
|
| 功能包括:
|
| 1. 将模型转换为混合精度(FP16)。
|
| 2. 禁用不需要的梯度计算。
|
| 3. 启用梯度检查点(可选)。
|
| """
|
|
|
| model = model.to(dtype=torch.float16)
|
|
|
|
|
| for param in model.parameters():
|
| if param.dtype == torch.float32:
|
| param.requires_grad = False
|
|
|
|
|
| if use_gradient_checkpointing:
|
| model.gradient_checkpointing_enable()
|
|
|
| return model
|
|
|
|
|
| model = prepare_model_for_kbit_training(
|
| model,
|
| use_gradient_checkpointing=False
|
| )
|
|
|
| print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
|
|
|
|
| model_weights_path = download_model_weights()
|
| checkpoint = torch.load(model_weights_path, map_location=torch.device(device))
|
| model.load_state_dict(checkpoint['model'], strict=False)
|
|
|
| model = model.to(device)
|
| model.eval()
|
|
|
|
|
| def postprocess_inst_names(abc_text):
|
|
|
| with open('standard_inst_names.txt', 'r', encoding='utf-8') as f:
|
| standard_instruments_list = [line.strip() for line in f if line.strip()]
|
|
|
| with open('instrument_mapping.json', 'r', encoding='utf-8') as f:
|
| instrument_mapping = json.load(f)
|
|
|
| abc_lines = abc_text.split('\n')
|
| abc_lines = list(filter(None, abc_lines))
|
| abc_lines = [line + '\n' for line in abc_lines]
|
|
|
| for i, line in enumerate(abc_lines):
|
| if line.startswith('V:') and 'nm=' in line:
|
| match = re.search(r'nm="([^"]*)"', line)
|
| if match:
|
| inst_name = match.group(1)
|
|
|
|
|
| if inst_name in standard_instruments_list:
|
| continue
|
|
|
|
|
| matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6)
|
|
|
| if matching_key:
|
|
|
| replacement = instrument_mapping[matching_key[0]]
|
| new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"')
|
| abc_lines[i] = new_line
|
|
|
|
|
| processed_abc_text = ''.join(abc_lines)
|
| return processed_abc_text
|
|
|
|
|
| def complete_brackets(s):
|
| stack = []
|
| bracket_map = {'{': '}', '[': ']', '(': ')'}
|
|
|
|
|
| for char in s:
|
| if char in bracket_map:
|
| stack.append(char)
|
| elif char in bracket_map.values():
|
|
|
| for key, value in bracket_map.items():
|
| if value == char:
|
| if stack and stack[-1] == key:
|
| stack.pop()
|
| break
|
|
|
|
|
| completion = ''.join(bracket_map[c] for c in reversed(stack))
|
| return s + completion
|
|
|
|
|
|
|
| def rest_unreduce(abc_lines):
|
|
|
| tunebody_index = None
|
| for i in range(len(abc_lines)):
|
| if abc_lines[i].startswith('%%score'):
|
| abc_lines[i] = complete_brackets(abc_lines[i])
|
| if '[V:' in abc_lines[i]:
|
| tunebody_index = i
|
| break
|
|
|
| metadata_lines = abc_lines[: tunebody_index]
|
| tunebody_lines = abc_lines[tunebody_index:]
|
|
|
| part_symbol_list = []
|
| voice_group_list = []
|
| for line in metadata_lines:
|
| if line.startswith('%%score'):
|
| for round_bracket_match in re.findall(r'\((.*?)\)', line):
|
| voice_group_list.append(round_bracket_match.split())
|
| existed_voices = [item for sublist in voice_group_list for item in sublist]
|
| if line.startswith('V:'):
|
| symbol = line.split()[0]
|
| part_symbol_list.append(symbol)
|
| if symbol[2:] not in existed_voices:
|
| voice_group_list.append([symbol[2:]])
|
| z_symbol_list = []
|
| x_symbol_list = []
|
| for voice_group in voice_group_list:
|
| z_symbol_list.append('V:' + voice_group[0])
|
| for j in range(1, len(voice_group)):
|
| x_symbol_list.append('V:' + voice_group[j])
|
|
|
| part_symbol_list.sort(key=lambda x: int(x[2:]))
|
|
|
| unreduced_tunebody_lines = []
|
|
|
| for i, line in enumerate(tunebody_lines):
|
| unreduced_line = ''
|
|
|
| line = re.sub(r'^\[r:[^\]]*\]', '', line)
|
|
|
| pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
|
| matches = re.findall(pattern, line)
|
|
|
| line_bar_dict = {}
|
| for match in matches:
|
| key = f'V:{match[0]}'
|
| value = match[1]
|
| line_bar_dict[key] = value
|
|
|
|
|
| dur_dict = {}
|
| for symbol, bartext in line_bar_dict.items():
|
| right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
|
| bartext = bartext[:-len(right_barline)]
|
| try:
|
| bar_dur = calculate_bartext_duration(bartext)
|
| except:
|
| bar_dur = None
|
| if bar_dur is not None:
|
| if bar_dur not in dur_dict.keys():
|
| dur_dict[bar_dur] = 1
|
| else:
|
| dur_dict[bar_dur] += 1
|
|
|
| try:
|
| ref_dur = max(dur_dict, key=dur_dict.get)
|
| except:
|
| pass
|
|
|
| if i == 0:
|
| prefix_left_barline = line.split('[V:')[0]
|
| else:
|
| prefix_left_barline = ''
|
|
|
| for symbol in part_symbol_list:
|
| if symbol in line_bar_dict.keys():
|
| symbol_bartext = line_bar_dict[symbol]
|
| else:
|
| if symbol in z_symbol_list:
|
| symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
|
| elif symbol in x_symbol_list:
|
| symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
|
| unreduced_line += '[' + symbol + ']' + symbol_bartext
|
|
|
| unreduced_tunebody_lines.append(unreduced_line + '\n')
|
|
|
| unreduced_lines = metadata_lines + unreduced_tunebody_lines
|
|
|
| return unreduced_lines
|
|
|
|
|
|
|
|
|
|
|
|
|
| def inference_patch(period, composer, instrumentation):
|
|
|
| prompt_lines=[
|
| '%' + period + '\n',
|
| '%' + composer + '\n',
|
| '%' + instrumentation + '\n']
|
|
|
| while True:
|
|
|
| failure_flag = False
|
|
|
| bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
|
|
|
| start_time = time.time()
|
|
|
| prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
|
| byte_list = list(''.join(prompt_lines))
|
| context_tunebody_byte_list = []
|
| metadata_byte_list = []
|
|
|
| print(''.join(byte_list), end='')
|
|
|
| prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
|
| in prompt_patches]
|
| prompt_patches.insert(0, bos_patch)
|
|
|
| input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
|
|
|
| end_flag = False
|
| cut_index = None
|
|
|
| tunebody_flag = False
|
|
|
| with torch.inference_mode():
|
|
|
| while True:
|
| with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| predicted_patch = model.generate(input_patches.unsqueeze(0),
|
| top_k=TOP_K,
|
| top_p=TOP_P,
|
| temperature=TEMPERATURE)
|
| if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'):
|
| tunebody_flag = True
|
| r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
|
| temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
|
| predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
|
| top_k=TOP_K,
|
| top_p=TOP_P,
|
| temperature=TEMPERATURE)
|
| predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
|
| if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
|
| end_flag = True
|
| break
|
| next_patch = patchilizer.decode([predicted_patch])
|
|
|
| for char in next_patch:
|
| byte_list.append(char)
|
| if tunebody_flag:
|
| context_tunebody_byte_list.append(char)
|
| else:
|
| metadata_byte_list.append(char)
|
| print(char, end='')
|
|
|
| patch_end_flag = False
|
| for j in range(len(predicted_patch)):
|
| if patch_end_flag:
|
| predicted_patch[j] = patchilizer.special_token_id
|
| if predicted_patch[j] == patchilizer.eos_token_id:
|
| patch_end_flag = True
|
|
|
| predicted_patch = torch.tensor([predicted_patch], device=device)
|
| input_patches = torch.cat([input_patches, predicted_patch], dim=1)
|
|
|
| if len(byte_list) > 102400:
|
| failure_flag = True
|
| break
|
| if time.time() - start_time > 20 * 60:
|
| failure_flag = True
|
| break
|
|
|
| if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
|
|
|
| print('Stream generating...')
|
|
|
| metadata = ''.join(metadata_byte_list)
|
| context_tunebody = ''.join(context_tunebody_byte_list)
|
|
|
| if '\n' not in context_tunebody:
|
|
|
| break
|
|
|
| context_tunebody_liness = context_tunebody.split('\n')
|
| if not context_tunebody.endswith('\n'):
|
| context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
|
| else:
|
| context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
|
|
|
| cut_index = len(context_tunebody_liness) // 2
|
| abc_code_slice = metadata + ''.join(context_tunebody_liness[-cut_index:])
|
|
|
| input_patches = patchilizer.encode_generate(abc_code_slice)
|
|
|
| input_patches = [item for sublist in input_patches for item in sublist]
|
| input_patches = torch.tensor([input_patches], device=device)
|
| input_patches = input_patches.reshape(1, -1)
|
|
|
| context_tunebody_byte_list = []
|
|
|
| if not failure_flag:
|
| abc_text = ''.join(byte_list)
|
|
|
|
|
| abc_lines = abc_text.split('\n')
|
| abc_lines = list(filter(None, abc_lines))
|
| abc_lines = [line + '\n' for line in abc_lines]
|
| try:
|
| unreduced_abc_lines = rest_unreduce(abc_lines)
|
| except:
|
| failure_flag = True
|
| pass
|
| else:
|
| unreduced_abc_lines = [line for line in unreduced_abc_lines if not(line.startswith('%') and not line.startswith('%%'))]
|
| unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
|
| unreduced_abc_text = ''.join(unreduced_abc_lines)
|
| return unreduced_abc_text
|
|
|
|
|
|
|
|
|
| if __name__ == '__main__':
|
|
|
| inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral')
|
|
|
|
|