File size: 10,266 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize

from typing import Callable, Iterable, List, Optional, Tuple

from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from internlm.core.context import global_context as gpc
from internlm.core.context import ParallelMode
from internlm.core.engine import Engine, KDEngine
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
from internlm.core.scheduler import (InterleavedPipelineScheduler, KDNonPipelineScheduler, KDPipelineScheduler,
                                     NonPipelineScheduler, PipelineScheduler, SchedulerHook)
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
from internlm.core.trainer import Trainer
from internlm.data.utils import unpack_data
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
from internlm.utils.common import get_current_device


def initialize_kd_trainer(
    model: nn.Module,
    teacher: nn.Module,
    optimizer: Optimizer,
    criterion: Optional[_Loss] = None,
    kd_criterion: Optional[_Loss] = None,
    train_dataloader: Optional[Iterable] = None,
    test_dataloader: Optional[Iterable] = None,
    lr_scheduler: Optional[_LRScheduler] = None,
    beta2_scheduler: Optional[Beta2Scheduler] = None,
    scheduler_hooks: Optional[List[SchedulerHook]] = None,
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
    """Core function to wrap the essential training components with our functionality based on the config which is
    loaded into gpc.config.

    Args:
        model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model.
        optimizer (:class:`BaseOptimizer`): Your optimizer for training.
        criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
        train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
        test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
        lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.

    Returns:
        Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
            A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
            where only ``trainer`` could not be None.
    """

    if isinstance(model, nn.Module):
        # first sync model across dp ranks
        model.to(get_current_device())
    elif isinstance(model, Callable):
        model = model().to(get_current_device())

    # clip grad norm
    clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)

    assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"

    # gradient handler, only support PipelineSharedModuleGradientHandler now
    if gpc.is_using_pp():
        gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
    gradient_handler_cfg = gpc.config.get("gradient_handler", [])
    gradient_handlers = []
    assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
    for config in gradient_handler_cfg:
        if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
            handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
            gradient_handlers.append(handler)

    # initialize scheduler for trainer
    scheduler = None
    if gpc.config.model.use_flash_attn:
        data_fn = None
    else:
        data_fn = unpack_data
    if gpc.is_using_pp():
        gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
        tensor_shape = get_tensor_shape()
        use_interleaved = (
                hasattr(gpc.config, "model") and hasattr(gpc.config.model,
                                                         "num_chunks") and gpc.config.model.num_chunks > 1
        )
        scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
        if use_interleaved:
            raise NotImplementedError('InterleavedPipelineScheduler for KD is not implemented')

        else:
            scheduler = KDPipelineScheduler(
                data_process_func=data_fn,
                num_microbatches=gpc.config.NUM_MICRO_BATCHES,
                dtype=gpc.config.model["dtype"],
                tensor_shape=tensor_shape,
                scatter_gather_tensors=scatter_gather,
                scheduler_hooks=scheduler_hooks,
            )
    else:
        scheduler = KDNonPipelineScheduler(
            data_process_func=data_fn,
            gradient_accumulation_size=gpc.config.data.gradient_accumulation,
            scheduler_hooks=scheduler_hooks,
        )

    # initialize engine for trainer
    engine = KDEngine(
        model=model,
        teacher=teacher,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        beta2_scheduler=beta2_scheduler,
        criterion=criterion,
        kd_criterion=kd_criterion,
        gradient_handlers=gradient_handlers,
        clip_grad_norm=clip_grad_norm,
    )

    trainer = Trainer(engine, scheduler)

    return trainer, train_dataloader, test_dataloader, lr_scheduler


def initialize_trainer(
        model: nn.Module,
        optimizer: Optimizer,
        criterion: Optional[_Loss] = None,
        train_dataloader: Optional[Iterable] = None,
        test_dataloader: Optional[Iterable] = None,
        lr_scheduler: Optional[_LRScheduler] = None,
        beta2_scheduler: Optional[Beta2Scheduler] = None,
        scheduler_hooks: Optional[List[SchedulerHook]] = None,
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
    """Core function to wrap the essential training components with our functionality based on the config which is
    loaded into gpc.config.

    Args:
        model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model.
        optimizer (:class:`BaseOptimizer`): Your optimizer for training.
        criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
        train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
        test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
        lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.

    Returns:
        Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
            A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
            where only ``trainer`` could not be None.
    """

    if isinstance(model, nn.Module):
        # first sync model across dp ranks
        model.to(get_current_device())
    elif isinstance(model, Callable):
        model = model().to(get_current_device())

    # clip grad norm
    clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)

    assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"

    # gradient handler, only support PipelineSharedModuleGradientHandler now
    if gpc.is_using_pp():
        gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
    gradient_handler_cfg = gpc.config.get("gradient_handler", [])
    gradient_handlers = []
    assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
    for config in gradient_handler_cfg:
        if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
            handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
            gradient_handlers.append(handler)

    # initialize scheduler for trainer
    scheduler = None
    if gpc.config.model.use_flash_attn:
        data_fn = None
    else:
        data_fn = unpack_data
    if gpc.is_using_pp():
        gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
        tensor_shape = get_tensor_shape()
        use_interleaved = (
            hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
        )
        scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
        if use_interleaved:
            if isinstance(model, nn.Sequential):
                model = nn.ModuleList([model])

            communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
            scheduler = InterleavedPipelineScheduler(
                num_microbatches=gpc.config.NUM_MICRO_BATCHES,
                num_chunks=gpc.config.model.num_chunks,
                dtype=gpc.config.model["dtype"],
                tensor_shape=tensor_shape,
                scatter_gather_tensors=scatter_gather,
                scheduler_hooks=scheduler_hooks,
                communication_overlap=communication_overlap,
            )
        else:
            scheduler = PipelineScheduler(
                data_process_func=data_fn,
                num_microbatches=gpc.config.NUM_MICRO_BATCHES,
                dtype=gpc.config.model["dtype"],
                tensor_shape=tensor_shape,
                scatter_gather_tensors=scatter_gather,
                scheduler_hooks=scheduler_hooks,
            )
    else:
        scheduler = NonPipelineScheduler(
            data_process_func=data_fn,
            gradient_accumulation_size=gpc.config.data.gradient_accumulation,
            scheduler_hooks=scheduler_hooks,
        )

    # initialize engine for trainer
    engine = Engine(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        beta2_scheduler=beta2_scheduler,
        criterion=criterion,
        gradient_handlers=gradient_handlers,
        clip_grad_norm=clip_grad_norm,
    )

    trainer = Trainer(engine, scheduler)

    return trainer, train_dataloader, test_dataloader, lr_scheduler