Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import Optional, Sequence, List | |
| class ImageInpaintingL1Loss(nn.Module): | |
| """ | |
| An inpainting loss where we use our free lunch! | |
| Include the given signal (i.e., unmasked pixels) in the final model prediction. | |
| """ | |
| def __init__(self): | |
| super(ImageInpaintingL1Loss, self).__init__() | |
| def forward( | |
| self, | |
| predicted_image: torch.Tensor, | |
| target_image: torch.Tensor, | |
| mask: torch.Tensor, | |
| ): | |
| """ | |
| Final loss = || (given_pixels + pred_pixels) - (target) || | |
| :param original_image: (B, H, W) | |
| :param predicted_image: (B, H, W) | |
| :param target_image: (B, H, W) | |
| :param mask: (B, H, W) | |
| """ | |
| # mask = 0: obstructed | |
| given_pixels = target_image * mask | |
| pred_pixels = predicted_image * ~mask | |
| final_prediction = given_pixels + pred_pixels | |
| return torch.nn.functional.l1_loss(final_prediction, target_image) | |
| def get_final_prediction( | |
| predicted_image: torch.Tensor, target_image: torch.Tensor, mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Returns | |
| (target * mask) + (pred * ~mask) | |
| """ | |
| # y_sparse [given] | |
| given_pixels = target_image * mask | |
| # pred - y_sparse_hat | |
| pred_pixels = predicted_image * ~mask | |
| # pred + y_sparse | |
| final_prediction = given_pixels + pred_pixels | |
| return final_prediction | |
| class VAELoss(nn.Module): | |
| def __init__(self): | |
| """ | |
| Variational Autoencoder Loss Function. | |
| """ | |
| super(VAELoss, self).__init__() | |
| def forward(self, output, target, mu, logvar): | |
| recon_loss = F.mse_loss(output, target, reduction="sum") / target.size(0) | |
| kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| return recon_loss + 0.002 * kl_loss | |
| # https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch | |
| class DiceLoss(nn.Module): | |
| def __init__(self, weight=None, size_average=True): | |
| super(DiceLoss, self).__init__() | |
| def forward(self, inputs, targets, smooth=1): | |
| # comment out if your model contains a sigmoid or equivalent activation layer | |
| inputs = F.sigmoid(inputs) | |
| # flatten label and prediction tensors | |
| inputs = inputs.view(-1) | |
| targets = targets.view(-1) | |
| intersection = (inputs * targets).sum() | |
| dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) | |
| return 1 - dice | |
| # https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8 | |
| class FocalLoss(nn.Module): | |
| """Focal Loss, as described in https://arxiv.org/abs/1708.02002. | |
| It is essentially an enhancement to cross entropy loss and is | |
| useful for classification tasks when there is a large class imbalance. | |
| x is expected to contain raw, unnormalized scores for each class. | |
| y is expected to contain class labels. | |
| Shape: | |
| - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. | |
| - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. | |
| """ | |
| def __init__( | |
| self, | |
| alpha: Optional[torch.Tensor] = None, | |
| gamma: float = 0.0, | |
| reduction: str = "mean", | |
| ignore_index: int = -100, | |
| ): | |
| """Constructor. | |
| Args: | |
| alpha (Tensor, optional): Weights for each class. Defaults to None. | |
| gamma (float, optional): A constant, as described in the paper. | |
| Defaults to 0. | |
| reduction (str, optional): 'mean', 'sum' or 'none'. | |
| Defaults to 'mean'. | |
| ignore_index (int, optional): class label to ignore. | |
| Defaults to -100. | |
| """ | |
| if reduction not in ("mean", "sum", "none"): | |
| raise ValueError('Reduction must be one of: "mean", "sum", "none".') | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.ignore_index = ignore_index | |
| self.reduction = reduction | |
| self.nll_loss = nn.NLLLoss( | |
| weight=alpha, reduction="none", ignore_index=ignore_index | |
| ) | |
| def __repr__(self): | |
| arg_keys = ["alpha", "gamma", "ignore_index", "reduction"] | |
| arg_vals = [self.__dict__[k] for k in arg_keys] | |
| arg_strs = [f"{k}={v!r}" for k, v in zip(arg_keys, arg_vals)] | |
| arg_str = ", ".join(arg_strs) | |
| return f"{type(self).__name__}({arg_str})" | |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| if x.ndim > 2: | |
| # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C) | |
| c = x.shape[1] | |
| x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c) | |
| # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,) | |
| y = y.view(-1) | |
| unignored_mask = y != self.ignore_index | |
| y = y[unignored_mask] | |
| if len(y) == 0: | |
| return torch.tensor(0.0) | |
| x = x[unignored_mask] | |
| # compute weighted cross entropy term: -alpha * log(pt) | |
| # (alpha is already part of self.nll_loss) | |
| log_p = F.log_softmax(x, dim=-1) | |
| ce = self.nll_loss(log_p, y) | |
| # get true class column from each row | |
| all_rows = torch.arange(len(x)) | |
| log_pt = log_p[all_rows, y] | |
| # compute focal term: (1 - pt)^gamma | |
| pt = log_pt.exp() | |
| focal_term = (1 - pt) ** self.gamma | |
| # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) | |
| loss = focal_term * ce | |
| if self.reduction == "mean": | |
| loss = loss.mean() | |
| elif self.reduction == "sum": | |
| loss = loss.sum() | |
| return loss | |
| def focal_loss( | |
| alpha: Optional[Sequence] = None, | |
| gamma: float = 0.0, | |
| reduction: str = "mean", | |
| ignore_index: int = -100, | |
| device="cpu", | |
| dtype=torch.float32, | |
| ) -> FocalLoss: | |
| """Factory function for FocalLoss. | |
| Args: | |
| alpha (Sequence, optional): Weights for each class. Will be converted | |
| to a Tensor if not None. Defaults to None. | |
| gamma (float, optional): A constant, as described in the paper. | |
| Defaults to 0. | |
| reduction (str, optional): 'mean', 'sum' or 'none'. | |
| Defaults to 'mean'. | |
| ignore_index (int, optional): class label to ignore. | |
| Defaults to -100. | |
| device (str, optional): Device to move alpha to. Defaults to 'cpu'. | |
| dtype (torch.dtype, optional): dtype to cast alpha to. | |
| Defaults to torch.float32. | |
| Returns: | |
| A FocalLoss object | |
| """ | |
| if alpha is not None: | |
| if not isinstance(alpha, torch.Tensor): | |
| alpha = torch.tensor(alpha) | |
| alpha = alpha.to(device=device, dtype=dtype) | |
| fl = FocalLoss( | |
| alpha=alpha, gamma=gamma, reduction=reduction, ignore_index=ignore_index | |
| ) | |
| return fl | |
| def vae_loss_function(output, x, mu, logvar): | |
| # reconstruction loss | |
| recon_loss = F.mse_loss(output, x, reduction="sum") / x.size(0) | |
| kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| return recon_loss + 0.002 * kl_loss | |
| def _center(x: torch.Tensor) -> torch.Tensor: | |
| """Zero‑centre each (H, W) map independently.""" | |
| return x - x.mean(dim=(-2, -1), keepdim=True) | |
| def rms_roughness(x: torch.Tensor) -> torch.Tensor: | |
| # B × H × W ➜ B | |
| x = _center(x) | |
| return torch.sqrt((x**2).mean(dim=(-2, -1))) | |
| def mean_roughness(x: torch.Tensor) -> torch.Tensor: | |
| # B × H × W ➜ B | |
| x = _center(x) | |
| return x.abs().mean(dim=(-2, -1)) | |
| def roughness_loss( | |
| pred: torch.Tensor, | |
| target: torch.Tensor, | |
| dataset_min: float, | |
| dataset_max: float, | |
| use_metrics: List[str] = ["rms", "mean"], | |
| weights: List[float] = [1.0, 1.0], | |
| ) -> torch.Tensor: | |
| """ | |
| Surface‑roughness consistency loss. | |
| Parameters | |
| ---------- | |
| pred, target : (B, H, W) tensors | |
| Normalised to [0, 1]. This function rescales them to physical units | |
| using `dataset_min` / `dataset_max` before computing roughness. | |
| dataset_min, dataset_max : float | |
| Global minimum / maximum of the *unnormalised* topography maps. | |
| use_metrics : list[str] | |
| Any subset of {"rms", "mean"}. | |
| weights : list[float] | |
| Per‑metric weights, same order as `use_metrics`. | |
| """ | |
| # ------------------------------------------------------------ | |
| # 1) un‑normalise to original scale (e.g. nanometres) | |
| # ------------------------------------------------------------ | |
| scale = dataset_max - dataset_min | |
| pred_phys = (pred * scale + dataset_min) * 1e9 | |
| target_phys = (target * scale + dataset_min) * 1e9 | |
| # ------------------------------------------------------------ | |
| # 2) compute roughness metrics | |
| # ------------------------------------------------------------ | |
| loss_terms: List[torch.Tensor] = [] | |
| if "rms" in use_metrics: | |
| rms_diff = (rms_roughness(pred_phys) - rms_roughness(target_phys)).abs() | |
| loss_terms.append(weights[0] * rms_diff) | |
| if "mean" in use_metrics: | |
| mean_diff = (mean_roughness(pred_phys) - mean_roughness(target_phys)).abs() | |
| # if both metrics are used, weights[1] applies; else weights[0] | |
| w = weights[1] if len(use_metrics) > 1 else weights[0] | |
| loss_terms.append(w * mean_diff) | |
| # ------------------------------------------------------------ | |
| # 3) aggregate to a scalar | |
| # ------------------------------------------------------------ | |
| # -> (B, n_metrics) ➜ scalar | |
| return torch.stack(loss_terms, dim=-1).mean() | |
| def rotation_invariant_l1_loss( | |
| model: torch.nn.Module, | |
| X: torch.Tensor, | |
| X_sparse: torch.Tensor, | |
| _min: float, | |
| _max: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Average L1 loss between the model’s output and its input over the | |
| four right‑angle rotations of X (0°, 90°, 180°, 270°). | |
| Args | |
| ---- | |
| model : torch.nn.Module | |
| Any network that maps a tensor shaped like `X` back to itself. | |
| X : torch.Tensor | |
| Image‑like tensor with at least (H, W) spatial dims. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Scalar mean loss (requires_grad=True if model parameters do). | |
| """ | |
| if X.ndim < 2: | |
| raise ValueError("X must have at least 2 spatial dimensions.") | |
| rot_dims = (0, 1) if X.ndim == 2 else (-2, -1) # pick spatial axes | |
| loss = roughness_loss | |
| # Pre‑compute the four rotated views: X, R90(X), R180(X), R270(X) | |
| views = [X_sparse] + [torch.rot90(X_sparse, k, rot_dims) for k in range(1, 4)] | |
| # Evaluate model and loss for each view, then average | |
| losses = [loss(model(v), X, _min, _max) for v in views] | |
| return torch.stack(losses).mean() | |
| def rotation_plus_flip_invariant_loss( | |
| model: torch.nn.Module, | |
| X: torch.Tensor, | |
| X_sparse: torch.Tensor, | |
| _min: float, | |
| _max: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Average L1 loss between the model’s output and its input over the | |
| four right‑angle rotations and horizontal flips of X. | |
| Args | |
| ---- | |
| model : torch.nn.Module | |
| Any network that maps a tensor shaped like `X` back to itself. | |
| X : torch.Tensor | |
| Image‑like tensor with at least (H, W) spatial dims. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Scalar mean loss (requires_grad=True if model parameters do). | |
| """ | |
| if X.ndim < 2: | |
| raise ValueError("X must have at least 2 spatial dimensions.") | |
| rot_dims = (0, 1) if X.ndim == 2 else (-2, -1) # pick spatial axes | |
| views = [X] + [torch.rot90(X, k, rot_dims) for k in range(1, 4)] # rotations | |
| flipped_views = [torch.flip(v, dims=[rot_dims[-1]]) for v in views] # flips | |
| all_views = views + flipped_views | |
| losses = [roughness_loss(model(v), X, _min, _max) for v in all_views] | |
| return torch.stack(losses).mean() | |
| if __name__ == "__main__": | |
| pass | |