| | |
| | |
| | |
| | |
| | |
| |
|
| | from contextlib import nullcontext |
| | from math import ceil |
| | from typing import Callable, Optional, Union |
| |
|
| | import torch |
| | from torch import Tensor |
| | import gc |
| | from torch.nn import functional as F |
| |
|
| | from flow_matching.path import MixtureDiscreteProbPath |
| |
|
| | from flow_matching.solver.solver import Solver |
| | from flow_matching.utils import categorical, ModelWrapper |
| | from .utils import get_nearest_times |
| | from ..utils.multi_guidance import * |
| |
|
| | try: |
| | from tqdm import tqdm |
| |
|
| | TQDM_AVAILABLE = True |
| | except ImportError: |
| | TQDM_AVAILABLE = False |
| |
|
| |
|
| | class MixtureDiscreteEulerSolver(Solver): |
| | r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. |
| | Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is: |
| | |
| | .. math:: |
| | |
| | \begin{align*} |
| | & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ |
| | & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ |
| | & Z^i_{\text{change}} \sim U[0,1]\\ |
| | & X_{t+h}^i \sim \begin{cases} |
| | \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ |
| | \delta_{X_t^i}(\cdot) \text{ else } |
| | \end{cases} |
| | \end{align*} |
| | |
| | Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is: |
| | |
| | .. math:: |
| | |
| | u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right], |
| | |
| | where |
| | |
| | .. math:: |
| | \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], |
| | |
| | and |
| | |
| | .. math:: |
| | |
| | \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right]. |
| | |
| | The source distribution :math:`p(x^i)` is given by ``p``. |
| | |
| | Args: |
| | model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size]. |
| | path (MixtureDiscreteProbPath): Probability path used for x-prediction training. |
| | vocabulary_size (int): size of the discrete vocabulary. |
| | source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model: ModelWrapper, |
| | path: MixtureDiscreteProbPath, |
| | vocabulary_size: int, |
| | source_distribution_p: Optional[Tensor] = None, |
| | ): |
| | super().__init__() |
| | self.model = model |
| | self.path = path |
| | self.vocabulary_size = vocabulary_size |
| |
|
| | if source_distribution_p is not None: |
| | assert source_distribution_p.shape == torch.Size( |
| | [vocabulary_size] |
| | ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}." |
| |
|
| | self.source_distribution_p = source_distribution_p |
| |
|
| | @torch.no_grad() |
| | def sample( |
| | self, |
| | x_init: Tensor, |
| | step_size: Optional[float], |
| | div_free: Union[float, Callable[[float], float]] = 0.0, |
| | dtype_categorical: torch.dtype = torch.float32, |
| | time_grid: Tensor = torch.tensor([0.0, 1.0]), |
| | return_intermediates: bool = False, |
| | verbose: bool = False, |
| | **model_extras, |
| | ) -> Tensor: |
| | """ |
| | Sample a sequence of discrete values from the given model. |
| | |
| | .. code-block:: python |
| | |
| | import torch |
| | from flow_matching.utils import ModelWrapper |
| | from flow_matching.solver import MixtureDiscreteEulerSolver |
| | |
| | class DummyModel(ModelWrapper): |
| | def __init__(self): |
| | super().__init__(None) |
| | def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: |
| | return ... |
| | |
| | model = DummyModel() |
| | solver = MixtureDiscreteEulerSolver(model=model) |
| | |
| | x_init = torch.LongTensor([122, 725]) |
| | step_size = 0.001 |
| | time_grid = torch.tensor([0.0, 1.0]) |
| | |
| | result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) |
| | |
| | Args: |
| | x_init (Tensor): The initial state. |
| | step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid. |
| | div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0. |
| | dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32. |
| | time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). |
| | return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False. |
| | verbose (bool): Whether to print progress bars. Defaults to False. |
| | **model_extras: Additional input for the model. |
| | |
| | Returns: |
| | Tensor: The sampled sequence of discrete values. |
| | |
| | Raises: |
| | ImportError: To run in verbose mode, tqdm must be installed. |
| | """ |
| | if not div_free == 0.0: |
| | assert ( |
| | self.source_distribution_p is not None |
| | ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity." |
| |
|
| | |
| | time_grid = time_grid.to(device=x_init.device) |
| |
|
| | if step_size is None: |
| | |
| | t_discretization = time_grid |
| | n_steps = len(time_grid) - 1 |
| | else: |
| | |
| | t_init = time_grid[0].item() |
| | t_final = time_grid[-1].item() |
| | assert ( |
| | t_final - t_init |
| | ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." |
| |
|
| | n_steps = ceil((t_final - t_init) / step_size) |
| | t_discretization = torch.tensor( |
| | [t_init + step_size * i for i in range(n_steps)] + [t_final], |
| | device=x_init.device, |
| | ) |
| |
|
| | if return_intermediates: |
| | |
| | order = torch.argsort(time_grid) |
| | |
| | time_grid = get_nearest_times( |
| | time_grid=time_grid, t_discretization=t_discretization |
| | ) |
| |
|
| | x_t = x_init.clone() |
| | steps_counter = 0 |
| | res = [] |
| |
|
| | if return_intermediates: |
| | res = [x_init.clone()] |
| |
|
| | if verbose: |
| | if not TQDM_AVAILABLE: |
| | raise ImportError( |
| | "tqdm is required for verbose mode. Please install it." |
| | ) |
| | ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") |
| | else: |
| | ctx = nullcontext() |
| |
|
| | with ctx: |
| | for i in range(n_steps): |
| | t = t_discretization[i : i + 1] |
| | h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] |
| |
|
| | |
| | p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) |
| | x_1 = categorical(p_1t.to(dtype=dtype_categorical)) |
| |
|
| | |
| | if i == n_steps - 1: |
| | x_t = x_1 |
| | else: |
| | |
| | scheduler_output = self.path.scheduler(t=t) |
| |
|
| | k_t = scheduler_output.alpha_t |
| | d_k_t = scheduler_output.d_alpha_t |
| |
|
| | delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to( |
| | k_t.dtype |
| | ) |
| | u = d_k_t / (1 - k_t) * delta_1 |
| |
|
| | |
| | div_free_t = div_free(t) if callable(div_free) else div_free |
| |
|
| | if div_free_t > 0: |
| | p_0 = self.source_distribution_p[(None,) * x_t.dim()] |
| | u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * ( |
| | (1 - k_t) * p_0 + k_t * delta_1 |
| | ) |
| |
|
| | |
| | delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) |
| | u = torch.where( |
| | delta_t.to(dtype=torch.bool), torch.zeros_like(u), u |
| | ) |
| | |
| | |
| | |
| | |
| | intensity = u.sum(dim=-1) |
| | mask_jump = torch.rand(size=x_t.shape, device=x_t.device) < 1 - torch.exp(-h * intensity) |
| |
|
| | if mask_jump.sum() > 0: |
| | x_t[mask_jump] = categorical( |
| | u[mask_jump].to(dtype=dtype_categorical) |
| | ) |
| |
|
| | steps_counter += 1 |
| | t = t + h |
| |
|
| | if return_intermediates and (t in time_grid): |
| | res.append(x_t.clone()) |
| |
|
| | if verbose: |
| | ctx.n = t.item() |
| | ctx.refresh() |
| | ctx.set_description(f"NFE: {steps_counter}") |
| |
|
| | if return_intermediates: |
| | if step_size is None: |
| | return torch.stack(res, dim=0) |
| | else: |
| | return torch.stack(res, dim=0)[order] |
| | else: |
| | return x_t |
| |
|
| |
|
| | @torch.no_grad() |
| | def multi_guidance_sample( |
| | self, |
| | args, |
| | x_init: Tensor, |
| | step_size: Optional[float], |
| | div_free: Union[float, Callable[[float], float]] = 0.0, |
| | dtype_categorical: torch.dtype = torch.float32, |
| | time_grid: Tensor = torch.tensor([0.0, 1.0]), |
| | return_intermediates: bool = False, |
| | verbose: bool = False, |
| | score_models: list = None, |
| | num_objectives: int = 1, |
| | weights: list = None, |
| | **model_extras, |
| | ) -> Tensor: |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import pdb |
| |
|
| | if not div_free == 0.0: |
| | raise NotImplementedError |
| |
|
| | |
| | time_grid = time_grid.to(device=x_init.device) |
| |
|
| | if step_size is None: |
| | |
| | t_discretization = time_grid |
| | n_steps = len(time_grid) - 1 |
| | else: |
| | |
| | t_init = time_grid[0].item() |
| | t_final = time_grid[-1].item() |
| | assert ( |
| | t_final - t_init |
| | ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." |
| |
|
| | n_steps = ceil((t_final - t_init) / step_size) |
| | t_discretization = torch.tensor( |
| | [t_init + step_size * i for i in range(n_steps)] + [t_final], |
| | device=x_init.device, |
| | ) |
| |
|
| | if return_intermediates: |
| | |
| | order = torch.argsort(time_grid) |
| | |
| | time_grid = get_nearest_times( |
| | time_grid=time_grid, t_discretization=t_discretization |
| | ) |
| |
|
| | x_t = x_init.clone() |
| | steps_counter = 0 |
| | res = [] |
| |
|
| | if return_intermediates: |
| | res = [x_init.clone()] |
| |
|
| | if verbose: |
| | if not TQDM_AVAILABLE: |
| | raise ImportError( |
| | "tqdm is required for verbose mode. Please install it." |
| | ) |
| | ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") |
| | else: |
| | ctx = nullcontext() |
| |
|
| | |
| | if weights is not None: |
| | w = torch.tensor(weights).to(device=x_init.device) |
| | else: |
| | w, _ = select_random_weight_vector(num_objectives, args.num_div) |
| | |
| | w = w.to(device=x_init.device) |
| | print(f"Weight Vector: {w}") |
| | Phi = args.Phi_init |
| | ema_r_t = None |
| |
|
| | with ctx: |
| | for i in range(n_steps): |
| | t = t_discretization[i : i + 1] |
| | h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] |
| |
|
| | p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) |
| | x_1 = categorical(p_1t.to(dtype=dtype_categorical)) |
| |
|
| | |
| | if i != n_steps - 1: |
| | |
| | scheduler_output = self.path.scheduler(t=t) |
| | k_t = scheduler_output.alpha_t |
| | d_k_t = scheduler_output.d_alpha_t |
| | u_t = d_k_t / (1 - k_t) * p_1t |
| |
|
| | guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, args) |
| |
|
| | best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t) |
| | |
| | |
| |
|
| | |
| |
|
| | x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h) |
| |
|
| |
|
| | steps_counter += 1 |
| | t = t + h |
| |
|
| | scores = [] |
| | for i, s in enumerate(score_models): |
| | sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
| | if 't' in sig.parameters: |
| | candidate_scores = s(x_t, 1) |
| | else: |
| | candidate_scores = s(x_t) |
| |
|
| | if isinstance(candidate_scores, tuple): |
| | for score in candidate_scores: |
| | scores.append(score.item()) |
| | else: |
| | scores.append(candidate_scores.item()) |
| | print(scores) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | if return_intermediates and (t in time_grid): |
| | res.append(x_t.clone()) |
| |
|
| | if verbose: |
| | ctx.n = t.item() |
| | ctx.refresh() |
| | ctx.set_description(f"NFE: {steps_counter}") |
| |
|
| | |
| | if return_intermediates: |
| | if step_size is None: |
| | return torch.stack(res, dim=0) |
| | else: |
| | return torch.stack(res, dim=0)[order] |
| | else: |
| | |
| | return x_t |