Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, Union | |
| import torch | |
| from torch import nn | |
| from transformers import Trainer as HFTrainer | |
| from transformers.file_utils import is_apex_available | |
| if is_apex_available(): | |
| from apex import amp | |
| from utils import label_smoothed_nll_loss | |
| class Trainer(HFTrainer): | |
| def __init__(self, label_smoothing: float = 0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.label_smoothing = label_smoothing | |
| # override to support label smoothing | |
| def _training_step( | |
| self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer | |
| ) -> float: | |
| model.train() | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| inputs[k] = v.to(self.args.device) | |
| # Our model outputs do not work with DataParallel, so forcing return tuple. | |
| if isinstance(model, nn.DataParallel): | |
| inputs["return_tuple"] = True | |
| if self.label_smoothing == 0: | |
| outputs = model(**inputs) | |
| loss = outputs[0] # model outputs are always tuple in transformers (see doc) | |
| else: | |
| labels = inputs.pop("labels") | |
| labels[labels == -100] = model.config.pad_token_id | |
| outputs = model(**inputs) | |
| lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) | |
| loss, nll_loss = label_smoothed_nll_loss( | |
| lprobs, labels, self.label_smoothing, ignore_index=model.config.pad_token_id | |
| ) | |
| if self.args.n_gpu > 1: | |
| loss = loss.mean() # mean() to average on multi-gpu parallel training | |
| if self.args.gradient_accumulation_steps > 1: | |
| loss = loss / self.args.gradient_accumulation_steps | |
| if self.args.fp16: | |
| with amp.scale_loss(loss, optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| else: | |
| loss.backward() | |
| return loss.item() | |