| |
|
| | from transformers import Trainer, TrainingArguments
|
| | from torch.utils.data import DataLoader
|
| | import torch
|
| | import torch.nn.functional as F
|
| |
|
| | class MTPTrainer(Trainer):
|
| | """
|
| | Custom trainer for Multi-Token Prediction training.
|
| | """
|
| | def __init__(self, model, args=None, train_dataset=None, eval_dataset=None, **kwargs):
|
| | super().__init__(
|
| | model=model,
|
| | args=args,
|
| | train_dataset=train_dataset,
|
| | eval_dataset=eval_dataset,
|
| | **kwargs
|
| | )
|
| | self.use_mtp_training = True
|
| |
|
| | def compute_loss(self, model, inputs, return_outputs=False):
|
| | """
|
| | Compute loss during training - handles both MTP and standard LM training.
|
| | """
|
| | labels = inputs.get("labels")
|
| | outputs = model(**inputs, labels=labels, use_mtp_training=self.use_mtp_training)
|
| | loss = outputs.loss
|
| |
|
| | return (loss, outputs) if return_outputs else loss
|
| |
|
| | def train_step_with_mtp(self, model, inputs):
|
| | """
|
| | Specialized training step for MTP training.
|
| | """
|
| | model.set_mtp_training(True)
|
| | return self.training_step(model, inputs)
|
| |
|
| | def train_step_with_lm(self, model, inputs):
|
| | """
|
| | Standard language modeling training step.
|
| | """
|
| | model.set_mtp_training(False)
|
| | return self.training_step(model, inputs)
|
| |
|
| |
|
| | class RLTrainer:
|
| | """
|
| | Separate trainer class for Reinforcement Learning training.
|
| | This can use the generate_with_logprobs method from your model.
|
| | """
|
| | def __init__(self, model, tokenizer, rl_config=None):
|
| | self.model = model
|
| | self.tokenizer = tokenizer
|
| | self.rl_config = rl_config or {}
|
| |
|
| | def rl_training_step(self, input_ids, old_log_probs, old_action_probs, **kwargs):
|
| | """
|
| | Perform an RL training step using the model's generate_with_logprobs method
|
| | and the reward functions from rl_utils.
|
| | """
|
| |
|
| | from .rl_utils import (
|
| | batch_compute_rewards,
|
| | compute_ppo_loss,
|
| | compute_kl_divergence,
|
| | compute_entropy_bonus,
|
| | AdaptiveKLController
|
| | )
|
| |
|
| |
|
| |
|
| | pass |