| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """The definition of trainer. |
| |
| Init Phase: |
| |
| 1. Init batch generator. |
| 2. Init optimizer (deepspeed). |
| 3. Shard model. |
| 4. Init optimizer (fsdp). |
| 5. Init lr scheduler. |
| |
| Train Phase: |
| 1. Train Loop |
| |
| """ |
|
|
| from abc import abstractmethod |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from ..accelerator.helper import ReduceOp |
| from ..accelerator.interface import Dim, DistributedInterface |
| from ..config import TrainingArguments |
| from ..utils import logging |
| from ..utils.helper import compute_valid_tokens |
| from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset |
| from .utils.batching import BatchGenerator |
| from .utils.rendering import Renderer |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class BaseTrainer: |
| def __init__( |
| self, |
| args: TrainingArguments, |
| model: HFModel, |
| renderer: Renderer, |
| train_dataset: TorchDataset, |
| ) -> None: |
| self.args = args |
| self.model = model |
| self.renderer = renderer |
| self.train_dataset = train_dataset |
|
|
| |
| self.global_step = 0 |
|
|
| |
| self.device = DistributedInterface().current_device |
| self.dp_size = DistributedInterface().get_world_size(Dim.DP) |
| self.model_input_names = self.renderer.processor.model_input_names |
|
|
| self._create_batch_generator() |
| self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator) |
|
|
| if self.args.enable_activation_checkpointing: |
| self.model.gradient_checkpointing_enable({"use_reentrant": False}) |
|
|
| if self.args.dist_config is not None: |
| shard_need_optimizer = self.args.dist_config.name == "deepspeed" |
| else: |
| shard_need_optimizer = False |
|
|
| if shard_need_optimizer: |
| self._init_optimizer() |
| self._shard_model() |
| else: |
| self._shard_model() |
| self._init_optimizer() |
|
|
| self._init_lr_scheduler() |
|
|
| def _create_batch_generator(self) -> None: |
| self.train_batch_generator = BatchGenerator( |
| dataset=self.train_dataset, |
| renderer=self.renderer, |
| micro_batch_size=self.args.micro_batch_size, |
| global_batch_size=self.args.global_batch_size, |
| cutoff_len=self.args.cutoff_len, |
| batching_workers=self.args.batching_workers, |
| batching_strategy=self.args.batching_strategy, |
| ) |
|
|
| def _shard_model(self) -> None: |
| pass |
|
|
| def _init_optimizer(self) -> None: |
| """Init optimizer.""" |
| if self.args.optim_config is None: |
| _trainable_params = [p for p in self.model.parameters() if p.requires_grad] |
| self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate) |
| else: |
| from ..plugins.trainer_plugins.optimizer import OptimizerPlugin |
|
|
| self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config) |
|
|
| def _init_lr_scheduler(self) -> None: |
| """Init lr scheduler.""" |
| if self.args.lr_scheduler_config is None: |
| self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0) |
| else: |
| from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin |
|
|
| self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)( |
| self.optimizer, self.num_training_steps, self.args.lr_scheduler_config |
| ) |
|
|
| def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor: |
| """Compute log probs. |
| |
| log_probs: Tensor of shape (batch_size, seq_len - 1) |
| """ |
| batch_size, _ = batch["labels"].shape |
| model_inputs = { |
| k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names |
| } |
| labels = batch["labels"].to(self.device, non_blocking=True) |
| outputs: ModelOutput = model(**model_inputs) |
| logits = outputs.logits.float() |
| shift_labels = labels[..., 1:].contiguous().view(-1) |
| shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1) |
| return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1) |
|
|
| @abstractmethod |
| def compute_loss(self, batch: BatchInput) -> Tensor: |
| """Compute the scalar loss.""" |
| ... |
|
|
| def fit(self) -> None: |
| """Train the model.""" |
| self.model.train() |
| for epoch in range(self.args.num_train_epochs): |
| self.train_batch_generator.set_epoch(epoch) |
| for micro_batches in self.train_batch_generator: |
| self.global_step += 1 |
| step_loss = 0 |
| step_valid_tokens = compute_valid_tokens(micro_batches) |
| step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) |
| for micro_batch in micro_batches: |
| loss = self.compute_loss(micro_batch) |
| mini_step_valid_tokens = compute_valid_tokens([micro_batch]) |
| |
| loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6) |
|
|
| loss.backward() |
| step_loss += loss.item() |
|
|
| grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() |
| if not torch.isfinite(grad_norm): |
| logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}") |
| else: |
| self.optimizer.step() |
|
|
| self.lr_scheduler.step() |
| self.optimizer.zero_grad() |
|
|
| step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) |
| DistributedInterface().sync() |
| print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") |
|
|
| def save_model(self) -> None: |
| """Save the model.""" |
| self.model.save_pretrained(self.args.output_dir) |
| self.renderer.processor.save_pretrained(self.args.output_dir) |
| logger.info_rank0(f"Model saved to {self.args.output_dir}") |
|
|