Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from fairseq2.config_registry import ConfigRegistry
from fairseq2.typing import DataType, Device
from torch import Tensor
from torch.nn import Module
@dataclass
class SonarNormalizerConfig:
dim: int = 1024
"""The dimension of the features to be normalized"""
clip_proba: Optional[float] = None
"""
If `clip_proba` is not None, `clip_min` and `clip_max` will
be used to clip the features before normalizing.
`clip_min` and `clip_max` correspond to the pre-computed `clip_proba`
and `1-clip_proba` quantiles respectively.
"""
with_fft: bool = False
"""
Applying FFT transform at the raw input before all other transforms.
"""
quantile_min: float = 0.25
"""The lower quantile used to measure the IQR when estimating the scale with a robust scaler"""
quantile_max: float = 0.75
"""The upper quantile used to measure the IQR when estimating the scale with a robust scaler"""
normalization_method: Literal["standard", "robust", "gaussian_robust"] = (
"gaussian_robust"
)
"""
Dictates how the normalizer's scale is evaluated when fitting.
(1) 'standard': center=mean, scale = std
(2) 'robust': center=median, scale = IQR = Qmax - Qmin
(3) 'gaussian_robust': center=median, scale = IQR / k,
where k=`stats.norm.ppf(q_max / 100.0) - stats.norm.ppf(q_min / 100.0)`
i.e scale = scale = 0.7413 x IQR if q_min=0.25 and q_max=0.75.
This is the robust normalization of https://arxiv.org/pdf/2307.05445
"""
sonar_normalizer_archs = ConfigRegistry[SonarNormalizerConfig]()
sonar_normalizer_arch = sonar_normalizer_archs.decorator
class FFTInterface:
@staticmethod
def fft_transform(embeddings: Tensor) -> Tensor:
dtype = embeddings.dtype
if dtype in [torch.float16, torch.bfloat16]:
embeddings = embeddings.to(dtype=torch.float32)
embeddings = torch.fft.rfft(embeddings, norm="backward")
return torch.concat(
[torch.real(embeddings), torch.imag(embeddings)[..., 1:-1]], dim=-1
).to(dtype)
@staticmethod
def fft_inverse_transform(embeddings: Tensor) -> Tensor:
assert embeddings.shape[-1] % 2 == 0
dtype = embeddings.dtype
if dtype in [torch.float16, torch.bfloat16]:
embeddings = embeddings.to(dtype=torch.float32)
rr, im = torch.split(
embeddings,
[embeddings.shape[-1] // 2 + 1, embeddings.shape[-1] // 2 - 1],
dim=-1,
)
im = torch.concat(
[torch.zeros_like(im[..., :1]), im, torch.zeros_like(im[..., :1])], dim=-1
)
embeddings = torch.fft.irfft(rr + im * 1j)
return embeddings.to(dtype)
class SonarNormalizer(FFTInterface, Module):
"""
To perform efficient diffusion modeling, SONAR embeddings need to be
normalized. This SonarNormalizer follows the robust normalization introduced in
https://arxiv.org/abs/2307.05445
Quoting from the paper: "Due to the very long-tailed feature distribution, typical mean and standard deviation statistics will be
heavily biased. We thus propose a robust alternative based on the feature distribution quantiles. We
take the median as the center of the distribution and approximate its scale using the Normalized
InterQuartile Range (IQR) for a normal distribution: 0.7413 × IQR
"""
def __init__(
self,
config: SonarNormalizerConfig,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
super().__init__()
self.config = config
self.register_buffer(
"center", torch.zeros(config.dim, dtype=dtype, device=device)
)
self.register_buffer(
"scale", torch.ones(config.dim, dtype=dtype, device=device)
)
if self.config.clip_proba is not None:
self.register_buffer(
"clip_min", torch.ones(config.dim, dtype=dtype, device=device)
)
self.register_buffer(
"clip_max", torch.ones(config.dim, dtype=dtype, device=device)
)
def normalize(self, embeddings: Tensor) -> Tensor:
if self.config.with_fft:
embeddings = self.fft_transform(embeddings)
embeddings = (embeddings - self.center) / self.scale
if self.config.clip_proba is not None:
embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max)
return embeddings
def denormalize(self, embeddings: Tensor) -> Tensor:
if self.config.clip_proba is not None:
embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max)
embeddings = (embeddings * self.scale) + self.center
if self.config.with_fft:
embeddings = self.fft_inverse_transform(embeddings)
return embeddings
@torch.no_grad()
def fit(self, embeddings: Tensor):
if self.config.normalization_method in [
"robust",
"gaussian_robust",
]:
from sklearn.preprocessing import RobustScaler
_scaler = RobustScaler(
unit_variance=self.config.normalization_method == "gaussian_robust",
quantile_range=(self.config.quantile_min, self.config.quantile_max),
)
elif self.config.normalization_method == "standard":
from sklearn.preprocessing import StandardScaler
_scaler = StandardScaler()
else:
raise ValueError(
f"Unrecognizable method {self.config.normalization_method} for scaling input features"
)
assert embeddings.shape[-1] == self.config.dim
assert len(embeddings.shape) == 2
if self.config.with_fft:
embeddings = self.fft_transform(embeddings)
embeddings = _scaler.fit_transform(embeddings.cpu().float().numpy())
if self.config.normalization_method in [
"robust",
"gaussian_robust",
]:
_center = _scaler.center_
_scale = _scaler.scale_
elif self.config.normalization_method == "standard":
_center = _scaler.mean_
_scale = _scaler.scale_
self.center[:] = torch.tensor(
_center, dtype=self.center.dtype, device=self.center.device
)
self.scale[:] = torch.tensor(
_scale, dtype=self.scale.dtype, device=self.scale.device
)
if self.config.clip_proba is not None:
self.clip_min[:] = torch.quantile(
torch.tensor(embeddings), self.config.clip_proba, dim=0
).to(dtype=self.clip_min.dtype, device=self.clip_min.device)
self.clip_max[:] = torch.quantile(
torch.tensor(embeddings), 1 - self.config.clip_proba, dim=0
).to(dtype=self.clip_max.dtype, device=self.clip_max.device)
def create_sonar_normalizer(
config: SonarNormalizerConfig,
*,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> SonarNormalizer:
"""Create an LCM model.
:param config:
The configuration.
:param device:
The device on which to initialize modules.
:param dtype:
The data type of module parameters and buffers.
"""
return SonarNormalizer(config, device=device, dtype=dtype)