File size: 7,040 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from contextlib import contextmanager

import torch
import torch.distributed as dist
from tqdm import tqdm

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex


@contextmanager
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
    if not gpc.is_using_pp():
        prev_data_process_func = trainer.schedule.data_process_func
        prev_grad_accum_size = trainer.schedule._grad_accum_size
        prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
        prev_metric_hooks = trainer.schedule._hooks
        try:
            trainer.schedule.data_process_func = None
            trainer.schedule._grad_accum_size = grad_accum_size
            trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
            trainer.schedule._hooks = metric_hook_list
            yield
        finally:
            trainer.schedule.data_process_func = prev_data_process_func
            trainer.schedule._grad_accum_size = prev_grad_accum_size
            trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
            trainer.schedule._hooks = prev_metric_hooks


@contextmanager
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
    if gpc.is_using_pp():
        pre_data_process_func = trainer.schedule.data_process_func
        prev_num_microbatches = trainer.schedule.num_microbatches
        prev_tensor_shape = trainer.schedule.tensor_shape
        prev_metric_hooks = trainer.schedule._hooks
        try:
            trainer.schedule.data_process_func = None
            trainer.schedule.num_microbatches = num_microbatches
            trainer.schedule.tensor_shape = tensor_shape
            trainer.schedule._hooks = metric_hook_list
            yield
        finally:
            trainer.schedule.data_process_func = pre_data_process_func
            trainer.schedule.num_microbatches = prev_num_microbatches
            trainer.schedule.tensor_shape = prev_tensor_shape
            trainer.schedule._hooks = prev_metric_hooks


@contextmanager
def switch_sequence_parallel_mode():
    prev_mode = gpc.config.parallel.sequence_parallel
    try:
        gpc.config.parallel.sequence_parallel = False
        yield
    finally:
        gpc.config.parallel.sequence_parallel = prev_mode


def evaluate_on_val_dls(
    trainer,
    val_dls,
    writer,
    logger,
    step_count,
    update_panel: bool = False,
    streaming: bool = False,
):
    with switch_sequence_parallel_mode():
        torch.cuda.empty_cache()
        trainer.eval()
        verbose = gpc.is_rank_for_log()
        data_cfg = gpc.config.data

        for val_name, val_dl in val_dls.items():
            if not streaming and len(val_dl) == 0 and verbose:
                logger.info(f"Validation dataset: {val_name} is empty")
                continue

            val_metric = AccPerplex(
                device=torch.cuda.current_device(),
                tp_pg=gpc.get_group(ParallelMode.TENSOR),
                dp_pg=gpc.get_group(ParallelMode.DATA),
            )
            val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)

            val_loss = 0
            val_idx = -1
            for val_idx, batch in tqdm(
                enumerate(val_dl),
                desc="Val.",
                total=len(val_dl) if not streaming else None,
                position=1,
                disable=not verbose,
                leave=False,
            ):
                with torch.inference_mode():
                    if gpc.is_using_pp():
                        total_val_bsz = len(batch[1])
                        assert total_val_bsz % data_cfg.micro_bsz == 0
                        num_microbatches = total_val_bsz // data_cfg.micro_bsz
                        tensor_shape = torch.Size(
                            [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
                        )

                        with switch_evaluation_pipeline_scheduler(
                            trainer=trainer,
                            num_microbatches=num_microbatches,
                            tensor_shape=tensor_shape,
                            metric_hook_list=[val_sche_metric_hook],
                        ):
                            _, _, loss = trainer.execute_schedule(
                                batch, forward_only=True, return_loss=True, return_output_label=False
                            )
                    else:
                        total_val_bsz = len(batch[1])
                        assert total_val_bsz % data_cfg.micro_bsz == 0
                        grad_accum_size = total_val_bsz // data_cfg.micro_bsz
                        grad_accum_batch_size = data_cfg.micro_bsz
                        with switch_evaluation_no_pipeline_scheduler(
                            trainer=trainer,
                            grad_accum_size=grad_accum_size,
                            grad_accum_batch_size=grad_accum_batch_size,
                            metric_hook_list=[val_sche_metric_hook],
                        ):
                            _, _, loss = trainer.execute_schedule(
                                batch, forward_only=True, return_loss=True, return_output_label=False
                            )
                if verbose:
                    if isinstance(loss, dict):
                        loss = sum(loss.values())
                    val_loss += loss.item()

            assert val_idx != -1
            dist.barrier()

            val_res = val_metric.get_metric()
            if verbose and (streaming or len(val_dl) != 0):
                val_loss = val_loss / (val_idx + 1 + 1e-6)
                infos = {
                    "step": step_count,
                    f"val/{val_name}_loss": val_loss,
                    f"val/{val_name}_acc": val_res["acc"],
                    f"val/{val_name}_plex": val_res["perplexity"],
                }

                for key, value in infos.items():
                    writer.add_scalar(key=key, value=value, step=step_count)

                if update_panel:
                    logger.info(
                        f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
                        extra={
                            "step": step_count,
                            "val_loss": val_loss,
                            "val_acc": val_res["acc"],
                            "val_perplexity": val_res["perplexity"],
                        },
                    )
                else:
                    logger.info(
                        f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
                    )

        trainer.train()
        torch.cuda.empty_cache()
        dist.barrier()