File size: 7,040 Bytes
ee3e701 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from contextlib import contextmanager
import torch
import torch.distributed as dist
from tqdm import tqdm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex
@contextmanager
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
if not gpc.is_using_pp():
prev_data_process_func = trainer.schedule.data_process_func
prev_grad_accum_size = trainer.schedule._grad_accum_size
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
prev_metric_hooks = trainer.schedule._hooks
try:
trainer.schedule.data_process_func = None
trainer.schedule._grad_accum_size = grad_accum_size
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
trainer.schedule._hooks = metric_hook_list
yield
finally:
trainer.schedule.data_process_func = prev_data_process_func
trainer.schedule._grad_accum_size = prev_grad_accum_size
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
trainer.schedule._hooks = prev_metric_hooks
@contextmanager
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
if gpc.is_using_pp():
pre_data_process_func = trainer.schedule.data_process_func
prev_num_microbatches = trainer.schedule.num_microbatches
prev_tensor_shape = trainer.schedule.tensor_shape
prev_metric_hooks = trainer.schedule._hooks
try:
trainer.schedule.data_process_func = None
trainer.schedule.num_microbatches = num_microbatches
trainer.schedule.tensor_shape = tensor_shape
trainer.schedule._hooks = metric_hook_list
yield
finally:
trainer.schedule.data_process_func = pre_data_process_func
trainer.schedule.num_microbatches = prev_num_microbatches
trainer.schedule.tensor_shape = prev_tensor_shape
trainer.schedule._hooks = prev_metric_hooks
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
gpc.config.parallel.sequence_parallel = False
yield
finally:
gpc.config.parallel.sequence_parallel = prev_mode
def evaluate_on_val_dls(
trainer,
val_dls,
writer,
logger,
step_count,
update_panel: bool = False,
streaming: bool = False,
):
with switch_sequence_parallel_mode():
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
for val_name, val_dl in val_dls.items():
if not streaming and len(val_dl) == 0 and verbose:
logger.info(f"Validation dataset: {val_name} is empty")
continue
val_metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
)
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
enumerate(val_dl),
desc="Val.",
total=len(val_dl) if not streaming else None,
position=1,
disable=not verbose,
leave=False,
):
with torch.inference_mode():
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
if isinstance(loss, dict):
loss = sum(loss.values())
val_loss += loss.item()
assert val_idx != -1
dist.barrier()
val_res = val_metric.get_metric()
if verbose and (streaming or len(val_dl) != 0):
val_loss = val_loss / (val_idx + 1 + 1e-6)
infos = {
"step": step_count,
f"val/{val_name}_loss": val_loss,
f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"],
}
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra={
"step": step_count,
"val_loss": val_loss,
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
},
)
else:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
)
trainer.train()
torch.cuda.empty_cache()
dist.barrier()
|