| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from torch import Tensor |
| |
|
| | from flow_matching.path.path import ProbPath |
| |
|
| | from flow_matching.path.path_sample import DiscretePathSample |
| | from flow_matching.path.scheduler import ConvexScheduler |
| | from flow_matching.utils import expand_tensor_like, unsqueeze_to_match |
| |
|
| |
|
| | class MixtureDiscreteProbPath(ProbPath): |
| | r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path. |
| | |
| | This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`. |
| | The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`: |
| | |
| | .. math:: |
| | |
| | P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t, |
| | |
| | where :math:`\sigma_t` is provided by the scheduler. |
| | |
| | Example: |
| | |
| | .. code-block:: python |
| | |
| | >>> x_0 = torch.zeros((1, 3, 3)) |
| | >>> x_1 = torch.ones((1, 3, 3)) |
| | |
| | >>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0)) |
| | >>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t |
| | >>> result |
| | tensor([[[0.0, 0.0, 0.0], |
| | [0.0, 0.0, 1.0], |
| | [0.0, 0.0, 0.0]]]) |
| | |
| | >>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t |
| | >>> result |
| | tensor([[[1.0, 0.0, 1.0], |
| | [0.0, 1.0, 0.0], |
| | [0.0, 1.0, 0.0]]]) |
| | |
| | >>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t |
| | >>> result |
| | tensor([[[1.0, 1.0, 1.0], |
| | [1.0, 1.0, 1.0], |
| | [1.0, 1.0, 1.0]]]) |
| | |
| | Args: |
| | scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`. |
| | """ |
| |
|
| | def __init__(self, scheduler: ConvexScheduler): |
| | assert isinstance( |
| | scheduler, ConvexScheduler |
| | ), "Scheduler for ConvexProbPath must be a ConvexScheduler." |
| |
|
| | self.scheduler = scheduler |
| |
|
| | def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: |
| | r"""Sample from the affine probability path: |
| | | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. |
| | | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`. |
| | Args: |
| | x_0 (Tensor): source data point, shape (batch_size, ...). |
| | x_1 (Tensor): target data point, shape (batch_size, ...). |
| | t (Tensor): times in [0,1], shape (batch_size). |
| | |
| | Returns: |
| | DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`. |
| | """ |
| | self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) |
| |
|
| | sigma_t = self.scheduler(t).sigma_t |
| |
|
| | sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) |
| |
|
| | source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t |
| | x_t = torch.where(condition=source_indices, input=x_0, other=x_1) |
| |
|
| | return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t) |
| |
|
| | def posterior_to_velocity( |
| | self, posterior_logits: Tensor, x_t: Tensor, t: Tensor |
| | ) -> Tensor: |
| | r"""Convert the factorized posterior to velocity. |
| | |
| | | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`. |
| | | return :math:`u_t`. |
| | |
| | Args: |
| | posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size). |
| | x_t (Tensor): path sample at time t, shape (...). |
| | t (Tensor): time in [0,1]. |
| | |
| | Returns: |
| | Tensor: velocity. |
| | """ |
| | posterior = torch.softmax(posterior_logits, dim=-1) |
| | vocabulary_size = posterior.shape[-1] |
| | x_t = F.one_hot(x_t, num_classes=vocabulary_size) |
| | t = unsqueeze_to_match(source=t, target=x_t) |
| |
|
| | scheduler_output = self.scheduler(t) |
| |
|
| | kappa_t = scheduler_output.alpha_t |
| | d_kappa_t = scheduler_output.d_alpha_t |
| |
|
| | return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t) |
| |
|