BetterTrainer / trainer_logic.py
goodgoals's picture
Create trainer_logic.py
8912b06 verified
import torch
from trl import SFTTrainer
class MistakeWeightTrainer(SFTTrainer):
"""
A custom trainer that calculates loss ONLY on tokens where the model
prediction differs from the ground truth.
"""
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
# 1. Shift for Causal LM (standard for Next Token Prediction)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 2. Per-token CrossEntropy
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
raw_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# 3. Mistake Masking Logic
with torch.no_grad():
# Get the predicted token IDs
preds = torch.argmax(shift_logits, dim=-1)
# Find where model is WRONG (1) and where it is RIGHT (0)
mistake_mask = (preds != shift_labels).view(-1).float()
# Ignore padding tokens (-100 is standard for HF)
padding_mask = (shift_labels.view(-1) != -100).float()
final_mask = mistake_mask * padding_mask
# 4. Zero-out gradients for correct answers
# This forces weights to only change when a mistake is detected.
mistake_loss = (raw_loss * final_mask).sum() / (final_mask.sum() + 1e-9)
return (mistake_loss, outputs) if return_outputs else mistake_loss