Spaces:
Runtime error
Runtime error
| from typing import Callable, Dict, Iterable, List | |
| from torch import nn | |
| # these functions are taken from transformers repo | |
| def grad_status(model: nn.Module) -> Iterable: | |
| return (par.requires_grad for par in model.parameters()) | |
| def freeze_params(model: nn.Module): | |
| for par in model.parameters(): | |
| par.requires_grad = False | |
| def freeze_embeds(model: nn.Module): | |
| """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" | |
| try: | |
| freeze_params(model.model.shared) | |
| for d in [model.model.encoder, model.model.decoder]: | |
| freeze_params(d.embed_positions) | |
| freeze_params(d.embed_tokens) | |
| except AttributeError: | |
| freeze_params(model.shared) | |
| for d in [model.encoder, model.decoder]: | |
| freeze_params(d.embed_tokens) | |
| def assert_not_all_frozen(model): | |
| model_grads: List[bool] = list(grad_status(model)) | |
| npars = len(model_grads) | |
| assert any(model_grads), f"none of {npars} weights require grad" | |
| def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): | |
| """From fairseq""" | |
| if target.dim() == lprobs.dim() - 1: | |
| target = target.unsqueeze(-1) | |
| nll_loss = -lprobs.gather(dim=-1, index=target) | |
| smooth_loss = -lprobs.sum(dim=-1, keepdim=True) | |
| if ignore_index is not None: | |
| pad_mask = target.eq(ignore_index) | |
| nll_loss.masked_fill_(pad_mask, 0.0) | |
| smooth_loss.masked_fill_(pad_mask, 0.0) | |
| bs = pad_mask.long().sum() | |
| else: | |
| nll_loss = nll_loss.squeeze(-1) | |
| smooth_loss = smooth_loss.squeeze(-1) | |
| bs = lprobs.shape[0] | |
| nll_loss = nll_loss.sum() # mean()? Scared to break other math. | |
| smooth_loss = smooth_loss.sum() | |
| eps_i = epsilon / lprobs.size(-1) | |
| loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss | |
| return loss / bs, nll_loss / bs |