| """ |
| Custom HuggingFace Trainer subclass. |
| Uses the model's built-in cross-entropy loss (computed during forward pass) |
| instead of recomputing it, saving ~60MB of VRAM. |
| """ |
|
|
| from transformers import Trainer |
| import torch |
| from loguru import logger |
|
|
|
|
| class CorrectionTrainer(Trainer): |
| """Custom trainer — uses model's built-in loss directly.""" |
|
|
| def __init__(self, loss_fn, fingerprinter, tokenizer, **kwargs): |
| super().__init__(**kwargs) |
| self.loss_fn = loss_fn |
| self.fingerprinter = fingerprinter |
| self.correction_tokenizer = tokenizer |
|
|
| def _strip_custom_fields(self, inputs): |
| """Remove dataset fields that T5 doesn't accept.""" |
| inputs.pop("style_vector", None) |
| inputs.pop("input_text", None) |
| inputs.pop("target_text", None) |
| return {k: v for k, v in inputs.items() if k in ("input_ids", "attention_mask", "labels")} |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| """Use model's built-in CE loss — avoids double-computing logits loss.""" |
| model_inputs = self._strip_custom_fields(inputs) |
|
|
| outputs = model(**model_inputs) |
| |
| |
| loss = outputs.loss |
|
|
| return (loss, outputs) if return_outputs else loss |
|
|
| def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): |
| """Compute eval loss directly — strips custom fields and runs forward. |
| |
| The parent's prediction_step doesn't return eval_loss when custom |
| fields are present, so we handle it ourselves. |
| """ |
| model_inputs = self._strip_custom_fields(inputs) |
| model_inputs = self._prepare_inputs(model_inputs) |
|
|
| with torch.no_grad(): |
| outputs = model(**model_inputs) |
| loss = outputs.loss.detach() |
|
|
| return (loss, None, None) |
|
|
|
|