| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # losses for sparse ga | |
| # -------------------------------------------------------- | |
| import torch | |
| import numpy as np | |
| def l05_loss(x, y): | |
| return torch.linalg.norm(x - y, dim=-1).sqrt() | |
| def l1_loss(x, y): | |
| return torch.linalg.norm(x - y, dim=-1) | |
| def gamma_loss(gamma, mul=1, offset=None, clip=np.inf): | |
| if offset is None: | |
| if gamma == 1: | |
| return l1_loss | |
| # d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1)) | |
| offset = (1 / gamma)**(1 / (gamma - 1)) | |
| def loss_func(x, y): | |
| return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma | |
| return loss_func | |
| def meta_gamma_loss(): | |
| return lambda alpha: gamma_loss(alpha) | |