|
|
|
|
|
|
|
|
''' |
|
|
* @Desc: train GPT2 from scratch/ fine tuning. |
|
|
Modified based on Huggingface GPT-2 implementation |
|
|
''' |
|
|
|
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import logging |
|
|
import time |
|
|
import tqdm |
|
|
import datetime |
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from os.path import join |
|
|
from torch.distributed import get_rank, get_world_size |
|
|
|
|
|
from lsp_model import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam |
|
|
from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length |
|
|
from gpt2_training.eval_utils import eval_model_loss |
|
|
|
|
|
from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader |
|
|
|
|
|
|
|
|
from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
|
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
INF = 100000000 |
|
|
CACHE_EMPTY_STEP = 10000 |
|
|
EVAL_STEP = 100000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--model_name_or_path', type=str, |
|
|
help='pretrained model name or path to local checkpoint') |
|
|
parser.add_argument("--seed", type=int, default=42) |
|
|
parser.add_argument("--max_seq_length", type=int, default=128) |
|
|
|
|
|
parser.add_argument("--skip_eval", action='store_true', |
|
|
help='If true, skip evaluation.') |
|
|
parser.add_argument("--init_checkpoint", type=str) |
|
|
parser.add_argument("--train_input_file", type=str) |
|
|
parser.add_argument("--eval_input_file", type=str) |
|
|
parser.add_argument("--continue_from", type=int, default=0) |
|
|
|
|
|
parser.add_argument("--train_batch_size", type=int, default=4, |
|
|
help="batch size now means per GPU per step") |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=2, |
|
|
help="to increase effective batch size " |
|
|
"and reduce synchronization") |
|
|
parser.add_argument("--eval_batch_size", type=int, default=4) |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-5) |
|
|
parser.add_argument("--num_optim_steps", type=int, default=1000000, |
|
|
help="new API specifies num update steps") |
|
|
parser.add_argument("--valid_step", type=int, default=10000, |
|
|
help="how many optim steps between validations") |
|
|
parser.add_argument("--warmup_proportion", type=float, default=0.1) |
|
|
parser.add_argument("--warmup_steps", type=int, default=16000) |
|
|
|
|
|
parser.add_argument("--normalize_data", type=boolean_string, default=True) |
|
|
parser.add_argument("--fp16", type=boolean_string, default=True) |
|
|
parser.add_argument("--lr_schedule", type=str, |
|
|
choices=['noam', 'noamwd', 'BERT', 'None'], default='noam') |
|
|
parser.add_argument("--loss_scale", type=float, default=0) |
|
|
parser.add_argument("--no_token_id", type=boolean_string, default=True) |
|
|
|
|
|
parser.add_argument("--output_dir", type=str) |
|
|
parser.add_argument("--log_dir", type=str) |
|
|
parser.add_argument('--pbar', type=boolean_string, default=True, help='turn on progress bar') |
|
|
|
|
|
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, |
|
|
help='for torch.distributed') |
|
|
parser.add_argument('--config', help='JSON config file') |
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.config is not None: |
|
|
|
|
|
opts = json.load(open(args.config)) |
|
|
for k, v in opts.items(): |
|
|
if isinstance(v, str): |
|
|
|
|
|
if 'PHILLY_JOB_DIRECTORY' in v: |
|
|
v = v.replace('PHILLY_JOB_DIRECTORY', |
|
|
os.environ['PHILLY_JOB_DIRECTORY']) |
|
|
elif 'PHILLY_LOG_DIRECTORY' in v: |
|
|
v = v.replace('PHILLY_LOG_DIRECTORY', |
|
|
os.environ['PHILLY_LOG_DIRECTORY']) |
|
|
setattr(args, k, v) |
|
|
|
|
|
|
|
|
argv = sys.argv[1:] |
|
|
overrides, _ = parser.parse_known_args(argv) |
|
|
for k, v in vars(overrides).items(): |
|
|
if f'--{k}' in argv: |
|
|
setattr(args, k, v) |
|
|
setattr(args, 'local_rank', overrides.local_rank) |
|
|
|
|
|
|
|
|
assert args.train_batch_size % args.gradient_accumulation_steps == 0, \ |
|
|
'batch size % gradient accumulation steps != 0!' |
|
|
args.train_batch_size = (args.train_batch_size |
|
|
// args.gradient_accumulation_steps) |
|
|
logger.info('train batch size = {}, ' |
|
|
'new train batch size (after gradient accumulation) = {}'.format( |
|
|
args.train_batch_size*args.gradient_accumulation_steps, |
|
|
args.train_batch_size)) |
|
|
|
|
|
|
|
|
if args.local_rank == -1: |
|
|
logger.info('CUDA available? {}'.format(str(torch.cuda.is_available()))) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
n_gpu = torch.cuda.device_count() |
|
|
args.device, args.n_gpu = device, n_gpu |
|
|
else: |
|
|
|
|
|
torch.cuda.set_device(args.local_rank) |
|
|
device = torch.device("cuda", args.local_rank) |
|
|
|
|
|
|
|
|
torch.distributed.init_process_group(backend='nccl') |
|
|
n_gpu = torch.distributed.get_world_size() |
|
|
args.device, args.n_gpu = device, 1 |
|
|
logger.info("device: {} n_gpu: {}, distributed training: {}, " |
|
|
"16-bits training: {}".format( |
|
|
device, n_gpu, bool(args.local_rank != -1), args.fp16)) |
|
|
|
|
|
np.random.seed(args.seed) |
|
|
torch.random.manual_seed(args.seed) |
|
|
torch.cuda.manual_seed(args.seed) |
|
|
if n_gpu > 0: |
|
|
torch.cuda.manual_seed_all(args.seed) |
|
|
|
|
|
timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S') |
|
|
output_dir = join(args.output_dir, |
|
|
'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate, |
|
|
args.train_batch_size, n_gpu, |
|
|
timestamp)) |
|
|
log_dir = args.log_dir if args.log_dir is not None and len(args.log_dir) > 0 else output_dir |
|
|
if args.local_rank == -1 or get_rank() == 0: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
logger.info('Input Argument Information') |
|
|
args_dict = vars(args) |
|
|
for a in args_dict: |
|
|
logger.info('%-28s %s' % (a, args_dict[a])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) |
|
|
|
|
|
config = GPT2Config.from_json_file( |
|
|
join(args.model_name_or_path, 'config.json')) |
|
|
|
|
|
if args.local_rank == -1: |
|
|
train_dataloader = BucketingDataLoader(args.train_input_file, |
|
|
args.train_batch_size, |
|
|
args.max_seq_length) |
|
|
else: |
|
|
train_dataloader = DistributedBucketingDataLoader( |
|
|
get_rank(), get_world_size(), |
|
|
args.train_input_file, args.train_batch_size, |
|
|
args.max_seq_length) |
|
|
|
|
|
eval_dataloader_loss = DynamicBatchingLoader( |
|
|
args.eval_input_file, enc, args.normalize_data, |
|
|
args.eval_batch_size, args.max_seq_length) |
|
|
|
|
|
eval_dataloader_gen = get_eval_list_same_length( |
|
|
args.eval_input_file, enc, args.eval_batch_size, True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = load_model(GPT2LMHeadModel(config), args.init_checkpoint, |
|
|
args, verbose=True) |
|
|
if args.local_rank != -1: |
|
|
|
|
|
params = [p.data for p in model.parameters()] |
|
|
all_reduce_and_rescale_tensors( |
|
|
params, float(torch.distributed.get_world_size())) |
|
|
|
|
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
|
|
total_params = sum([np.prod(p.size()) for p in model_parameters]) |
|
|
logger.info('Number of parameter = {}'.format(total_params)) |
|
|
|
|
|
param_optimizer = list(model.named_parameters()) |
|
|
no_decay = ['bias', 'ln'] |
|
|
optimizer_grouped_parameters = [ |
|
|
{'params': [p for n, p in param_optimizer |
|
|
if not any(nd in n for nd in no_decay)], |
|
|
'weight_decay': 0.01}, |
|
|
{'params': [p for n, p in param_optimizer |
|
|
if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
|
] |
|
|
|
|
|
if args.fp16: |
|
|
logger.info('in fp16, using FusedAdam') |
|
|
try: |
|
|
from apex.optimizers import FP16_Optimizer |
|
|
from apex.optimizers import FusedAdam |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Please install apex from https://www.github.com/nvidia/apex " |
|
|
"to use distributed and fp16 training.") |
|
|
|
|
|
optimizer = FusedAdam(optimizer_grouped_parameters, |
|
|
lr=args.learning_rate, |
|
|
bias_correction=False, |
|
|
max_grad_norm=1.0) |
|
|
if args.loss_scale == 0: |
|
|
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True, |
|
|
verbose=False) |
|
|
else: |
|
|
optimizer = FP16_Optimizer(optimizer, |
|
|
static_loss_scale=args.loss_scale, |
|
|
verbose=False) |
|
|
else: |
|
|
optimizer = Adam(optimizer_grouped_parameters, args.learning_rate, |
|
|
max_grad_norm=1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.local_rank == -1 or get_rank() == 0: |
|
|
train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1) |
|
|
eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1) |
|
|
print('epoch,global_step,step,mean_loss,mean_ppl,n_token_real,' |
|
|
'n_token_total,epoch_time', file=train_logger) |
|
|
print('epoch,global_step,step,eval_loss,eval_ppl', file=eval_logger) |
|
|
|
|
|
global_step = 0 |
|
|
step = 0 |
|
|
epoch = 0 |
|
|
|
|
|
if args.continue_from: |
|
|
global_step = args.continue_from |
|
|
step = global_step*2 - 1 |
|
|
|
|
|
|
|
|
if args.local_rank != -1: |
|
|
n_gpu = 1 |
|
|
if args.local_rank == -1 or get_rank() == 0: |
|
|
if args.pbar: |
|
|
pbar = tqdm.tqdm(total=args.num_optim_steps, desc=f"training") |
|
|
else: |
|
|
pbar = None |
|
|
|
|
|
while True: |
|
|
model.train() |
|
|
(tr_loss, tr_ppl, mean_ppl, nb_tr_examples, nb_tr_steps) = 0.0, 0.0, 0.0, 0, 0 |
|
|
n_token_real, n_token_total = 0, 0 |
|
|
train_start_time_epoch = time.time() |
|
|
for batch in train_dataloader: |
|
|
|
|
|
seq_len = batch[0].shape[1] |
|
|
batch = tuple(t.to(device) for t in batch) |
|
|
input_ids, position_ids, token_ids, label_ids, *_ = batch |
|
|
if args.no_token_id: |
|
|
token_ids = None |
|
|
loss, ppl = model(input_ids, position_ids, token_ids, label_ids) |
|
|
|
|
|
if n_gpu > 1: |
|
|
loss = loss.mean() |
|
|
ppl = ppl.mean() |
|
|
loss = loss / (args.train_batch_size / input_ids.shape[0]) |
|
|
if args.fp16: |
|
|
optimizer.backward(loss) |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
tr_loss += float(loss.item()) * (args.train_batch_size / input_ids.shape[0]) |
|
|
nb_tr_examples += input_ids.size(0) |
|
|
nb_tr_steps += 1 |
|
|
mean_loss = tr_loss / nb_tr_steps |
|
|
if ppl.item() < INF: |
|
|
tr_ppl += ppl.item() |
|
|
else: |
|
|
tr_ppl += mean_ppl |
|
|
mean_ppl = tr_ppl / nb_tr_steps |
|
|
|
|
|
n_token_total += input_ids.shape[0] * input_ids.shape[1] |
|
|
n_token_real += (input_ids != 0).sum().item() |
|
|
|
|
|
|
|
|
step += 1 |
|
|
if step % args.gradient_accumulation_steps == 0: |
|
|
set_lr(optimizer, global_step, |
|
|
args.lr_schedule, args.learning_rate, |
|
|
args.warmup_steps, args.warmup_proportion, |
|
|
config.n_embd, args.num_optim_steps) |
|
|
|
|
|
if args.local_rank != -1: |
|
|
grads = [p.grad.data for p in model.parameters() |
|
|
if p.requires_grad and p.grad is not None] |
|
|
all_reduce_and_rescale_tensors(grads, float(1)) |
|
|
|
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
if args.local_rank != -1: |
|
|
mean_loss = sum(all_gather_list(mean_loss)) / get_world_size() |
|
|
mean_ppl = sum(all_gather_list(mean_ppl)) / get_world_size() |
|
|
n_token_real_all_proc = sum(all_gather_list(n_token_real)) |
|
|
n_token_total_all_proc = sum(all_gather_list(n_token_total)) |
|
|
else: |
|
|
n_token_real_all_proc = n_token_real |
|
|
n_token_total_all_proc = n_token_total |
|
|
|
|
|
if args.local_rank == -1 or get_rank() == 0: |
|
|
epoch_time = time.time() - train_start_time_epoch |
|
|
if pbar is not None: |
|
|
pbar.set_postfix_str( |
|
|
f"tok/s: {n_token_real_all_proc//epoch_time//1000}k " |
|
|
f"ppl: {mean_ppl:.2f} epoch: {epoch}") |
|
|
pbar.update(1) |
|
|
print('{},{},{},{},{},{},{},{}'.format( |
|
|
epoch+1, global_step+1, step+1, mean_loss, mean_ppl, |
|
|
n_token_real_all_proc, n_token_total_all_proc, epoch_time), |
|
|
file=train_logger) |
|
|
|
|
|
if global_step % args.valid_step == 0: |
|
|
if args.local_rank == -1 or get_rank() == 0: |
|
|
|
|
|
torch.save( |
|
|
{k: (v.cpu() if v is not None else None) |
|
|
for k, v in model.state_dict().items()}, |
|
|
join(output_dir, |
|
|
f'GP2-pretrain-step-{global_step}.pkl')) |
|
|
|
|
|
eval_loss, eval_ppl = eval_model_loss( |
|
|
model, enc, eval_dataloader_loss, epoch, args) |
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
# probably use beam search only for test set |
|
|
if False: |
|
|
gen_response_beam = eval_model_generation( |
|
|
model, enc, eval_dataloader_gen, epoch, args, |
|
|
use_beam_search=True, beam_width=3) |
|
|
''' |
|
|
print('{},{},{},{},{}'.format( |
|
|
epoch+1, global_step+1, step+1, eval_loss, eval_ppl), |
|
|
file=eval_logger) |
|
|
logger.info('current learning rate: ' |
|
|
+ str(optimizer.param_groups[0]['lr'])) |
|
|
model.train() |
|
|
if global_step >= args.num_optim_steps: |
|
|
break |
|
|
|
|
|
if (step+1) % CACHE_EMPTY_STEP == 0: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if global_step >= args.num_optim_steps: |
|
|
break |
|
|
epoch += 1 |
|
|
|
|
|
|
|
|
if args.local_rank == -1 or get_rank() == 0: |
|
|
if pbar is not None: |
|
|
pbar.close() |
|
|
train_logger.close() |
|
|
eval_logger.close() |
|
|
|