File size: 6,315 Bytes
f6d0071 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | import logging
import math
import time
import copy
import torch
import torch.distributed as dist
try:
import wandb
except ImportError:
wandb = None
from open_lm.data import sample_chunk
from open_lm.distributed import is_master
from open_lm.precision import get_autocast
from open_lm.meters import (
AverageMeter,
ConfidenceIntervalMeter,
gather_meters,
)
@torch.inference_mode()
def evaluate(model, data, start_epoch, args, writer):
"""
evaluates perplexity on validation data
"""
if is_master(args):
print("=> begin evaluation")
device = torch.device(args.device)
autocast = get_autocast(args.precision)
model.eval()
data["val"].set_epoch(start_epoch) # set epoch in process safe manner via sampler or shared_epoch
dataloader = data["val"].dataloader
# NOTE: dataloader.num_batches = 0 corresponds to exhausting iterator by convention
exhaust_loader = dataloader.num_batches == 0
losses_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
sps_m = AverageMeter()
spspg_m = AverageMeter()
losses_seq_ci_m = ConfidenceIntervalMeter()
losses_tok_ci_m = ConfidenceIntervalMeter()
end = time.time()
loss = torch.nn.CrossEntropyLoss(reduction="none")
# by default the dataloader will be exhausted
for i, batch in enumerate(dataloader):
if i == dataloader.num_batches and not exhaust_loader:
break
data_time_m.update(time.time() - end)
with autocast():
if args.dataset_type == "jsonl":
inputs, targets = batch
inputs = torch.LongTensor(inputs).to(device)
targets = torch.LongTensor(targets).to(device)
inputs = inputs[:, :-1]
targets = targets[:, 1:]
else:
(texts,) = batch
texts = torch.LongTensor(texts).to(device)
inputs, targets = sample_chunk(texts, args)
out, _, _ = model(inputs) # [per_gpu_bs, seq_len, vocab_size]
bs, seq_len = targets.shape
targets = targets.reshape(-1)
total_loss = loss(out.reshape(-1, args.vocab_size), targets) # [bs * seq_len]
# cross entropy ignores -100 values in loss computation
mask = targets != -100
# reshape and average for sequence losses
sum_loss_per_seq = torch.sum(total_loss.reshape(bs, seq_len), -1)
num_toks_per_seq = torch.sum(mask.reshape(bs, seq_len), -1).float()
losses_seq_ci_m.update((sum_loss_per_seq / num_toks_per_seq).cpu().numpy())
# individual token losses
losses_tok_ci_m.update(total_loss[mask].cpu().numpy())
# compute average loss for the mini-batch
total_loss = total_loss[mask].mean()
losses_m.update(total_loss.item(), n=inputs.shape[0])
batch_time_m.update(time.time() - end)
end = time.time()
sps_m.update(inputs.numel() * args.world_size / batch_time_m.val)
spspg_m.update(inputs.numel() / batch_time_m.val)
if args.distributed:
dist.barrier()
if args.world_size > 1:
# in this case we need to gather the loss. for simplicity we gather the meters only on main proc
# Save eval loss / etc.
meters = [
losses_m,
batch_time_m,
data_time_m,
sps_m,
spspg_m,
losses_seq_ci_m,
losses_tok_ci_m,
]
# meters on master will become global meters, other meters will remain local
losses_m, batch_time_m, data_time_m, sps_m, spspg_m, losses_seq_ci_m, losses_tok_ci_m = gather_meters(
meters, args
)
if args.distributed:
dist.barrier()
lower_seq, upper_seq, lower_tok, upper_tok = -1.0, -1.0, -1.0, -1.0
if args.val_seq_ci:
lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci(args.val_max_pop_ci, args.val_iter_ci)
if args.val_tok_ci:
lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci(args.val_max_pop_ci, args.val_iter_ci)
num_seqs = sum([len(p) for p in losses_seq_ci_m.points])
num_toks = sum([len(p) for p in losses_tok_ci_m.points])
# Save eval loss / etc.
log_data = {
"loss": losses_m.avg,
"data_time": data_time_m.avg,
"batch_time": batch_time_m.avg,
"samples_per_second": sps_m.avg,
"samples_per_second_per_gpu": spspg_m.avg,
"loss_sequences_lower_95": lower_seq,
"loss_sequences_upper_95": upper_seq,
"loss_tokens_lower_95": lower_tok,
"loss_tokens_upper_95": upper_tok,
"sequences": num_seqs,
"tokens": num_toks,
}
if args.train_num_samples is not None:
log_data["train_tokens"] = start_epoch * args.train_num_samples * args.seq_len
for name, val in log_data.items():
name = "valid/" + name
if writer is not None:
writer.add_scalar(name, val, start_epoch)
if args.wandb and is_master(args):
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "epoch": start_epoch, "tokens": log_data["tokens"]})
if is_master(args):
# meters on masters should be global
print(f"evaluation on: {args.val_data}")
print(f"evaluation loss: {losses_m.avg}")
print(f"num loss point evaluations {losses_m.count}")
print(f"evaluation perplexity: {math.exp(losses_m.avg)}")
print(f"num seqs: {num_seqs}")
print(f"num tokens: {num_toks}")
log_data["checkpoint_path"] = args.resume
log_data["val_data"] = args.val_data
log_data["model"] = args.hf_model if args.hf_model else args.model
return log_data
def evaluate_loop(model, data_list, start_epoch, args, writer):
log_data_list = []
for i, data in enumerate(data_list):
args_copy = copy.deepcopy(args)
args_copy.val_data = [args.val_data[i]]
args_copy.val_data_key = args.val_data_key[i]
if args.distributed:
dist.barrier()
log_data_list.append(evaluate(model, data, start_epoch, args_copy, writer))
if args.distributed:
dist.barrier()
return log_data_list
|