File size: 3,332 Bytes
b5a0bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from dataclasses import dataclass
from typing import Literal, Optional
import torch
import torch.distributions as D
from fairseq2.logging import get_log_writer
from torch import Tensor
from lcm.nn.schedulers import DDIMScheduler
SUPPORTED_SAMPLERS = Literal["uniform", "beta"]
SUPPORTED_WEIGHTINGS = Literal["none", "clamp_snr"]
logger = get_log_writer(__name__)
def beta_function(a, b):
result = torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b))
return result
@dataclass
class StepsSamplerConfig:
sampling: SUPPORTED_SAMPLERS = "uniform"
weighting: SUPPORTED_WEIGHTINGS = "none"
beta_a: float = 0.8
beta_b: float = 1
max_gamma: float = 5.0
min_gamma: float = 0
class StepsSampler(object):
def __init__(
self,
config: StepsSamplerConfig,
noise_scheduler: DDIMScheduler,
):
num_diffusion_train_steps = noise_scheduler.num_diffusion_train_steps
weights: Optional[Tensor] = None
if config.sampling == "uniform":
weights = torch.ones(
num_diffusion_train_steps,
)
elif config.sampling == "beta":
# As motivated in https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/00328.pdf
a = torch.tensor([config.beta_a])
b = torch.tensor([config.beta_b])
# a=1, b=1 -> uniform
# The paper empirically chooses b=1, a=0.8 < 1
steps = (
torch.arange(1, num_diffusion_train_steps + 1)
/ num_diffusion_train_steps
)
weights = (
1 / beta_function(a, b) * (steps ** (a - 1)) * ((1 - steps) ** (b - 1))
)
assert weights is not None, "The sampling weights were not properly set!"
logger.info(f"Training with sampling weights={weights}")
self.distrib = D.Categorical(
probs=weights / weights.sum(),
)
# setup weights for scaling:
if config.weighting == "none":
self.gamma_per_step = None
elif config.weighting == "clamp_snr":
# Min-SNR scheme from
# https://arxiv.org/abs/2303.09556
snrs = noise_scheduler.get_snrs()
# gamma(t) = min(max_gamma, snr(t))
self.gamma_per_step = torch.clamp(
snrs, max=config.max_gamma, min=config.min_gamma
)
logger.info(f"Training with Gamma={self.gamma_per_step}")
@property
def _training_weights(self) -> Tensor:
return self.distrib.probs
def sample(self, size: torch.Size, device: torch.device):
samples = self.distrib.sample(size).to(device)
# print('Samples', samples)
# print('Counts:', torch.bincount(samples.flatten()))
return samples
def get_loss_scales(self, steps):
if self.gamma_per_step is None:
return None
# If we're using constant Gamma=1 (returning None), then the sum of
# the loss scales is steps.numel(), to match the total mass,
# we normalize the scales to sum to steps.numel()
gamma = self.gamma_per_step.to(steps.device)[steps]
gamma = gamma / gamma.sum() * steps.numel()
return gamma
|