Spaces:
Build error
Build error
| import torch | |
| from trl import SFTTrainer | |
| class MistakeWeightTrainer(SFTTrainer): | |
| """ | |
| A custom trainer that calculates loss ONLY on tokens where the model | |
| prediction differs from the ground truth. | |
| """ | |
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | |
| labels = inputs.get("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.get("logits") | |
| # 1. Shift for Causal LM (standard for Next Token Prediction) | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # 2. Per-token CrossEntropy | |
| loss_fct = torch.nn.CrossEntropyLoss(reduction='none') | |
| raw_loss = loss_fct( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1) | |
| ) | |
| # 3. Mistake Masking Logic | |
| with torch.no_grad(): | |
| # Get the predicted token IDs | |
| preds = torch.argmax(shift_logits, dim=-1) | |
| # Find where model is WRONG (1) and where it is RIGHT (0) | |
| mistake_mask = (preds != shift_labels).view(-1).float() | |
| # Ignore padding tokens (-100 is standard for HF) | |
| padding_mask = (shift_labels.view(-1) != -100).float() | |
| final_mask = mistake_mask * padding_mask | |
| # 4. Zero-out gradients for correct answers | |
| # This forces weights to only change when a mistake is detected. | |
| mistake_loss = (raw_loss * final_mask).sum() / (final_mask.sum() + 1e-9) | |
| return (mistake_loss, outputs) if return_outputs else mistake_loss | |