| import os
|
| import gc
|
| import time
|
| import math
|
| import json
|
| import wandb
|
| import torch
|
| import random
|
| import numpy as np
|
| from abctoolkit.transpose import Key2index, Key2Mode
|
| from utils import *
|
| from config import *
|
| from tqdm import tqdm
|
| from copy import deepcopy
|
| from torch.cuda.amp import autocast, GradScaler
|
| from torch.utils.data import Dataset, DataLoader
|
| from transformers import GPT2Config, LlamaConfig, get_scheduler, get_constant_schedule_with_warmup
|
| import torch.distributed as dist
|
| from torch.nn.parallel import DistributedDataParallel as DDP
|
| from torch.utils.data.distributed import DistributedSampler
|
|
|
| Index2Key = {index: key for key, index in Key2index.items() if index not in [1, 11]}
|
| Mode2Key = {mode: key for key, mode_list in Key2Mode.items() for mode in mode_list }
|
|
|
|
|
| world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
| global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
|
| local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
|
|
|
| if world_size > 1:
|
| torch.cuda.set_device(local_rank)
|
| device = torch.device("cuda", local_rank)
|
| dist.init_process_group(backend='nccl') if world_size > 1 else None
|
| else:
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
|
| seed = 0 + global_rank
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| torch.cuda.manual_seed_all(seed)
|
| torch.backends.cudnn.deterministic = True
|
| torch.backends.cudnn.benchmark = False
|
|
|
| batch_size = BATCH_SIZE
|
|
|
| patchilizer = Patchilizer()
|
|
|
| patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
|
| max_length=PATCH_LENGTH,
|
| max_position_embeddings=PATCH_LENGTH,
|
| n_embd=HIDDEN_SIZE,
|
| num_attention_heads=HIDDEN_SIZE//64,
|
| vocab_size=1)
|
| char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
| max_length=PATCH_SIZE+1,
|
| max_position_embeddings=PATCH_SIZE+1,
|
| hidden_size=HIDDEN_SIZE,
|
| num_attention_heads=HIDDEN_SIZE//64,
|
| vocab_size=128)
|
|
|
| model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
|
|
|
| model = model.to(device)
|
|
|
|
|
| print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
|
|
| if world_size > 1:
|
| model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
|
|
| scaler = GradScaler()
|
| is_autocast = True
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
|
| def clear_unused_tensors():
|
| gc.disable()
|
| try:
|
|
|
| if hasattr(model, "module"):
|
| model_tensors = {id(p) for p in model.module.parameters()}
|
| else:
|
| model_tensors = {id(p) for p in model.parameters()}
|
|
|
|
|
| optimizer_tensors = {
|
| id(state)
|
| for state_dict in optimizer.state.values()
|
| for state in state_dict.values()
|
| if isinstance(state, torch.Tensor)
|
| }
|
|
|
|
|
| tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
|
|
|
|
|
| tensor_refs = [weakref.ref(tensor) for tensor in tensors]
|
|
|
| for tensor_ref in tensor_refs:
|
| tensor = tensor_ref()
|
| if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors:
|
|
|
| tensor.detach_()
|
| del tensor
|
| except:
|
| pass
|
|
|
| finally:
|
| gc.enable()
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
|
|
| def collate_batch(input_batches):
|
|
|
| input_patches, input_masks = zip(*input_batches)
|
| input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=0)
|
| input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0)
|
|
|
| return input_patches.to(device), input_masks.to(device)
|
|
|
| def split_into_minibatches(input_patches, input_masks, minibatch_size):
|
| minibatches = []
|
| for start_idx in range(0, len(input_patches), minibatch_size):
|
| end_idx = start_idx + minibatch_size
|
| minibatch_patches = input_patches[start_idx:end_idx]
|
| minibatch_masks = input_masks[start_idx:end_idx]
|
| minibatches.append((minibatch_patches, minibatch_masks))
|
| return minibatches
|
|
|
| class NotaGenDataset(Dataset):
|
| def __init__(self, filenames):
|
| self.filenames = filenames
|
|
|
| def __len__(self):
|
| return len(self.filenames)
|
|
|
| def __getitem__(self, idx):
|
|
|
| filepath = self.filenames[idx]['path']
|
| ori_key = Mode2Key[self.filenames[idx]['key']]
|
|
|
|
|
| ori_key_index = Key2index[ori_key]
|
| available_index = [(ori_key_index + offset) % 12 for offset in range(-3, 4)]
|
| index_prob = [1/16, 2/16, 3/16, 4/16, 3/16, 2/16, 1/16]
|
| index_prob_range = [0] + [sum(index_prob[0 : i + 1]) for i in range(len(index_prob))]
|
| random_number = random.random()
|
| for i in range(len(index_prob_range) - 1):
|
| if index_prob_range[i] <= random_number < index_prob_range[i + 1]:
|
| des_key_index = available_index[i]
|
| if des_key_index == 1:
|
| des_key = 'Db' if random.random() < 0.8 else 'C#'
|
| elif des_key_index == 11:
|
| des_key = 'B' if random.random() < 0.8 else 'Cb'
|
| elif des_key_index == 6:
|
| des_key = 'F#' if random.random() < 0.5 else 'Gb'
|
| else:
|
| des_key = Index2Key[des_key_index]
|
|
|
| folder = os.path.dirname(filepath)
|
| name = os.path.split(filepath)[-1]
|
| des_filepath = os.path.join(folder, des_key, name + '_' + des_key + '.abc')
|
|
|
| with open(des_filepath, 'r', encoding='utf-8') as f:
|
| abc_text = f.read()
|
|
|
| file_bytes = patchilizer.encode_train(abc_text)
|
| file_masks = [1] * len(file_bytes)
|
|
|
| file_bytes = torch.tensor(file_bytes, dtype=torch.long)
|
| file_masks = torch.tensor(file_masks, dtype=torch.long)
|
|
|
| return file_bytes, file_masks
|
|
|
|
|
| def process_one_batch(batch):
|
| input_patches, input_masks = batch
|
| loss = model(input_patches, input_masks).loss
|
|
|
|
|
| if world_size > 1:
|
| loss = loss.unsqueeze(0)
|
| dist.reduce(loss, dst=0)
|
| loss = loss / world_size
|
| dist.broadcast(loss, src=0)
|
|
|
| return loss
|
|
|
|
|
|
|
| def train_epoch(epoch):
|
| tqdm_train_set = tqdm(train_set)
|
| total_train_loss = 0
|
| iter_idx = 1
|
| model.train()
|
| train_steps = (epoch-1)*len(train_set)
|
|
|
| for batch in tqdm_train_set:
|
| minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
|
| for minibatch in minibatches:
|
| with autocast():
|
| loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
|
| scaler.scale(loss).backward()
|
| total_train_loss += loss.item()
|
| scaler.step(optimizer)
|
| scaler.update()
|
|
|
| lr_scheduler.step()
|
| model.zero_grad(set_to_none=True)
|
| tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
|
| train_steps += 1
|
|
|
|
|
| if global_rank==0 and WANDB_LOGGING:
|
| wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
|
|
|
| iter_idx += 1
|
| if iter_idx % 1000 == 0:
|
| clear_unused_tensors()
|
|
|
| return total_train_loss / (iter_idx-1)
|
|
|
|
|
| def eval_epoch():
|
| tqdm_eval_set = tqdm(eval_set)
|
| total_eval_loss = 0
|
| total_eval_bpb = 0
|
| iter_idx = 1
|
| model.eval()
|
|
|
|
|
| for batch in tqdm_eval_set:
|
| minibatches = split_into_minibatches(batch[0], batch[1], BATCH_SIZE//ACCUMULATION_STEPS)
|
| for minibatch in minibatches:
|
| with torch.no_grad():
|
| loss = process_one_batch(minibatch) / ACCUMULATION_STEPS
|
| total_eval_loss += loss.item()
|
| tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
|
| iter_idx += 1
|
| return total_eval_loss / (iter_idx-1)
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
|
|
| if WANDB_LOGGING and global_rank==0:
|
| wandb.login(key=WANDB_KEY)
|
| wandb.init(project="notagen",
|
| name=WANDB_NAME)
|
|
|
|
|
| with open(DATA_TRAIN_INDEX_PATH, "r", encoding="utf-8") as f:
|
| print("Loading Data...")
|
| train_files = []
|
| for line in f:
|
| train_files.append(json.loads(line))
|
|
|
| with open(DATA_EVAL_INDEX_PATH, "r", encoding="utf-8") as f:
|
| print("Loading Data...")
|
| eval_files = []
|
| for line in f:
|
| eval_files.append(json.loads(line))
|
|
|
| if len(eval_files) == 0:
|
| train_files, eval_files = split_data(train_files)
|
|
|
| train_batch_nums = int(len(train_files) / batch_size)
|
| eval_batch_nums = int(len(eval_files) / batch_size)
|
|
|
| random.shuffle(train_files)
|
| random.shuffle(eval_files)
|
|
|
| train_files = train_files[:train_batch_nums*batch_size]
|
| eval_files = eval_files[:eval_batch_nums*batch_size]
|
|
|
| train_set = NotaGenDataset(train_files)
|
| eval_set = NotaGenDataset(eval_files)
|
|
|
| train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=local_rank)
|
| eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=local_rank)
|
|
|
| train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
|
| eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
|
|
|
| lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=1000)
|
|
|
| model = model.to(device)
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
|
| if not LOAD_FROM_CHECKPOINT:
|
| if os.path.exists(PRETRAINED_PATH):
|
|
|
| checkpoint = torch.load(PRETRAINED_PATH, map_location='cpu')
|
|
|
|
|
|
|
| if torch.cuda.device_count() > 1:
|
|
|
| cpu_model = deepcopy(model.module)
|
| cpu_model.load_state_dict(checkpoint['model'])
|
| model.module.load_state_dict(cpu_model.state_dict())
|
| else:
|
|
|
| cpu_model = deepcopy(model)
|
| cpu_model.load_state_dict(checkpoint['model'])
|
| model.load_state_dict(cpu_model.state_dict())
|
|
|
| print(f"Successfully Loaded Pretrained Checkpoint at Epoch {checkpoint['epoch']} with Loss {checkpoint['min_eval_loss']}")
|
|
|
| pre_epoch = 0
|
| best_epoch = 0
|
| min_eval_loss = 100
|
| else:
|
| raise Exception('Pre-trained Checkpoint not found. Please check your pre-trained ckpt path.')
|
|
|
| else:
|
| if os.path.exists(WEIGHTS_PATH):
|
|
|
| checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
|
|
|
|
|
|
|
| if torch.cuda.device_count() > 1:
|
|
|
| cpu_model = deepcopy(model.module)
|
| cpu_model.load_state_dict(checkpoint['model'])
|
| model.module.load_state_dict(cpu_model.state_dict())
|
| else:
|
|
|
| cpu_model = deepcopy(model)
|
| cpu_model.load_state_dict(checkpoint['model'])
|
| model.load_state_dict(cpu_model.state_dict())
|
| optimizer.load_state_dict(checkpoint['optimizer'])
|
| lr_scheduler.load_state_dict(checkpoint['lr_sched'])
|
| pre_epoch = checkpoint['epoch']
|
| best_epoch = checkpoint['best_epoch']
|
| min_eval_loss = checkpoint['min_eval_loss']
|
| print("Successfully Loaded Checkpoint from Epoch %d" % pre_epoch)
|
| checkpoint = None
|
|
|
| else:
|
| raise Exception('Checkpoint not found to continue training. Please check your parameter settings.')
|
|
|
|
|
| for epoch in range(1+pre_epoch, NUM_EPOCHS+1):
|
| train_sampler.set_epoch(epoch)
|
| eval_sampler.set_epoch(epoch)
|
| print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
|
| train_loss = train_epoch(epoch)
|
| eval_loss = eval_epoch()
|
| if global_rank==0:
|
| with open(LOGS_PATH,'a') as f:
|
| f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
|
| if eval_loss < min_eval_loss:
|
| best_epoch = epoch
|
| min_eval_loss = eval_loss
|
| checkpoint = {
|
| 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
| 'optimizer': optimizer.state_dict(),
|
| 'lr_sched': lr_scheduler.state_dict(),
|
| 'epoch': epoch,
|
| 'best_epoch': best_epoch,
|
| 'min_eval_loss': min_eval_loss
|
| }
|
| torch.save(checkpoint, WEIGHTS_PATH)
|
|
|
| if world_size > 1:
|
| dist.barrier()
|
|
|
| if global_rank==0:
|
| print("Best Eval Epoch : "+str(best_epoch))
|
| print("Min Eval Loss : "+str(min_eval_loss))
|
|
|