from typing import Callable import torch import torch.nn as nn from torch import Tensor def get_beta_scheduler(name: str) -> Callable: if name == 'linear': return BetaLinear def get_loss_weighting(name: str) -> Callable: if name == 'exponential': return exponential_loss_weighting class BetaLinear(nn.Module): """ Linear scheduling for beta. Input t is always from interval [0, 1]. Args: start: Lower bound (float) end: Upper bound (float) """ def __init__(self, start: float, end: float): super().__init__() self.start = start self.end = end def forward(self, t: Tensor) -> Tensor: return self.start * (1 - t) + self.end * t def integral(self, t: Tensor) -> Tensor: return 0.5 * (self.end - self.start) * t.square() + self.start * t def exponential_loss_weighting(beta_fn, i): return 1 - torch.exp(-beta_fn.integral(i))