| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the CC-by-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from contextlib import nullcontext | |
| from math import ceil | |
| from typing import Callable, Optional, Union | |
| import torch | |
| from torch import Tensor | |
| from torch.nn import functional as F | |
| from flow_matching.path import MixtureDiscreteProbPath | |
| from flow_matching.solver.solver import Solver | |
| from flow_matching.utils import categorical, ModelWrapper | |
| from .utils import get_nearest_times | |
| try: | |
| from tqdm import tqdm | |
| TQDM_AVAILABLE = True | |
| except ImportError: | |
| TQDM_AVAILABLE = False | |
| class MixtureDiscreteEulerSolver(Solver): | |
| r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. | |
| Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is: | |
| .. math:: | |
| \begin{align*} | |
| & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ | |
| & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ | |
| & Z^i_{\text{change}} \sim U[0,1]\\ | |
| & X_{t+h}^i \sim \begin{cases} | |
| \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ | |
| \delta_{X_t^i}(\cdot) \text{ else } | |
| \end{cases} | |
| \end{align*} | |
| Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is: | |
| .. math:: | |
| u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right], | |
| where | |
| .. math:: | |
| \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], | |
| and | |
| .. math:: | |
| \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right]. | |
| The source distribution :math:`p(x^i)` is given by ``p``. | |
| Args: | |
| model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size]. | |
| path (MixtureDiscreteProbPath): Probability path used for x-prediction training. | |
| vocabulary_size (int): size of the discrete vocabulary. | |
| source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| model: ModelWrapper, | |
| path: MixtureDiscreteProbPath, | |
| vocabulary_size: int, | |
| source_distribution_p: Optional[Tensor] = None, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.path = path | |
| self.vocabulary_size = vocabulary_size | |
| if source_distribution_p is not None: | |
| assert source_distribution_p.shape == torch.Size( | |
| [vocabulary_size] | |
| ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}." | |
| self.source_distribution_p = source_distribution_p | |
| def sample( | |
| self, | |
| x_init: Tensor, | |
| step_size: Optional[float], | |
| div_free: Union[float, Callable[[float], float]] = 0.0, | |
| dtype_categorical: torch.dtype = torch.float32, | |
| time_grid: Tensor = torch.tensor([0.0, 1.0]), | |
| return_intermediates: bool = False, | |
| verbose: bool = False, | |
| **model_extras, | |
| ) -> Tensor: | |
| """ | |
| Sample a sequence of discrete values from the given model. | |
| .. code-block:: python | |
| import torch | |
| from flow_matching.utils import ModelWrapper | |
| from flow_matching.solver import MixtureDiscreteEulerSolver | |
| class DummyModel(ModelWrapper): | |
| def __init__(self): | |
| super().__init__(None) | |
| def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: | |
| return ... | |
| model = DummyModel() | |
| solver = MixtureDiscreteEulerSolver(model=model) | |
| x_init = torch.LongTensor([122, 725]) | |
| step_size = 0.001 | |
| time_grid = torch.tensor([0.0, 1.0]) | |
| result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) | |
| Args: | |
| x_init (Tensor): The initial state. | |
| step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid. | |
| div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0. | |
| dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32. | |
| time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). | |
| return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False. | |
| verbose (bool): Whether to print progress bars. Defaults to False. | |
| **model_extras: Additional input for the model. | |
| Returns: | |
| Tensor: The sampled sequence of discrete values. | |
| Raises: | |
| ImportError: To run in verbose mode, tqdm must be installed. | |
| """ | |
| if not div_free == 0.0: | |
| assert ( | |
| self.source_distribution_p is not None | |
| ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity." | |
| # Initialize the current state `x_t` with the initial state `X_0`. | |
| time_grid = time_grid.to(device=x_init.device) | |
| if step_size is None: | |
| # If step_size is None then set the t discretization to time_grid. | |
| t_discretization = time_grid | |
| n_steps = len(time_grid) - 1 | |
| else: | |
| # If step_size is float then t discretization is uniform with step size set by step_size. | |
| t_init = time_grid[0].item() | |
| t_final = time_grid[-1].item() | |
| assert ( | |
| t_final - t_init | |
| ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." | |
| n_steps = ceil((t_final - t_init) / step_size) | |
| t_discretization = torch.tensor( | |
| [t_init + step_size * i for i in range(n_steps)] + [t_final], | |
| device=x_init.device, | |
| ) | |
| if return_intermediates: | |
| # get order of intermediate steps: | |
| order = torch.argsort(time_grid) | |
| # Compute intermediate steps to return via nearest points in t_discretization to time_grid. | |
| time_grid = get_nearest_times( | |
| time_grid=time_grid, t_discretization=t_discretization | |
| ) | |
| x_t = x_init.clone() | |
| steps_counter = 0 | |
| res = [] | |
| if return_intermediates: | |
| res = [x_init.clone()] | |
| if verbose: | |
| if not TQDM_AVAILABLE: | |
| raise ImportError( | |
| "tqdm is required for verbose mode. Please install it." | |
| ) | |
| ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") | |
| else: | |
| ctx = nullcontext() | |
| with ctx: | |
| for i in range(n_steps): | |
| t = t_discretization[i : i + 1] | |
| h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] | |
| # Sample x_1 ~ p_1|t( \cdot |x_t) | |
| p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) | |
| x_1 = categorical(p_1t.to(dtype=dtype_categorical)) | |
| # Checks if final step | |
| if i == n_steps - 1: | |
| x_t = x_1 | |
| else: | |
| # Compute u_t(x|x_t,x_1) | |
| scheduler_output = self.path.scheduler(t=t) | |
| k_t = scheduler_output.alpha_t | |
| d_k_t = scheduler_output.d_alpha_t | |
| delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to( | |
| k_t.dtype | |
| ) | |
| u = d_k_t / (1 - k_t) * delta_1 | |
| # Add divergence-free part | |
| div_free_t = div_free(t) if callable(div_free) else div_free | |
| if div_free_t > 0: | |
| p_0 = self.source_distribution_p[(None,) * x_t.dim()] | |
| u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * ( | |
| (1 - k_t) * p_0 + k_t * delta_1 | |
| ) | |
| # Set u_t(x_t|x_t,x_1) = 0 | |
| delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) | |
| u = torch.where( | |
| delta_t.to(dtype=torch.bool), torch.zeros_like(u), u | |
| ) | |
| # Sample x_t ~ u_t( \cdot |x_t,x_1) | |
| intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 | |
| mask_jump = torch.rand( | |
| size=x_t.shape, device=x_t.device | |
| ) < 1 - torch.exp(-h * intensity) | |
| if mask_jump.sum() > 0: | |
| x_t[mask_jump] = categorical( | |
| u[mask_jump].to(dtype=dtype_categorical) | |
| ) | |
| steps_counter += 1 | |
| t = t + h | |
| if return_intermediates and (t in time_grid): | |
| res.append(x_t.clone()) | |
| if verbose: | |
| ctx.n = t.item() | |
| ctx.refresh() | |
| ctx.set_description(f"NFE: {steps_counter}") | |
| if return_intermediates: | |
| if step_size is None: | |
| return torch.stack(res, dim=0) | |
| else: | |
| return torch.stack(res, dim=0)[order] | |
| else: | |
| return x_t | |
Xet Storage Details
- Size:
- 10.5 kB
- Xet hash:
- 6c211511fdc43e9c195862996c90006c974a0f51910a48965f1623fccd6aeefe
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.