File size: 7,072 Bytes
12b2634 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import gc
import time
import math
import json
import wandb
import torch
import random
import numpy as np
from abctoolkit.transpose import Key2index, Key2Mode
from utils import *
from config import *
from data import generate_preference_dict
from tqdm import tqdm
from copy import deepcopy
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Config, get_scheduler, get_constant_schedule_with_warmup
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Set random seed
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
patchilizer = Patchilizer()
patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
max_length=PATCH_LENGTH,
max_position_embeddings=PATCH_LENGTH,
n_embd=HIDDEN_SIZE,
num_attention_heads=HIDDEN_SIZE//64,
vocab_size=1)
char_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
max_length=PATCH_SIZE+1,
max_position_embeddings=PATCH_SIZE+1,
hidden_size=HIDDEN_SIZE,
num_attention_heads=HIDDEN_SIZE//64,
vocab_size=128)
model_ref = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=char_config)
model_ref = model_ref.to(device)
model = model.to(device)
# print parameter number
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
def collate_batch(input_batches):
pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = input_batches
pos_input_patches = pos_input_patches.unsqueeze(0)
pos_input_masks = pos_input_masks.unsqueeze(0)
neg_input_patches = neg_input_patches.unsqueeze(0)
neg_input_masks = neg_input_masks.unsqueeze(0)
pos_input_patches = torch.nn.utils.rnn.pad_sequence(pos_input_patches, batch_first=True, padding_value=0)
pos_input_masks = torch.nn.utils.rnn.pad_sequence(pos_input_masks, batch_first=True, padding_value=0)
neg_input_patches = torch.nn.utils.rnn.pad_sequence(neg_input_patches, batch_first=True, padding_value=0)
neg_input_masks = torch.nn.utils.rnn.pad_sequence(neg_input_masks, batch_first=True, padding_value=0)
return (pos_input_patches.to(device), pos_input_masks.to(device),
neg_input_patches.to(device), neg_input_masks.to(device))
class NotaGenDataset(Dataset):
def __init__(self, preference_dict):
self.preference_dict = preference_dict
self.pair_list = []
for pos_filepath in self.preference_dict['chosen']:
for neg_filepath in self.preference_dict['rejected']:
self.pair_list.append({'chosen': pos_filepath, 'rejected': neg_filepath})
def __len__(self):
return len(self.pair_list)
def __getitem__(self, idx):
try:
pair = self.pair_list[idx]
pos_filepath = pair['chosen']
neg_filepath = pair['rejected']
with open(pos_filepath, 'r', encoding='utf-8') as f:
pos_abc_text = f.read()
with open(neg_filepath, 'r', encoding='utf-8') as f:
neg_abc_text = f.read()
pos_file_bytes = patchilizer.encode(pos_abc_text)
pos_file_masks = [1] * len(pos_file_bytes)
neg_file_bytes = patchilizer.encode(neg_abc_text)
neg_file_masks = [1] * len(neg_file_bytes)
pos_file_bytes = torch.tensor(pos_file_bytes, dtype=torch.long)
pos_file_masks = torch.tensor(pos_file_masks, dtype=torch.long)
neg_file_bytes = torch.tensor(neg_file_bytes, dtype=torch.long)
neg_file_masks = torch.tensor(neg_file_masks, dtype=torch.long)
return pos_file_bytes, pos_file_masks, neg_file_bytes, neg_file_masks
except Exception as e:
print(e)
return self.__getitem__((idx+1) % len(self.pair_list))
def process_one_batch(batch):
pos_input_patches, pos_input_masks, neg_input_patches, neg_input_masks = batch
pos_input_patches_ref = pos_input_patches.clone()
pos_input_masks_ref = pos_input_masks.clone()
neg_input_patches_ref = neg_input_patches.clone()
neg_input_masks_ref = neg_input_masks.clone()
policy_pos_logps = model(pos_input_patches, pos_input_masks)
policy_neg_logps = model(neg_input_patches, neg_input_masks)
with torch.no_grad():
ref_pos_logps = model_ref(pos_input_patches_ref, pos_input_masks_ref).detach()
ref_neg_logps = model_ref(neg_input_patches_ref, neg_input_masks_ref).detach()
logits = (policy_pos_logps - policy_neg_logps) - (ref_pos_logps - ref_neg_logps)
loss = - torch.nn.functional.logsigmoid(BETA * (logits - LAMBDA * max(0, ref_pos_logps - policy_pos_logps)))
return loss
# train
if __name__ == "__main__":
# Initialize wandb
if WANDB_LOGGING:
wandb.login(key=WANDB_KEY)
wandb.init(project="notagen",
name=WANDB_NAME)
# load data
with open(DATA_INDEX_PATH, 'r') as f:
preference_dict = json.loads(f.read())
train_set = NotaGenDataset(preference_dict)
# Load model actor/ref
if os.path.exists(PRETRAINED_PATH):
checkpoint = torch.load(PRETRAINED_PATH, map_location='cpu')
cpu_model = deepcopy(model)
cpu_model.load_state_dict(checkpoint['model'])
model.load_state_dict(cpu_model.state_dict())
cpu_model_ref = deepcopy(model_ref)
cpu_model_ref.load_state_dict(checkpoint['model'])
model_ref.load_state_dict(cpu_model_ref.state_dict())
else:
raise Exception('No pre-trained model loaded.')
model.train()
total_train_loss = 0
iter_idx = 1
tqdm_set = tqdm(range(OPTIMIZATION_STEPS))
for i in tqdm_set:
idx = random.randint(0, len(train_set)-1)
batch = train_set[idx]
batch = collate_batch(batch)
loss = process_one_batch(batch)
total_train_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(),max_norm=1.0 )
optimizer.step()
model.zero_grad(set_to_none=True)
tqdm_set.set_postfix({'train_loss': total_train_loss / (i + 1)})
# Log the training loss to wandb
if WANDB_LOGGING:
wandb.log({"train_loss": total_train_loss / (i + 1)}, step=i+1)
checkpoint = {'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict()}
torch.save(checkpoint, WEIGHTS_PATH)
|