Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from dataclasses import dataclass, field
from typing import List, Tuple
import torch
import torch.nn.functional as F
from fairseq2.logging import get_log_writer
from fairseq2.nn.padding import pad_seqs
from torch import Tensor
from lcm.datasets.batch import EmbeddingsBatch, LCMInput, LCMStyle
from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModel
from lcm.train.criterion import CriterionsFactory
from lcm.train.lcm.criterion import (
LCMCriterion,
LCMCriterionConfig,
compute_standard_mse,
)
from lcm.train.metrics import LossTerm, format_as_float, register_metric_formatter
from lcm.train.step_sampler import StepsSampler, StepsSamplerConfig
logger = get_log_writer(__name__)
@dataclass
class TowerDiffusionLCMCriterionConfig(LCMCriterionConfig):
cf_guidance_probability: float = 0.0
"""Probability to use classifier-free guidance by dropping conditioning.
Note that this requires the model to be set with
`trained_with_cf_guidance = True`!
"""
step_sampling: StepsSamplerConfig = field(
default_factory=lambda: StepsSamplerConfig()
)
log_losses_per_timestep_bucket: bool = False
@CriterionsFactory.register("two_tower_diffusion_next_sent")
class TwoTowerDiffusionCriterion(LCMCriterion):
"""Computes the LCM training objective for next-sentence prediction with diffusion"""
config: TowerDiffusionLCMCriterionConfig
model: TwoTowerDiffusionLCModel
def __init__(
self,
config: TowerDiffusionLCMCriterionConfig,
model: TwoTowerDiffusionLCModel,
style: LCMStyle = LCMStyle.UNSUPERVISED,
):
super().__init__(config, model, style)
assert hasattr(self.base_model, "noise_scheduler"), (
"Expecting the diffusion model to have a `noise_scheduler`"
)
self.noise_scheduler = self.base_model.noise_scheduler
self.prediction_type = self.noise_scheduler.prediction_type
self.trained_with_cf_guidance = self.base_model.config.trained_with_cf_guidance
self.cf_guidance_probability = config.cf_guidance_probability
assert (
bool(self.cf_guidance_probability > 0) == self.trained_with_cf_guidance
), (
"Expecting the config's cf_guidance_probabilitya to align with the model's `trained_with_cf_guidance` ",
f"Found cf_guidance_probability={config.cf_guidance_probability} and "
f"trained_with_cf_guidance={self.trained_with_cf_guidance}",
)
assert self.normalize_in_criterion, (
"We only support `normalize_in_criterion = True` in the diffusion criterions"
)
self.summands.append("unnormalized_reconstruction_loss")
if self.config.log_losses_per_timestep_bucket:
# customize if needed
self.step_bucketing_boundaries = torch.linspace(
0, self.noise_scheduler.num_diffusion_train_steps, 11
)
self.step_bucketing_labels: List[str] = []
for e in range(len(self.step_bucketing_boundaries) - 1):
bucket_left = self.step_bucketing_boundaries[e]
bucket_right = self.step_bucketing_boundaries[e + 1]
self.step_bucketing_labels.append(
f"reconstruction_loss_t{bucket_left:.0f}-{bucket_right:.0f}"
)
self.summands.extend(self.step_bucketing_labels)
for label in self.step_bucketing_labels:
register_metric_formatter(
label, label, 1000, format_as_float, overwrite=True
)
# Step sampler + loss weighter
self.step_sampler = StepsSampler(
config.step_sampling,
noise_scheduler=self.noise_scheduler,
)
def prepare_input_and_mask(
self,
batch: LCMInput,
) -> Tuple[EmbeddingsBatch, EmbeddingsBatch, torch.Tensor]:
"""
A method for preparing model inputs and mask for a batch.
It will be typically reused by the `__call__`
implementations of the subclasses.
Returns:
- input_batch: context
- target_batch: denoiser input
- target_mask mask of positions to compute the loss over
"""
# Prepare the input as in MSE LCM: each sequence is (src, tgt)
input_embeddings = batch.prepare_input(style=self.style)
# Normalize the embeddings
if self.normalize_in_criterion:
input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
target_mask = torch.ones(
size=input_embeddings.seqs.shape[:-1],
dtype=torch.bool,
device=input_embeddings.seqs.device,
)
# Factor in padded positions:
if input_embeddings.padding_mask is not None:
target_mask &= input_embeddings.padding_mask.materialize()
return input_embeddings, input_embeddings.clone(), target_mask
def sample_noisy_input_and_targets(self, input_batch, target_mask):
"""
(1)
Prepares the noised inputs (latents) by sampling diffusion timesteps and calling
on the model's noise_scheduler to add noise accordingly
(2) Given the scheduler prediction type, prepares the target that the model will be
trained to predict.
:param input_bach: EmbeddingsBatch of the ground truth embeddings with seqs in (B, T, C)
:param target_mask: Bool tensor in (B, T) where `True` signals that the
model will be asked to predict the position
"""
input_seqs, padding_mask = input_batch.seqs, input_batch.padding_mask
timesteps = self.step_sampler.sample(
size=input_seqs[..., 0].size(), device=input_seqs.device
)
# Sample noise
noise_seqs = torch.randn_like(input_seqs)
# Define target in (B*T, C)
sonar_dim = input_seqs.size(-1)
if self.prediction_type == "sample":
"""Predict the clean ground truth embeddings. Default mode"""
target = input_seqs.view(-1, sonar_dim)
elif self.prediction_type == "epsilon":
"""Predict the added noise"""
target = noise_seqs.view(-1, sonar_dim)
elif self.prediction_type == "v_prediction":
"""Predict an interpolation of the ground truth clean
embeddings and the added noise.
As introduced in https://arxiv.org/pdf/2305.08891
"""
target = self.noise_scheduler.get_velocity(
input_seqs.view(-1, sonar_dim),
noise_seqs.view(-1, sonar_dim),
timesteps.view(-1),
).clone()
else:
raise ValueError(
"Prediction type should be either: sample, epsilon, v_prediction"
)
# Add noise
# Reshape inputs and noise into in (B*T , C) -> add noise -> reshape back as (B, T, C)
noisy_input_seqs = self.noise_scheduler.add_noise(
input_seqs.view(-1, sonar_dim),
noise_seqs.view(-1, sonar_dim),
timesteps.view(-1),
).view(input_seqs.size())
# Create sequence batch with diffusion timesteps
noisy_input_batch = EmbeddingsBatch(
noisy_input_seqs,
padding_mask,
diffusion_timesteps=timesteps,
)
return noisy_input_batch, target, target_mask
def compute_loss(
self, flattened_predictions, flattened_target
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Parameters:
flattened_predictions (Tensor): The predictions in (N, C)
flattened_target (Tensor): The targets in (N, C)
Returns:
reconstruction_loss (Tensor): The Reconstruction loss we want to optimize (RMSE, SmoothL1, Huber etc.).
plain_reconstruction_loss (Tensor): plain RMSE loss.
unnormalized_reconstruction_loss (Tensor): plain RMSE loss between unnormalized features.
"""
reconstruction_loss, plain_reconstruction_loss = compute_standard_mse(
flattened_predictions,
flattened_target,
)
unnormalized_reconstruction_loss, _ = compute_standard_mse(
flattened_predictions,
flattened_target,
normalizer=self.sonar_normalizer,
)
# For backward compatibility with ongoing runs, take the sqrt
if self.config.compute_rmse:
epsilon = 1e-5
reconstruction_loss = torch.sqrt(reconstruction_loss + epsilon)
plain_reconstruction_loss = torch.sqrt(plain_reconstruction_loss + epsilon)
unnormalized_reconstruction_loss = torch.sqrt(
unnormalized_reconstruction_loss + epsilon
)
return (
reconstruction_loss,
plain_reconstruction_loss,
unnormalized_reconstruction_loss,
)
@torch.no_grad()
def _log_losses_per_step(self, batch_steps, reconstruction_loss):
# Aggregate loss terms based on their bucket of diffusion steps for tracking
summands = {}
if self.config.log_losses_per_timestep_bucket:
# Reconstruction_loss in BT,
# batch_steps in BT,
bucket_index = torch.bucketize(
batch_steps, self.step_bucketing_boundaries.to(batch_steps.device)
)
onehot = F.one_hot(
bucket_index,
num_classes=self.step_bucketing_boundaries.numel(),
)
loss_per_step = torch.matmul(onehot.t().float(), reconstruction_loss)
count_steps = onehot.sum(dim=0) + 1e-6
if self.reduction == "mean":
loss_per_step /= count_steps
for e, label in enumerate(self.step_bucketing_labels):
summands[label] = (
loss_per_step[e].item(),
count_steps[e].long().item(),
)
return summands
def __call__(self, batch: LCMInput) -> LossTerm:
"""
Input batch is LCMInput with:
source: List[Tensor]
target: Union[None, List[Tensor]]
"""
# Prepare the clean inputs and target mask:
input_batch, target_batch, target_mask = self.prepare_input_and_mask(batch)
noisy_target_batch, target, target_mask = self.sample_noisy_input_and_targets(
target_batch, target_mask
)
# Encode the context and diffuse:
output_batch = self.model(
input_batch,
noisy_target_batch,
cf_guidance_prob=self.cf_guidance_probability,
)
# Shape B, T, C
output_seqs = output_batch.seqs
sonar_dim = output_seqs.size(-1)
# only measure distance over `target_mask = True` positions
target_mask = target_mask.reshape(-1)
# The target is basically the doubled ground truth sequence before noising
# (with some modification to adjust for the denoiser's prediction type)
# contextualized latents (noised inputs preceding the target) e_1, e_2, ...
flattened_predictions = output_seqs.view(-1, sonar_dim)[target_mask]
# x1, x2, ..., xT
# Target is already in B*T, C
flattened_target = target[target_mask]
# Cast features to float32 before computing the loss:
(
reconstruction_loss,
mse_loss,
unnormalized_reconstruction_loss,
) = self.compute_loss(flattened_predictions.float(), flattened_target.float())
num_target_elements = target_mask.sum()
batch_steps = noisy_target_batch.diffusion_timesteps.view(-1)[target_mask]
summands = self._log_losses_per_step(batch_steps, reconstruction_loss)
# Get loss scales per timestep (gamma)
gammas = self.step_sampler.get_loss_scales(batch_steps)
# Weight the loss terms
if gammas is not None:
reconstruction_loss = torch.mul(reconstruction_loss, gammas)
if self.reduction == "sum" or num_target_elements == 0:
reduced_reconstruction_loss = reconstruction_loss.sum()
mse_loss = mse_loss.sum()
unnormalized_reconstruction_loss = unnormalized_reconstruction_loss.sum()
elif self.reduction == "mean":
reduced_reconstruction_loss = reconstruction_loss.mean()
mse_loss = mse_loss.mean()
unnormalized_reconstruction_loss = unnormalized_reconstruction_loss.mean()
final_loss = reduced_reconstruction_loss
# Loss summands for records
summands.update(
{
"mse_loss": (mse_loss.item(), -1),
"reconstruction_loss": (reduced_reconstruction_loss.item(), -1),
"unnormalized_reconstruction_loss": (
unnormalized_reconstruction_loss.item(),
-1,
),
}
)
return LossTerm(
value=final_loss,
batch_size=output_seqs.size(0),
num_target_elements=num_target_elements.item(),
summands=summands,
)
@CriterionsFactory.register("two_tower_diffusion_next_sent_finetuning")
class DiffusionNextSentFinetuningCriterion(TwoTowerDiffusionCriterion):
def __init__(
self,
config: TowerDiffusionLCMCriterionConfig,
model: TwoTowerDiffusionLCModel,
):
super().__init__(config, model, LCMStyle.SUPERVISED)
def prepare_input_and_mask(
self,
batch: LCMInput,
) -> Tuple[EmbeddingsBatch, EmbeddingsBatch, torch.Tensor]:
"""
A method for preparing model inputs and mask for a batch.
It will be typically reused by the `__call__`
implementations of the subclasses.
Returns:
- input_batch: context
- target_batch: denoiser input
- target_mask mask of positions to compute the loss over
"""
# Prepare the input as in MSE LCM
input_embeddings = batch.prepare_input(style=self.style)
assert input_embeddings.source_lengths is not None, (
"Missing source lengths needed for the two-tower supervised fintuning"
)
target_embeddings = EmbeddingsBatch(*pad_seqs(batch.target)) # type: ignore
# Normalize the embeddings
if self.normalize_in_criterion:
input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
target_embeddings = target_embeddings.normalize_seqs(self.sonar_normalizer)
target_mask = torch.ones(
size=target_embeddings.shape[:-1],
dtype=torch.bool,
device=input_embeddings.seqs.device,
)
# Factor in padded positions:
if target_embeddings.padding_mask is not None:
target_mask &= target_embeddings.padding_mask.materialize()
return input_embeddings, target_embeddings, target_mask