World_Model / URSA /diffnext /schedulers /scheduling_dfm.py
BryanW's picture
Add files using upload-large-folder tool
b6ff324 verified
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Simple implementation of discrete flow matching schedulers."""
import dataclasses
import os
from typing import Union
from typing_extensions import Self
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_outputs import BaseOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin
import torch
@dataclasses.dataclass
class KineticOptimalSchedulerOutput(BaseOutput):
"""Output for scheduler's `step` function output."""
prev_sample: torch.LongTensor
class DiscreteProbPath(object):
"""Define a general discrete probability path."""
def __init__(self, emb):
"""Create a ``DiscreteProbPath``.
Args:
emb (Union[torch.Tensor, torch.nn.Embedding])
The codebook embeddings.
"""
self.generator = None
self.emb = emb.weight if isinstance(emb, torch.nn.Embedding) else emb
def categorical(self, prob) -> torch.Tensor:
"""Categorical sampling according to weights in the last dimension.
Args:
prob (torch.Tensor)
The sample token probability, shape (bsz, ..., codebook_size).
Returns:
torch.Tensor: The sample token index, shape (bsz, ...).
"""
return prob.flatten(0, -2).multinomial(1, generator=self.generator).view(*prob.shape[:-1])
class MixtureDiscreteProbPath(DiscreteProbPath):
"""Define a mixture discrete probability path."""
def sample(self, x_1, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""Sample from the affine probability path.
Args:
x_1 (torch.Tensor)
The target token index, shape (bsz, ...).
t (float or torch.Tensor)
The timestep ``t``, shape (bsz,).
Returns:
torch.Tensor: The sample token index at time t, shape (bsz, ...).
"""
t = t.to(self.emb).view([-1] + [1] * (x_1.dim() - 1)) if hasattr(t, "cpu") else t
x_0 = x_1.new_empty(x_1.shape).random_(to=self.emb.shape[0], generator=self.generator)
return x_0.where(t.new_empty(x_1.shape).uniform_(generator=self.generator).lt(1 - t), x_1)
def get_velocity(self, logits, x_t, t: float, x_1=None) -> torch.Tensor:
"""Return the velocity by converting the factorized posterior.
Args:
logits (torch.Tensor)
The sample token logits at time t+1, shape (bsz, ..., codebook_size).
x_t (torch.Tensor)
The sample token index at time t, shape (bsz, ...).
t (float)
The timestep ``t``.
x_1 (torch.Tensor, optional):
The sample token index at time t+1, shape (bsz, ...).
Returns:
torch.Tensor: The velocity ``v``.
"""
x_1 = self.categorical(logits.softmax(-1)) if x_1 is None else x_1
return logits.zero_().scatter_(-1, x_1.unsqueeze(-1), 1 / (1 - t))
class MetricDiscreteProbPath(DiscreteProbPath):
"""Define a metric-induced discrete probability path."""
def __init__(self, emb, alpha=0.9, c=3, eps=1e-5):
"""Create a ``MetricDiscreteProbPath``.
Args:
emb (Union[torch.Tensor, torch.nn.Embedding])
The codebook embeddings.
alpha (float)
The value to ``alpha``.
c (float)
The value to ``c``.
eps (float, *optional*, defaults to 1e-5):
A small value to clip the L2 normalization denominator.
"""
self.alpha, self.c, self.eps, self.generator = alpha, c, eps, None
emb = emb.weight if isinstance(emb, torch.nn.Embedding) else emb
self.emb = torch.nn.functional.normalize(emb, dim=-1, eps=eps)
self.emb_sumsq = self.emb.square().sum(-1)
self.emb_mul2t = self.emb.mul(2).T.contiguous()
def get_dist(self, emb_1: torch.Tensor, emb_2: torch.Tensor = None) -> torch.Tensor:
"""Return the distance between two input embeddings.
Args:
emb_1 (torch.Tensor)
The input1 embeddings, shape (bsz, ..., dim).
emb_2 (torch.Tensor, optional)
The input2 embeddings, shape (bsz, ..., dim) or (bsz, ..., codebook_size).
Returns:
torch.Tensor: The distance, shape (bsz, ..., 1) or (bsz, ..., codebook_size).
"""
emb_1 = torch.nn.functional.normalize(emb_1, dim=-1, eps=self.eps)
if emb_2 is None or emb_1.size() != emb_2.size(): # Distance between input and codebook.
emb_1_sumsq, emb_2_sumsq = emb_1.square().sum(-1, True), self.emb_sumsq
return torch.add(emb_1_sumsq, emb_2_sumsq, out=emb_2).sub_(emb_1 @ self.emb_mul2t)
emb_2 = torch.nn.functional.normalize(emb_2, dim=-1, eps=self.eps)
return emb_1.sub(emb_2).abs_().square_().sum(-1, keepdim=True)
def get_prob(self, emb: torch.Tensor, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""Return the metric-induced probability.
Args:
emb (torch.Tensor)
The input embeddings, shape (bsz, ..., dim).
t (float or torch.Tensor)
The timestep ``t``, shape (bsz,).
Returns:
torch.Tensor: The probability at timestep ``t``, shape (bsz, ..., codebook_size).
"""
beta = self.c * (t / (1 - t)) ** self.alpha
beta = beta.to(emb).view([-1] + [1] * (emb.dim() - 1)) if hasattr(t, "cpu") else beta
return self.get_dist(emb).mul_(-beta).softmax(-1)
def get_prob_by_dist(self, dist: torch.Tensor, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""Return the metric-induced probability by distance.
Args:
dist (torch.Tensor)
The distance, shape (bsz, ..., codebook_size).
t (float or torch.Tensor)
The timestep ``t``, shape (bsz,).
Returns:
torch.Tensor: The probability at timestep ``t``, shape (bsz, ..., codebook_size).
"""
beta = self.c * (t / (1 - t)) ** self.alpha
beta = beta.to(dist).view([-1] + [1] * (dist.dim() - 1)) if hasattr(t, "cpu") else beta
return dist.mul(-beta).softmax(-1)
def sample(self, x_1, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""Sample from the affine probability path.
Args:
x_1 (torch.Tensor)
The target token index, shape (bsz, ...).
t (float or torch.Tensor)
The timestep ``t``, shape (bsz,).
Returns:
torch.Tensor: The sample token index at time t, shape (bsz, ...).
"""
return self.categorical(self.get_prob(self.emb[x_1], t))
def get_velocity(self, logits, x_t, t: float, x_1=None) -> torch.Tensor:
"""Return the velocity by converting the factorized posterior.
Args:
logits (torch.Tensor)
The sample token logits at time t+1, shape (bsz, ..., codebook_size).
x_t (torch.Tensor)
The sample token index at time t, shape (bsz, ...).
t (float)
The timestep ``t``.
x_1 (torch.Tensor, optional):
The sample token index at time t+1, shape (bsz, ...).
Returns:
torch.Tensor: The velocity ``v``, shape (bsz, ..., codebook_size).
"""
numerator = self.c * self.alpha * (t ** (self.alpha - 1)) if t > 0 else 0
d_beta_t = numerator / (1 - t) ** (self.alpha + 1)
emb_x_1 = self.emb[self.categorical(logits.softmax(-1)) if x_1 is None else x_1]
dist_x_1_x = self.get_dist(emb_x_1, logits) # (bsz, ..., codebook_size)
prob_x_1_x = self.get_prob_by_dist(dist_x_1_x, t) # (bsz, ..., codebook_size)
dist_x_t_x_1 = self.get_dist(self.emb[x_t], emb_x_1) # (bsz, ..., 1)
dist = torch.nn.functional.relu(dist_x_1_x.sub_(dist_x_t_x_1).neg_(), inplace=True)
return prob_x_1_x.mul_(d_beta_t).mul_(dist) # (bsz, ..., codebook_size)
class KineticOptimalScheduler(SchedulerMixin, ConfigMixin):
"""Kinetic optimal scheduler with general discrete paths."""
@register_to_config
def __init__(self, alpha=None, c=None, shift=1.0, eps=1e-5, **kwargs):
self.alpha, self.c, self.shift, self.eps = alpha, c, shift, eps
self.init_args, self.path, self.codebook_size = kwargs or {}, None, 0
self.init_args.setdefault("shift", shift) if shift != 1 else None
def __repr__(self) -> str:
"""Return the extra representation of this scheduler."""
s = f"{self.__class__.__name__}"
if self.alpha is None: # Fallback to ``MixtureDiscreteProbPath``.
return s + "(shift={shift})".format(**self.__dict__)
return s + "(alpha={alpha}, c={c}, shift={shift})".format(**self.__dict__)
@classmethod
def from_pretrained(cls, pretrained_path, device=None, dtype=None, **kwargs) -> Self:
"""Instantiate the scheduler from a pretrained model vocabulary."""
return KineticOptimalScheduler().load_pretrained(pretrained_path, device, dtype, **kwargs)
def load_pretrained(self, pretrained_path=None, device=None, dtype=None, **kwargs) -> Self:
"""Load the scheduler from a pretrained model vocabulary."""
pretrained_path = self.init_args.get("pretrained_path", None) or pretrained_path
pretrained_args = super().from_pretrained(pretrained_path, **kwargs).__dict__
pretrained_args.update({"init_args": self.init_args, **self.init_args})
self.__dict__.update(pretrained_args)
model_file = os.path.join(pretrained_path, "scheduler_model.pth")
emb = torch.load(model_file, weights_only=False)["path.emb"]
emb = emb.to(device).to(dtype=dtype or torch.float16)
self.path = MetricDiscreteProbPath(emb=emb, alpha=self.alpha, c=self.c, eps=self.eps)
self.path = MixtureDiscreteProbPath(emb=emb) if self.alpha is None else self.path
self.codebook_size = self.path.emb.size(0)
return self
def to(self, device=None, dtype=None) -> Self:
"""Convert to given device and dtype."""
for k, v in self.path.__dict__.items():
self.path.__dict__[k] = v.to(device, dtype) if isinstance(v, torch.Tensor) else v
return self
def sample_timesteps(self, size, device=None, generator=None) -> torch.Tensor:
"""Sample a batch of timesteps for training.
Args:
size (Tuple[int])
The sample size of timesteps.
device (torch.device, optional)
The output device.
generator (torch.Generator, optional):
The random generator.
"""
sigma = 1 - torch.rand(size, device=device, generator=generator).mul_(0.999)
return 1 - self.shift * sigma / (1 + (self.shift - 1) * sigma)
def set_timesteps(self, num_inference_steps, *args, **kwargs):
"""Set the inference timesteps for sampling.
Args:
num_inference_steps (int)
The number of inference steps.
"""
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(num_inference_steps).tolist()
def add_noise(self, original_samples, timesteps, generator=None) -> torch.Tensor:
"""Add forward noise to samples.
Args:
original_samples (torch.Tensor)
The sample token index, shape (bsz, ...).
t (float or torch.Tensor)
The timestep ``t``, shape (bsz,).
generator (torch.Generator, optional):
The random generator.
Returns:
torch.Tensor: The sample token index at time t, shape (bsz, ...).
"""
self.path.generator = generator if generator else self.path.generator
return self.path.sample(original_samples, timesteps)
def timestep_to_t(self, timestep) -> float:
"""Return the ``t`` for given timestep.
Args:
timestep (int)
The discrete timestep index.
Returns:
float: The continuous timestep in [0, 1).
"""
sigma = 1 - self.timesteps[timestep] / self.num_inference_steps
return 1 - self.shift * sigma / (1 + (self.shift - 1) * sigma)
def step(
self,
model_output,
timestep,
sample,
generator=None,
return_dict=True,
) -> KineticOptimalSchedulerOutput:
"""Predict the sample from the previous timestep.
Args:
model_output (torch.Tensor)
The sample token logits at time t+1, shape (bsz, ..., codebook_size).
timestep (int)
The discrete timestep index.
sample (torch.Tensor)
The sample token index at time t, shape (bsz, ...).
generator (torch.Generator, optional):
The random generator.
return_dict (bool, optional)
Whether return the output in a dict.
Returns:
torch.Tensor: The sample token index at time t+1, shape (bsz, ...).
"""
self.path.generator = generator if generator else self.path.generator
if timestep == self.num_inference_steps - 1:
prev_sample = self.path.categorical(model_output.softmax(-1))
else:
t = self.timestep_to_t(timestep)
dt = self.timestep_to_t(timestep + 1) - t
v = self.path.get_velocity(model_output, sample, t)
u_dist = torch.empty_like(sample, dtype=v.dtype).uniform_(generator=generator)
jump_thresh = 1 - v.scatter_(-1, sample[..., None], 0).sum(-1).mul_(-dt).exp_()
prev_sample, jump_index = sample.clone(), u_dist < jump_thresh
prev_sample[jump_index] = self.path.categorical(v[jump_index])
if not return_dict:
return (prev_sample,)
return KineticOptimalSchedulerOutput(prev_sample=prev_sample)