| "Implements [mixup](https://arxiv.org/abs/1710.09412) training method" | |
| from ..torch_core import * | |
| from ..callback import * | |
| from ..basic_train import Learner, LearnerCallback | |
| class MixUpCallback(LearnerCallback): | |
| "Callback that creates the mixed-up input and target." | |
| def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True): | |
| super().__init__(learn) | |
| self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y | |
| def on_train_begin(self, **kwargs): | |
| if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func) | |
| def on_batch_begin(self, last_input, last_target, train, **kwargs): | |
| "Applies mixup to `last_input` and `last_target` if `train`." | |
| if not train: return | |
| lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0)) | |
| lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1) | |
| lambd = last_input.new(lambd) | |
| shuffle = torch.randperm(last_target.size(0)).to(last_input.device) | |
| x1, y1 = last_input[shuffle], last_target[shuffle] | |
| if self.stack_x: | |
| new_input = [last_input, last_input[shuffle], lambd] | |
| else: | |
| out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)] | |
| new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape)) | |
| if self.stack_y: | |
| new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), lambd[:,None].float()], 1) | |
| else: | |
| if len(last_target.shape) == 2: | |
| lambd = lambd.unsqueeze(1).float() | |
| new_target = last_target.float() * lambd + y1.float() * (1-lambd) | |
| return {'last_input': new_input, 'last_target': new_target} | |
| def on_train_end(self, **kwargs): | |
| if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old() | |
| class MixUpLoss(Module): | |
| "Adapt the loss function `crit` to go with mixup." | |
| def __init__(self, crit, reduction='mean'): | |
| super().__init__() | |
| if hasattr(crit, 'reduction'): | |
| self.crit = crit | |
| self.old_red = crit.reduction | |
| setattr(self.crit, 'reduction', 'none') | |
| else: | |
| self.crit = partial(crit, reduction='none') | |
| self.old_crit = crit | |
| self.reduction = reduction | |
| def forward(self, output, target): | |
| if len(target.size()) == 2: | |
| loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long()) | |
| d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean() | |
| else: d = self.crit(output, target) | |
| if self.reduction == 'mean': return d.mean() | |
| elif self.reduction == 'sum': return d.sum() | |
| return d | |
| def get_old(self): | |
| if hasattr(self, 'old_crit'): return self.old_crit | |
| elif hasattr(self, 'old_red'): | |
| setattr(self.crit, 'reduction', self.old_red) | |
| return self.crit | |