File size: 3,332 Bytes
b5a0bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#

from dataclasses import dataclass
from typing import Literal, Optional

import torch
import torch.distributions as D
from fairseq2.logging import get_log_writer
from torch import Tensor

from lcm.nn.schedulers import DDIMScheduler

SUPPORTED_SAMPLERS = Literal["uniform", "beta"]
SUPPORTED_WEIGHTINGS = Literal["none", "clamp_snr"]

logger = get_log_writer(__name__)


def beta_function(a, b):
    result = torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b))
    return result


@dataclass
class StepsSamplerConfig:
    sampling: SUPPORTED_SAMPLERS = "uniform"
    weighting: SUPPORTED_WEIGHTINGS = "none"
    beta_a: float = 0.8
    beta_b: float = 1
    max_gamma: float = 5.0
    min_gamma: float = 0


class StepsSampler(object):
    def __init__(
        self,
        config: StepsSamplerConfig,
        noise_scheduler: DDIMScheduler,
    ):
        num_diffusion_train_steps = noise_scheduler.num_diffusion_train_steps
        weights: Optional[Tensor] = None

        if config.sampling == "uniform":
            weights = torch.ones(
                num_diffusion_train_steps,
            )

        elif config.sampling == "beta":
            # As motivated in https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/00328.pdf
            a = torch.tensor([config.beta_a])
            b = torch.tensor([config.beta_b])
            # a=1, b=1 -> uniform
            # The paper empirically chooses b=1, a=0.8 < 1

            steps = (
                torch.arange(1, num_diffusion_train_steps + 1)
                / num_diffusion_train_steps
            )
            weights = (
                1 / beta_function(a, b) * (steps ** (a - 1)) * ((1 - steps) ** (b - 1))
            )

        assert weights is not None, "The sampling weights were not properly set!"
        logger.info(f"Training with sampling weights={weights}")

        self.distrib = D.Categorical(
            probs=weights / weights.sum(),
        )

        # setup weights for scaling:
        if config.weighting == "none":
            self.gamma_per_step = None

        elif config.weighting == "clamp_snr":
            # Min-SNR scheme from
            # https://arxiv.org/abs/2303.09556
            snrs = noise_scheduler.get_snrs()
            # gamma(t) = min(max_gamma, snr(t))
            self.gamma_per_step = torch.clamp(
                snrs, max=config.max_gamma, min=config.min_gamma
            )

        logger.info(f"Training with Gamma={self.gamma_per_step}")

    @property
    def _training_weights(self) -> Tensor:
        return self.distrib.probs

    def sample(self, size: torch.Size, device: torch.device):
        samples = self.distrib.sample(size).to(device)
        # print('Samples', samples)
        # print('Counts:', torch.bincount(samples.flatten()))
        return samples

    def get_loss_scales(self, steps):
        if self.gamma_per_step is None:
            return None

        # If we're using constant Gamma=1 (returning None), then the sum of
        # the loss scales is steps.numel(), to match the total mass,
        # we normalize the scales to sum to steps.numel()
        gamma = self.gamma_per_step.to(steps.device)[steps]
        gamma = gamma / gamma.sum() * steps.numel()
        return gamma