|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
from typing import Iterable, Optional |
|
|
|
|
|
from internlm.core.engine import Engine |
|
|
from internlm.core.scheduler import ( |
|
|
BaseScheduler, |
|
|
InterleavedPipelineScheduler, |
|
|
NonPipelineScheduler, |
|
|
PipelineScheduler, |
|
|
) |
|
|
|
|
|
|
|
|
class TrainState: |
|
|
""" |
|
|
The TrainState class is used to record the current state of training. |
|
|
|
|
|
Args: |
|
|
train_dl (DataLoader): The DataLoader object used for training. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, batch_sampler) -> None: |
|
|
""" |
|
|
Args: |
|
|
config (Config): internlm config |
|
|
batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is |
|
|
asynchronous and prefetched, the batch_sampler state maintained inside the |
|
|
dataloader are faster then the actual training progress, so we copy the |
|
|
batch_sampler as the anchor point of ckpt reload. |
|
|
""" |
|
|
|
|
|
self.batch_count: int = 0 |
|
|
|
|
|
self.num_consumed_samples_in_epoch: int = 0 |
|
|
|
|
|
self.num_consumed_tokens: int = 0 |
|
|
|
|
|
self.inf_nan_skip_batches: int = 0 |
|
|
|
|
|
self.step_count: int = 0 |
|
|
|
|
|
|
|
|
self.total_steps: int = config.data.total_steps |
|
|
|
|
|
|
|
|
self.resume_tb_folder = config.resume_tb_folder |
|
|
|
|
|
self.tensorboard_folder = config.tensorboard_folder |
|
|
|
|
|
|
|
|
self.lr = config.adam.lr |
|
|
|
|
|
|
|
|
if batch_sampler: |
|
|
self.init_batch_sampler(batch_sampler) |
|
|
|
|
|
def init_batch_sampler(self, batch_sampler): |
|
|
""" |
|
|
Args: |
|
|
batch_sampler (torch.utils.data.Sampler): sampler. |
|
|
""" |
|
|
|
|
|
self.batch_sampler = batch_sampler.copy() |
|
|
|
|
|
self.batch_sampler_iter = iter(self.batch_sampler) |
|
|
|
|
|
def __str__(self) -> str: |
|
|
"""Returns a string representation of the training state in JSON format.""" |
|
|
info = { |
|
|
"batch_count": self.batch_count, |
|
|
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, |
|
|
"num_consumed_tokens": self.num_consumed_tokens, |
|
|
"inf_nan_skip_batches": self.inf_nan_skip_batches, |
|
|
"step_count": self.step_count, |
|
|
} |
|
|
|
|
|
return json.dumps(info, indent=4, sort_keys=True) |
|
|
|
|
|
def load_state_dict(self, other_stuffs): |
|
|
""" |
|
|
Resumes training from a checkpoint. |
|
|
|
|
|
Args: |
|
|
other_stuffs (dict): Other information needed to resume training. |
|
|
""" |
|
|
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"] |
|
|
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"] |
|
|
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.batch_count = other_stuffs["batch_count"] + 1 |
|
|
self.step_count = other_stuffs.get("step_count", self.batch_count) |
|
|
|
|
|
|
|
|
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None) |
|
|
|
|
|
def state_dict(self): |
|
|
return { |
|
|
"batch_count": self.batch_count, |
|
|
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, |
|
|
"num_consumed_tokens": self.num_consumed_tokens, |
|
|
"inf_nan_skip_batches": self.inf_nan_skip_batches, |
|
|
"step_count": self.step_count, |
|
|
"tensorboard_folder": self.tensorboard_folder, |
|
|
} |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
"""This is a class tending for easy deployments of users' training and evaluation instead of |
|
|
writing their own scripts. |
|
|
|
|
|
Args: |
|
|
engine (:class:`Engine`): Engine responsible for the process function. |
|
|
schedule (:class:`BaseScheduler`, optional): Runtime schedule. Defaults to None. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
engine: Engine, |
|
|
schedule: Optional[BaseScheduler] = None, |
|
|
): |
|
|
"""Initializes the Trainer class. |
|
|
|
|
|
Args: |
|
|
engine (Engine): The engine responsible for the process function. |
|
|
schedule (Optional[BaseScheduler], optional): The runtime schedule. Defaults to None. |
|
|
""" |
|
|
self._engine = engine |
|
|
|
|
|
|
|
|
if schedule is None: |
|
|
self._schedule = NonPipelineScheduler() |
|
|
else: |
|
|
assert isinstance( |
|
|
schedule, BaseScheduler |
|
|
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}" |
|
|
self._schedule = schedule |
|
|
|
|
|
self._schedule.pre_processing(self._engine) |
|
|
|
|
|
@property |
|
|
def engine(self): |
|
|
"""Returns the engine that responsible for managing the training and evaluation process.""" |
|
|
return self._engine |
|
|
|
|
|
@property |
|
|
def schedule(self): |
|
|
"""Returns the runtime scheduler.""" |
|
|
return self._schedule |
|
|
|
|
|
@property |
|
|
def uses_pipeline(self): |
|
|
"""Returns whether the pipeline parallel is used or not.""" |
|
|
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler)) |
|
|
|
|
|
def train(self): |
|
|
"""Sets the model to training mode.""" |
|
|
self._engine.train() |
|
|
|
|
|
def eval(self): |
|
|
"""Sets the model to evaluation mode.""" |
|
|
self._engine.eval() |
|
|
|
|
|
def zero_grad(self): |
|
|
"""Sets the gradient of all parameters in the model to zero.""" |
|
|
self._engine.zero_grad() |
|
|
|
|
|
def step(self): |
|
|
"""Executes the parameter update step.""" |
|
|
return self._engine.step() |
|
|
|
|
|
def execute_schedule(self, data_iter: Iterable, **kwargs): |
|
|
"""Runs the forward, loss computation, and backward for the model. |
|
|
Returns a tuple of (output, label, loss). |
|
|
|
|
|
Args: |
|
|
data_iter (Iterable): The data iterator. |
|
|
**kwargs: Additional keyword arguments. |
|
|
|
|
|
Returns: |
|
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). |
|
|
""" |
|
|
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) |
|
|
return output, label, loss |
|
|
|