import os import gc import time import math import json import wandb import torch import random import numpy as np from abctoolkit.transpose import Key2index, Key2Mode from utils import * from config import * from data import generate_preference_dict from tqdm import tqdm from copy import deepcopy from torch.utils.data import Dataset, DataLoader from transformers import GPT2Config, get_scheduler, get_constant_schedule_with_warmup device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Set random seed seed = 0 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False patchilizer = Patchilizer() patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS, max_length=PATCH_LENGTH, max_position_embeddings=PATCH_LENGTH, n_embd=HIDDEN_SIZE, num_attention_heads=HIDDEN_SIZE//64, vocab_size=1) char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS, max_length=PATCH_SIZE+1, max_position_embeddings=PATCH_SIZE+1, hidden_size=HIDDEN_SIZE, num_attention_heads=HIDDEN_SIZE//64, vocab_size=128) model_ref = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config) model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config) model_ref = model_ref.to(device) model = model.to(device) # print parameter number print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad))) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) def collate_batch(input_batches): pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = input_batches pos_input_patches = pos_input_patches.unsqueeze(0) pos_input_masks = pos_input_masks.unsqueeze(0) neg_input_patches = neg_input_patches.unsqueeze(0) neg_input_masks = neg_input_masks.unsqueeze(0) pos_input_patches = torch.nn.utils.rnn.pad_sequence(pos_input_patches, batch_first=True, padding_value=0) pos_input_masks = torch.nn.utils.rnn.pad_sequence(pos_input_masks, batch_first=True, padding_value=0) neg_input_patches = torch.nn.utils.rnn.pad_sequence(neg_input_patches, batch_first=True, padding_value=0) neg_input_masks = torch.nn.utils.rnn.pad_sequence(neg_input_masks, batch_first=True, padding_value=0) return (pos_input_patches.to(device), pos_input_masks.to(device), neg_input_patches.to(device), neg_input_masks.to(device)) class NotaGenDataset(Dataset): def __init__(self, preference_dict): self.preference_dict = preference_dict self.pair_list = [] for pos_filepath in self.preference_dict['chosen']: for neg_filepath in self.preference_dict['rejected']: self.pair_list.append({'chosen': pos_filepath, 'rejected': neg_filepath}) def __len__(self): return len(self.pair_list) def __getitem__(self, idx): try: pair = self.pair_list[idx] pos_filepath = pair['chosen'] neg_filepath = pair['rejected'] with open(pos_filepath, 'r', encoding='utf-8') as f: pos_abc_text = f.read() with open(neg_filepath, 'r', encoding='utf-8') as f: neg_abc_text = f.read() pos_file_bytes = patchilizer.encode(pos_abc_text) pos_file_masks = [1] * len(pos_file_bytes) neg_file_bytes = patchilizer.encode(neg_abc_text) neg_file_masks = [1] * len(neg_file_bytes) pos_file_bytes = torch.tensor(pos_file_bytes, dtype=torch.long) pos_file_masks = torch.tensor(pos_file_masks, dtype=torch.long) neg_file_bytes = torch.tensor(neg_file_bytes, dtype=torch.long) neg_file_masks = torch.tensor(neg_file_masks, dtype=torch.long) return pos_file_bytes, pos_file_masks, neg_file_bytes, neg_file_masks except Exception as e: print(e) return self.__getitem__((idx+1) % len(self.pair_list)) def process_one_batch(batch): pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = batch pos_input_patches_ref = pos_input_patches.clone() pos_input_masks_ref = pos_input_masks.clone() neg_input_patches_ref = neg_input_patches.clone() neg_input_masks_ref = neg_input_masks.clone() policy_pos_logps = model(pos_input_patches, pos_input_masks) policy_neg_logps = model(neg_input_patches, neg_input_masks) with torch.no_grad(): ref_pos_logps = model_ref(pos_input_patches_ref, pos_input_masks_ref).detach() ref_neg_logps = model_ref(neg_input_patches_ref, neg_input_masks_ref).detach() logits = (policy_pos_logps - policy_neg_logps) - (ref_pos_logps - ref_neg_logps) loss = - torch.nn.functional.logsigmoid(BETA * (logits - LAMBDA * max(0, ref_pos_logps - policy_pos_logps))) return loss # train if __name__ == "__main__": # Initialize wandb if WANDB_LOGGING: wandb.login(key=WANDB_KEY) wandb.init(project="notagen", name=WANDB_NAME) # load data with open(DATA_INDEX_PATH, 'r') as f: preference_dict = json.loads(f.read()) train_set = NotaGenDataset(preference_dict) # Load model actor/ref if os.path.exists(PRETRAINED_PATH): checkpoint = torch.load(PRETRAINED_PATH, map_location='cpu') cpu_model = deepcopy(model) cpu_model.load_state_dict(checkpoint['model']) model.load_state_dict(cpu_model.state_dict()) cpu_model_ref = deepcopy(model_ref) cpu_model_ref.load_state_dict(checkpoint['model']) model_ref.load_state_dict(cpu_model_ref.state_dict()) else: raise Exception('No pre-trained model loaded.') model.train() total_train_loss = 0 iter_idx = 1 tqdm_set = tqdm(range(OPTIMIZATION_STEPS)) for i in tqdm_set: idx = random.randint(0, len(train_set)-1) batch = train_set[idx] batch = collate_batch(batch) loss = process_one_batch(batch) total_train_loss += loss.item() loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(),max_norm=1.0 ) optimizer.step() model.zero_grad(set_to_none=True) tqdm_set.set_postfix({'train_loss': total_train_loss / (i + 1)}) # Log the training loss to wandb if WANDB_LOGGING: wandb.log({"train_loss": total_train_loss / (i + 1)}, step=i+1) checkpoint = {'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict()} torch.save(checkpoint, WEIGHTS_PATH)