|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from fairseq2.generation.generator import ( |
|
|
GenerationCounters, |
|
|
Hypothesis, |
|
|
SequenceGeneratorOutput, |
|
|
) |
|
|
from fairseq2.logging import get_log_writer |
|
|
|
|
|
from lcm.datasets.batch import EmbeddingsBatch, PaddingMask |
|
|
from lcm.inference.lcm.generator import ( |
|
|
LCMGenerator, |
|
|
LCMGeneratorOptions, |
|
|
) |
|
|
from lcm.models.abstract_lcm import AbstractLCModel |
|
|
from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModel |
|
|
from lcm.nn.incremental_state import LCMIncrementalStateBag |
|
|
|
|
|
logger = get_log_writer(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DiffusionLCMGeneratorOptions(LCMGeneratorOptions): |
|
|
"""Holds the options to pass to a diffusion-based sequence generator.""" |
|
|
|
|
|
guidance_scale: float = 1.0 |
|
|
"""The weight of the regular classifier-free guidance. |
|
|
Here `guidance_scale` is defined as the guidance weight `w` of |
|
|
Equation (2) of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf. |
|
|
`guidance_scale = 1` corresponds to doing no classifier free guidance. |
|
|
A higher guidance scale value encourages the model to generate outputs |
|
|
closely related to the `prompt` at the expense of lower quality.""" |
|
|
|
|
|
guidance_rescale: float = 0.0 |
|
|
"""The rescaling factor for Classifier-Free Guidance with Rescale |
|
|
(Algorithm 2 - https://arxiv.org/pdf/2305.08891)""" |
|
|
|
|
|
ddim_eta: float = 0.0 |
|
|
"""The weight of noise for added noise in diffusion step. |
|
|
It controls the level of interpolation between a deterministic |
|
|
DDIM (at eta=0.0) and a stochastic DDPM (at eta = 1.0) |
|
|
See section 5 of the DDIM paper https://arxiv.org/pdf/2010.02502 """ |
|
|
|
|
|
epsilon_scaling: Optional[float] = None |
|
|
"""epsilon_scaling: Optional[float] if not None, the predicted epsilon will |
|
|
be scaled down by the provided factor as |
|
|
introduced in https://arxiv.org/pdf/2308.15321""" "" |
|
|
|
|
|
initial_noise_scale: float = 1.0 |
|
|
"""For Diffusion models, scaling of initial noise""" |
|
|
|
|
|
inference_timesteps: int = 100 |
|
|
"""For Diffusion models, number of denoising timesteps""" |
|
|
|
|
|
clip_noise: int = 100 |
|
|
"""For Diffusion models, factor to clip noise of the sampling steps""" |
|
|
|
|
|
thresholding: bool = False |
|
|
"""Whether to use the "dynamic thresholding" method. |
|
|
This is unsuitable for latent-space diffusion models such as Stable Diffusion.""" |
|
|
|
|
|
dynamic_thresholding_ratio: float = 0.995 |
|
|
"""The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.""" |
|
|
|
|
|
sample_max_value: float = 6.0 |
|
|
"""The threshold value for dynamic thresholding. Valid only when `thresholding=True`.""" |
|
|
|
|
|
|
|
|
class TwoTowerDiffusionLCMGenerator(LCMGenerator): |
|
|
"""Generates with a Two-tower Diffusion LCM model.""" |
|
|
|
|
|
options: DiffusionLCMGeneratorOptions |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: AbstractLCModel, |
|
|
options: Optional[LCMGeneratorOptions] = None, |
|
|
eos_vec: Optional[torch.Tensor] = None, |
|
|
) -> None: |
|
|
super().__init__(model, options, eos_vec) |
|
|
|
|
|
assert isinstance(self.model, TwoTowerDiffusionLCModel), ( |
|
|
"The TwoTowerDiffusionLCMGenerator expects a Diffusion LCM" |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Setting up the model with decoding_options: {options} -- {type(options)}" |
|
|
) |
|
|
model.prep_for_denoising(options) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def __call__( |
|
|
self, |
|
|
batch_input: EmbeddingsBatch, |
|
|
max_gen_len: Optional[int] = None, |
|
|
min_gen_len: Optional[int] = None, |
|
|
temperature: float = 0.0, |
|
|
disable_cache: bool = False, |
|
|
**kwargs, |
|
|
) -> SequenceGeneratorOutput: |
|
|
""" |
|
|
:param input: |
|
|
`bacth_input` embedded and padded tensor sequence of the inputs |
|
|
`max_gen_len` max length to be generated for the given input |
|
|
`min_gen_len` minimum length to be generated for the given input |
|
|
`disable_cache` if True, do not use kv-caching |
|
|
`temperature` temperature to control the generation |
|
|
:returns: |
|
|
The output of the LCM generator, consists of :math:`N` lists of |
|
|
hypotheses for :math:`N` prompts. Each list has 1 Hypothesis |
|
|
(beam size = 1), of which `seq` has the *Shape:* math:`(S+T, D)` |
|
|
(:math:`S` is the prompt length, :math:`T` the length of the |
|
|
generated sequence after the prompt and :math:`D` the model |
|
|
dimension.) |
|
|
|
|
|
""" |
|
|
if self.options.seed: |
|
|
torch.manual_seed(self.options.seed) |
|
|
|
|
|
|
|
|
batch_size, self.max_prompt_len, embed_dim = batch_input.seqs.size() |
|
|
prompt_padding_mask = batch_input.padding_mask |
|
|
if prompt_padding_mask is None: |
|
|
self.min_prompt_len = self.max_prompt_len |
|
|
self.prompt_padding_mask = None |
|
|
self.prompt_seq_lens = None |
|
|
else: |
|
|
self.prompt_seq_lens = prompt_padding_mask.seq_lens |
|
|
assert self.prompt_seq_lens is not None, ( |
|
|
"Expecting a valid `self.prompt_seq_lens` Tensor, found `None`" |
|
|
) |
|
|
self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item()) |
|
|
|
|
|
|
|
|
self.prompt_padding_mask = prompt_padding_mask.materialize() |
|
|
|
|
|
if not max_gen_len: |
|
|
max_gen_len = self.max_seq_len |
|
|
|
|
|
|
|
|
|
|
|
assert max_gen_len <= self.max_seq_len, ( |
|
|
f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" |
|
|
) |
|
|
self.max_gen_len = max_gen_len |
|
|
|
|
|
if not min_gen_len: |
|
|
min_gen_len = self.min_seq_len |
|
|
|
|
|
assert min_gen_len > 0, ( |
|
|
f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" |
|
|
) |
|
|
self.min_gen_len = min_gen_len |
|
|
|
|
|
if temperature == 0.0: |
|
|
|
|
|
|
|
|
temperature = self.options.lcm_temperature |
|
|
|
|
|
|
|
|
dtype = self.model.dtype |
|
|
device = batch_input.seqs.device |
|
|
self.temperature = temperature |
|
|
|
|
|
if disable_cache: |
|
|
self.state_bag = None |
|
|
self.context_state_bag = None |
|
|
else: |
|
|
self.state_bag = LCMIncrementalStateBag( |
|
|
self.max_prompt_len + self.max_gen_len |
|
|
) |
|
|
self.context_state_bag = LCMIncrementalStateBag( |
|
|
self.max_prompt_len + self.max_gen_len |
|
|
) |
|
|
|
|
|
|
|
|
self.seqs = torch.zeros( |
|
|
(batch_size, self.max_prompt_len + self.max_gen_len, embed_dim), |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
self.step_scores = torch.zeros( |
|
|
(batch_size, self.max_prompt_len + self.max_gen_len), |
|
|
device=device, |
|
|
) |
|
|
self.lengths = torch.zeros(batch_size, dtype=torch.int, device=device) - 1 |
|
|
|
|
|
|
|
|
self.sample_indices = torch.arange(batch_size, device=device) |
|
|
|
|
|
self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)] |
|
|
|
|
|
|
|
|
self.seqs[:, : self.max_prompt_len] = batch_input.seqs[:, : self.max_prompt_len] |
|
|
self.step_nr = self.min_prompt_len |
|
|
|
|
|
|
|
|
self.prefill() |
|
|
|
|
|
for self.step_nr in range( |
|
|
self.min_prompt_len, self.max_prompt_len + self.max_gen_len |
|
|
): |
|
|
if not self._step(): |
|
|
break |
|
|
|
|
|
return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters()) |
|
|
|
|
|
def state_bag_reorder(self, new_order: torch.Tensor) -> None: |
|
|
if self.state_bag is not None: |
|
|
self.state_bag.reorder(new_order) |
|
|
|
|
|
if self.context_state_bag is not None: |
|
|
self.context_state_bag.reorder(new_order) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def prefill(self, **kwargs) -> None: |
|
|
"""encode the prefix with the context encoder""" |
|
|
|
|
|
assert self.context_state_bag is not None, ( |
|
|
"Expecting a context state bag to prefill" |
|
|
) |
|
|
|
|
|
context: EmbeddingsBatch |
|
|
|
|
|
prefill_len = self.step_nr - 1 |
|
|
if prefill_len > 0: |
|
|
|
|
|
input_seqs = self.seqs[:, :prefill_len] |
|
|
if self.model.config.sonar_normalizer_name is not None: |
|
|
input_seqs = self.model.sonar_normalizer.normalize(input_seqs) |
|
|
|
|
|
context = self.model.encode( |
|
|
EmbeddingsBatch(input_seqs, None), |
|
|
state_bag=self.context_state_bag, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
self.context_state_bag.increment_step_nr(prefill_len) |
|
|
|
|
|
else: |
|
|
logger.warning( |
|
|
f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix" |
|
|
) |
|
|
context = EmbeddingsBatch( |
|
|
torch.empty( |
|
|
(self.seqs.shape[0], 0, self.model.model_dim), |
|
|
dtype=self.seqs.dtype, |
|
|
device=self.seqs.device, |
|
|
) |
|
|
) |
|
|
|
|
|
self.context = context |
|
|
|
|
|
@torch.inference_mode() |
|
|
def _decode( |
|
|
self, |
|
|
seqs: torch.Tensor, |
|
|
padding_mask: Optional[PaddingMask] = None, |
|
|
**kwargs, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
output, context = self.model.predict_next_sentence( |
|
|
batch=EmbeddingsBatch(seqs, padding_mask), |
|
|
context=self.context, |
|
|
temperature=self.temperature, |
|
|
state_bag=self.state_bag, |
|
|
context_state_bag=self.context_state_bag, |
|
|
**kwargs, |
|
|
) |
|
|
self.context = context |
|
|
|
|
|
|
|
|
scores = torch.zeros(seqs.shape[:-1]) |
|
|
return output.seqs, scores |
|
|
|
|
|
def _step(self) -> bool: |
|
|
|
|
|
|
|
|
if self.state_bag is None: |
|
|
|
|
|
|
|
|
|
|
|
model_output, step_score = self._decode( |
|
|
seqs=self.seqs[:, : self.step_nr], |
|
|
padding_mask=None, |
|
|
) |
|
|
else: |
|
|
|
|
|
model_output, step_score = self._decode( |
|
|
seqs=self.seqs[:, self.step_nr - 1 : self.step_nr], |
|
|
padding_mask=None, |
|
|
) |
|
|
|
|
|
self.state_bag.increment_step_nr() |
|
|
|
|
|
|
|
|
return self.finalize_step(model_output, step_score) |
|
|
|
|
|
def finalize_step( |
|
|
self, model_output: torch.Tensor, step_score: torch.Tensor |
|
|
) -> bool: |
|
|
"""Post-processing and finalizing a step |
|
|
by checking all stopping criteria |
|
|
Takes the model's outputed embeddings (model_output) |
|
|
and their associated scores (step_score) |
|
|
If we're stepping, return True, else return False |
|
|
""" |
|
|
already_finished = self.lengths > -1 |
|
|
should_finish_now = torch.zeros_like(already_finished) |
|
|
|
|
|
model_last_output = model_output[:, -1] |
|
|
device = model_last_output.device |
|
|
|
|
|
|
|
|
must_keep_going = None |
|
|
if self.step_nr < self.max_prompt_len: |
|
|
assert self.prompt_padding_mask is not None, ( |
|
|
f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}" |
|
|
) |
|
|
mask = self.prompt_padding_mask[:, self.step_nr] |
|
|
model_last_output[mask] = self.seqs[mask, self.step_nr] |
|
|
must_keep_going = mask |
|
|
|
|
|
|
|
|
if self.eos_threshold is not None and self.eos_vec is not None: |
|
|
sim2eos = torch.nn.functional.cosine_similarity( |
|
|
self.eos_vec.to(device), model_last_output |
|
|
) |
|
|
logger.debug(f"Similarity to eos vector: {sim2eos} vs {self.eos_threshold}") |
|
|
should_finish_now = should_finish_now | sim2eos.ge(self.eos_threshold) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.options.stop_on_repetition_cosine_threshold is not None |
|
|
and self.step_nr > 0 |
|
|
): |
|
|
sim2prev = torch.nn.functional.cosine_similarity( |
|
|
self.seqs[:, self.step_nr - 1], model_last_output |
|
|
) |
|
|
logger.debug( |
|
|
f"Similarity to prev vector: {sim2prev} vs {self.options.stop_on_repetition_cosine_threshold}" |
|
|
) |
|
|
should_finish_now = should_finish_now | sim2prev.ge( |
|
|
self.options.stop_on_repetition_cosine_threshold |
|
|
) |
|
|
|
|
|
if must_keep_going is not None: |
|
|
logger.debug( |
|
|
f"Must keep going (to cover max_prompt_len={self.max_prompt_len}) is not None = {must_keep_going}" |
|
|
) |
|
|
should_finish_now = should_finish_now & ~must_keep_going |
|
|
|
|
|
|
|
|
if self.prompt_seq_lens is not None: |
|
|
longuer_than_min_gen_len = (self.step_nr - self.prompt_seq_lens).ge( |
|
|
self.min_gen_len |
|
|
) |
|
|
else: |
|
|
longuer_than_min_gen_len = ( |
|
|
self.step_nr - self.max_prompt_len |
|
|
) >= self.min_gen_len |
|
|
|
|
|
logger.debug( |
|
|
f"Longuer than min_gen_len ({self.min_gen_len}) = {longuer_than_min_gen_len}" |
|
|
) |
|
|
should_finish_now = should_finish_now & longuer_than_min_gen_len |
|
|
stopped_on_eos = should_finish_now |
|
|
|
|
|
|
|
|
if self.prompt_seq_lens is not None: |
|
|
exceeds_max_gen_len = (self.step_nr - self.prompt_seq_lens + 1).ge( |
|
|
self.max_gen_len |
|
|
) |
|
|
logger.debug( |
|
|
f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: {self.prompt_seq_lens}; steps exceeded: {self.max_gen_len + self.prompt_seq_lens}" |
|
|
) |
|
|
|
|
|
else: |
|
|
exceeds_max_gen_len = ( |
|
|
self.step_nr - self.max_prompt_len + 1 |
|
|
) >= self.max_gen_len |
|
|
logger.debug( |
|
|
f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: None (unique length: {self.max_prompt_len}); steps exceeded: {self.max_prompt_len + self.max_gen_len}" |
|
|
) |
|
|
|
|
|
logger.debug( |
|
|
f"Stopping criteria: {should_finish_now}; exceeds max len: {exceeds_max_gen_len}; already finished: {already_finished}" |
|
|
) |
|
|
|
|
|
should_finish_now = should_finish_now | exceeds_max_gen_len |
|
|
|
|
|
|
|
|
should_finish_now = should_finish_now & ~already_finished |
|
|
self.lengths[should_finish_now] = self.step_nr + 1 |
|
|
|
|
|
|
|
|
self.seqs[:, self.step_nr] = model_last_output.squeeze(1) |
|
|
self.step_scores[:, self.step_nr - self.min_prompt_len] = step_score[:, -1] |
|
|
|
|
|
|
|
|
finished_mask = self.lengths.ne(-1) |
|
|
finished_indices = finished_mask.nonzero() |
|
|
|
|
|
|
|
|
if len(finished_indices) > 0: |
|
|
for idx in finished_indices: |
|
|
self.finish_sequence(int(idx), is_eos=bool(stopped_on_eos[int(idx)])) |
|
|
|
|
|
active_mask = ~finished_mask |
|
|
active_indices = active_mask.nonzero().squeeze(-1) |
|
|
|
|
|
if len(active_indices) == 0: |
|
|
return False |
|
|
|
|
|
self.reorder_state(active_indices) |
|
|
|
|
|
return True |
|
|
|
|
|
def finish_sequence(self, idx: int, is_eos: bool = False) -> None: |
|
|
seq_len = int(self.lengths[idx].item()) |
|
|
|
|
|
if self.options.trim_hypotheses and self.lengths[idx].item() > -1 and is_eos: |
|
|
seq_len = int(self.lengths[idx].item()) - int( |
|
|
not self.options.include_eos_token |
|
|
) |
|
|
|
|
|
sample_idx = int(self.sample_indices[idx]) |
|
|
self.hypotheses[sample_idx] = [ |
|
|
Hypothesis( |
|
|
seq=self.seqs[idx, :seq_len], |
|
|
score=None, |
|
|
step_scores=self.step_scores[idx], |
|
|
) |
|
|
] |
|
|
|
|
|
def reorder_state(self, new_order: torch.Tensor) -> None: |
|
|
self.state_bag_reorder(new_order) |
|
|
|
|
|
self.context = EmbeddingsBatch( |
|
|
self.context.seqs.index_select(dim=0, index=new_order), |
|
|
self.context.padding_mask, |
|
|
) |
|
|
|
|
|
self.seqs = self.seqs.index_select(dim=0, index=new_order) |
|
|
|
|
|
self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order) |
|
|
|
|
|
self.step_scores = self.step_scores.index_select(dim=0, index=new_order) |
|
|
|
|
|
self.lengths = self.lengths.index_select(dim=0, index=new_order) |
|
|
|
|
|
if self.prompt_padding_mask is not None: |
|
|
self.prompt_padding_mask = self.prompt_padding_mask.index_select( |
|
|
dim=0, index=new_order |
|
|
) |
|
|
|
|
|
if self.prompt_seq_lens is not None: |
|
|
self.prompt_seq_lens = self.prompt_seq_lens.index_select( |
|
|
dim=0, index=new_order |
|
|
) |
|
|
|