Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| def kl_loss(p, q): | |
| p_loss = F.kl_div(p, torch.exp(q), reduction="sum") | |
| q_loss = F.kl_div(q, torch.exp(p), reduction="sum") | |
| loss = (p_loss + q_loss) / 2 | |
| return loss | |
| def label_smoothed_nll_loss( | |
| logits, | |
| target, | |
| epsilon, | |
| reduction="sum", | |
| ): | |
| lprobs = F.log_softmax(logits, dim=-1) | |
| if target.dim() == lprobs.dim() - 1: | |
| target = target.unsqueeze(-1) | |
| nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1) | |
| smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1) | |
| eps_i = epsilon / (lprobs.size(-1) - 1) | |
| loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss | |
| ntokens = loss.numel() | |
| nll_loss = nll_loss.sum() | |
| loss = loss.sum() | |
| if reduction == "mean": | |
| loss /= ntokens | |
| return loss # nll_loss, ntokens | |