Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # | |
| from dataclasses import dataclass | |
| from typing import Literal, Optional | |
| import torch | |
| import torch.distributions as D | |
| from fairseq2.logging import get_log_writer | |
| from torch import Tensor | |
| from lcm.nn.schedulers import DDIMScheduler | |
| SUPPORTED_SAMPLERS = Literal["uniform", "beta"] | |
| SUPPORTED_WEIGHTINGS = Literal["none", "clamp_snr"] | |
| logger = get_log_writer(__name__) | |
| def beta_function(a, b): | |
| result = torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b)) | |
| return result | |
| class StepsSamplerConfig: | |
| sampling: SUPPORTED_SAMPLERS = "uniform" | |
| weighting: SUPPORTED_WEIGHTINGS = "none" | |
| beta_a: float = 0.8 | |
| beta_b: float = 1 | |
| max_gamma: float = 5.0 | |
| min_gamma: float = 0 | |
| class StepsSampler(object): | |
| def __init__( | |
| self, | |
| config: StepsSamplerConfig, | |
| noise_scheduler: DDIMScheduler, | |
| ): | |
| num_diffusion_train_steps = noise_scheduler.num_diffusion_train_steps | |
| weights: Optional[Tensor] = None | |
| if config.sampling == "uniform": | |
| weights = torch.ones( | |
| num_diffusion_train_steps, | |
| ) | |
| elif config.sampling == "beta": | |
| # As motivated in https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/00328.pdf | |
| a = torch.tensor([config.beta_a]) | |
| b = torch.tensor([config.beta_b]) | |
| # a=1, b=1 -> uniform | |
| # The paper empirically chooses b=1, a=0.8 < 1 | |
| steps = ( | |
| torch.arange(1, num_diffusion_train_steps + 1) | |
| / num_diffusion_train_steps | |
| ) | |
| weights = ( | |
| 1 / beta_function(a, b) * (steps ** (a - 1)) * ((1 - steps) ** (b - 1)) | |
| ) | |
| assert weights is not None, "The sampling weights were not properly set!" | |
| logger.info(f"Training with sampling weights={weights}") | |
| self.distrib = D.Categorical( | |
| probs=weights / weights.sum(), | |
| ) | |
| # setup weights for scaling: | |
| if config.weighting == "none": | |
| self.gamma_per_step = None | |
| elif config.weighting == "clamp_snr": | |
| # Min-SNR scheme from | |
| # https://arxiv.org/abs/2303.09556 | |
| snrs = noise_scheduler.get_snrs() | |
| # gamma(t) = min(max_gamma, snr(t)) | |
| self.gamma_per_step = torch.clamp( | |
| snrs, max=config.max_gamma, min=config.min_gamma | |
| ) | |
| logger.info(f"Training with Gamma={self.gamma_per_step}") | |
| def _training_weights(self) -> Tensor: | |
| return self.distrib.probs | |
| def sample(self, size: torch.Size, device: torch.device): | |
| samples = self.distrib.sample(size).to(device) | |
| # print('Samples', samples) | |
| # print('Counts:', torch.bincount(samples.flatten())) | |
| return samples | |
| def get_loss_scales(self, steps): | |
| if self.gamma_per_step is None: | |
| return None | |
| # If we're using constant Gamma=1 (returning None), then the sum of | |
| # the loss scales is steps.numel(), to match the total mass, | |
| # we normalize the scales to sum to steps.numel() | |
| gamma = self.gamma_per_step.to(steps.device)[steps] | |
| gamma = gamma / gamma.sum() * steps.numel() | |
| return gamma | |