Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from fairseq2.logging import get_log_writer
from fairseq2.models.sequence import SequenceBatch
from fairseq2.nn.padding import PaddingMask, pad_seqs
from fairseq2.typing import Device
from torch import Tensor
from torch.nn import Module
from lcm.utils.common import Batched
logger = get_log_writer(__name__)
DOC_LENGTHS = "__doc_lengths"
class LCMStyle(Enum):
"""Specifies a style for preparing the LCM input."""
SUPERVISED = 1
"""For when the model is fed supervised data with source & target sentences."""
UNSUPERVISED = 2
"""For when the model is fed unsupervised data with source sentences only."""
PACKED_UNSUPERVISED = 3
"""For when the model is fed ``packed`` unsupervised data with source sentences only.
This means that we will look for document_lengths and propagate them to the
packed causal masked attention and the packed position encoders"""
@dataclass
class EmbeddingsBatch:
"""Represents a sequence of embeddings batch.
Resembles Fairseq2's SequenceBatch with additional properties"""
seqs: Tensor
"""The sequences. *Shape:* :math:`(B,S,D)`, where :math:`B` is the batch
size, :math:`S` is the sequence length (in sentences per document),
and :math:`D` the embedding dimension
"""
padding_mask: Optional[PaddingMask] = None
"""The padding mask of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is
the batch size and :math:`S` is the sequence length."""
diffusion_timesteps: Optional[Tensor] = None
"""Diffusion timesteps of noising process of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is
the batch size and :math:`S` is the sequence length."""
document_lengths: Optional[Tensor] = None
"""Lengths of the documents (in sentences) present in the batch
Shape: (len_doc, )
"""
source_lengths: Optional[Tensor] = None
"""Lengths of source part for each element in batch, so that `seqs[i, :source_lengths[i]]` corresponds to source for each i in [0, batch_size).
Shape: (batch_size, )
"""
def __post_init__(self):
if self.document_lengths is not None:
assert self.document_lengths.sum() == self.seqs.size(
1
) or 2 * self.document_lengths.sum() == self.seqs.size(1), (
"The legnths do no sum up to the sequence length "
"(nor half the length for doubled diffusion sequences). "
f"We have seqs.size={self.seqs.size()} and lengths={self.document_lengths} "
f"summing to {self.document_lengths.sum()}"
)
def __len__(self) -> int:
return self.batch_size
@property
def batch_size(self) -> int:
"""The size of the batch."""
return self.seqs.size(0)
@property
def shape(self) -> torch.Size:
"""The shape of the batch."""
return self.seqs.shape
@property
def device(self) -> Device:
"""The device of the batch."""
return self.seqs.device
def clone(self):
return deepcopy(self)
def __getitem__(self, i: int) -> Any:
raise NotImplementedError(
"Access to each item in EmbeddingsBatch not allowed yet"
)
def unbatch(self) -> List[Tensor]:
if self.padding_mask is None:
return list(self.seqs)
else:
return [
tt[:length] for tt, length in zip(self.seqs, self.padding_mask.seq_lens)
]
def get_last_element(self) -> Tensor:
if self.padding_mask:
return self.seqs[
torch.arange(len(self.padding_mask.seq_lens), device=self.seqs.device),
(self.padding_mask.seq_lens - 1),
]
else:
return self.seqs[:, -1]
def set_last_element(self, element: Tensor) -> None:
element = element.to(self.seqs.device)
if self.padding_mask:
for i, slen in enumerate(self.padding_mask.seq_lens):
self.seqs[i, slen - 1] = element[i]
else:
self.seqs[:, -1] = element
def normalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch":
if normalizer is None:
logger.warning(
"The normalizer is None, as such, the features will remain unchanged"
)
return self
return EmbeddingsBatch(
seqs=normalizer.normalize(self.seqs),
padding_mask=self.padding_mask,
diffusion_timesteps=self.diffusion_timesteps,
document_lengths=self.document_lengths,
source_lengths=self.source_lengths,
)
def denormalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch":
if normalizer is None:
logger.warning(
"The normalizer is None, as such, the features will remain unchanged"
)
return self
return EmbeddingsBatch(
seqs=normalizer.denormalize(self.seqs),
padding_mask=self.padding_mask,
diffusion_timesteps=self.diffusion_timesteps,
document_lengths=self.document_lengths,
source_lengths=self.source_lengths,
)
def double_seqs(self) -> "EmbeddingsBatch":
"""
performs sequence elements repeatition in sequence dim :
1, 2, 3 -> 1, 1, 2, 2, 3, 3
x, y -> x, x, y, y
"""
if self.padding_mask is not None:
doubled_padding_mask = PaddingMask(
seq_lens=2 * self.padding_mask._seq_lens,
batch_seq_len=2 * self.padding_mask._batch_seq_len,
)
else:
doubled_padding_mask = None
return EmbeddingsBatch(
seqs=torch.repeat_interleave(self.seqs, 2, dim=1),
padding_mask=doubled_padding_mask,
diffusion_timesteps=self.diffusion_timesteps,
document_lengths=self.document_lengths,
source_lengths=(
torch.repeat_interleave(self.source_lengths, 2, dim=0)
if self.source_lengths is not None
else None
),
)
def flatten_to_sentences(self) -> Tensor:
"""Flatten the sequence of embeddings
from B, S, D to B*~S, D after removing the padded positions
"""
embed_dim = self.seqs.size(-1)
if self.padding_mask is not None:
seq_lens = self.padding_mask.seq_lens
embeds_mask = self.padding_mask.materialize().unsqueeze(-1)
# Remove padded positions and reshape as B*~S, D
flat_embeds = torch.masked_select(self.seqs, embeds_mask).reshape(
(-1, embed_dim)
)
# split per document/paragraph
flat_embeds_per_doc = list(torch.split(flat_embeds, seq_lens.tolist()))
# Concatenate back
flat_embeds = torch.concat(flat_embeds_per_doc)
else:
embeds = self.seqs
flat_embeds = embeds.reshape((-1, embed_dim))
return flat_embeds
@dataclass
class LCMInput(Batched):
"""Dataclass for a pair of source/target sequences of SONAR embeddings"""
source: List[Tensor]
"""source: SONAR embeddings of the source text
i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]"""
target: Union[None, List[Tensor]]
"""target: If supervised data: SONAR embeddings of the target text"""
tokens: Union[None, SequenceBatch] = None
"""tokens: Tokenized flattened sentences for the SONAR decoder
(see the dataloader `_prepare_subword_tokens`)"""
target_tokens: Union[None, SequenceBatch] = None
"""target_tokens: a sequence of the same shape as target_tokens, but shifted, to serve as the target.
(see the dataloader `_prepare_subword_tokens`)"""
name: Optional[str] = None
"""
dataset name from which input is coming from
"""
batch: Optional[Dict[str, Any]] = None
"""raw batch of dataloader used for tracking and debugging"""
def __post_init__(self):
assert self.source is not None
length = len(self.source)
assert (self.target is None) or (len(self.target) == length), (
f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}"
)
def __len__(self) -> int:
return len(self.source)
def __getitem__(self, i: int) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""
Return the content of item in the batch
"""
if self.target is None:
return self.source[i]
else:
return self.source[i], self.target[i]
def prepare_input(
self,
style: LCMStyle = LCMStyle.UNSUPERVISED,
) -> EmbeddingsBatch:
"""
Adds special tokens to the source (& target) and prepares
the EmbeddingsBatch (tensor & its padding mask) that will be
forwarded in the LCM model.
`style`: LCMStyle is either supervised or
unsupervised (requires target embeddings)
"""
if style == LCMStyle.UNSUPERVISED:
return get_embeddings_sequence(src_seqs=self.source)
elif style == LCMStyle.PACKED_UNSUPERVISED:
# If using PACKED_UNSUPERVISED, document_lengths will be added to `EmbeddingsBatch`
document_lengths = None
if self.batch is not None and self.batch.get(DOC_LENGTHS, None) is not None:
# document_lengths will only be consumed if the batch_size is 1
assert len(self.batch[DOC_LENGTHS]) == 1, "Expecting batch size of 1"
document_lengths = self.batch[DOC_LENGTHS][0].type(torch.int64)
return get_embeddings_sequence(
src_seqs=self.source,
document_lengths=document_lengths,
)
elif style == LCMStyle.SUPERVISED:
assert self.target is not None, (
"Missing target embeddings for a supervised batch"
)
return get_embeddings_sequence(
src_seqs=self.source,
tgt_seqs=self.target,
)
raise ValueError(f"Unsupported style={style} - could not prepare input")
def prepare_target_mask(
self,
embeddings: EmbeddingsBatch,
style: LCMStyle,
min_context_size: Optional[int] = None,
) -> Tensor:
"""Prepare a target mask signaling what positions
we should predict and optimize the model for
Args:
- min_context_size: the minimum context used to predict the next
concept (only used for unuspervised training)
"""
batch_size, maxlen, _ = embeddings.seqs.size()
device = embeddings.seqs.device
if style == LCMStyle.UNSUPERVISED:
# A target mask for unsupervised next sentence prediction
# All positions are optimized/predicted starting from min_context_size
target_mask = torch.ones(
(batch_size, maxlen),
dtype=torch.bool,
device=device,
)
if min_context_size is not None:
target_mask[:, : min(min_context_size, target_mask.size(1))] = False
elif style == LCMStyle.PACKED_UNSUPERVISED:
# A target mask for unsupervised next sentence prediction when the data is packed
# All positions are optimized starting from min_context_size in each document
document_lengths = embeddings.document_lengths
if document_lengths is not None: # training
def get_document_target_mask(doc_length):
mask = torch.ones(doc_length, dtype=torch.bool, device=device)
mask[: min(min_context_size, doc_length)] = False
return mask
target_mask = torch.cat(
[get_document_target_mask(length) for length in document_lengths]
).unsqueeze(0)
else: # validation with unpacked data:
target_mask = torch.ones(
(batch_size, maxlen),
dtype=torch.bool,
device=device,
)
if min_context_size is not None:
target_mask[:, : min(min_context_size, target_mask.size(1))] = False
elif style == LCMStyle.SUPERVISED:
# A target mask for target prediction
indices = torch.arange(maxlen, device=device).expand(batch_size, -1)
source_lengths = torch.tensor(
[seq.size(0) for seq in self.source],
device=device,
)
target_mask = indices >= source_lengths.unsqueeze(1).expand(-1, maxlen)
# Factor in padded positions:
if embeddings.padding_mask is not None:
target_mask = target_mask * embeddings.padding_mask.materialize()
return target_mask.detach()
def get_embeddings_sequence(
src_seqs: List[Tensor],
tgt_seqs: Optional[List[Tensor]] = None,
document_lengths: Optional[Tensor] = None,
double_target: bool = False,
) -> EmbeddingsBatch:
seqs_lst: List[Tensor] = []
for src_seq, tgt_seq in zip(src_seqs, tgt_seqs or [None] * len(src_seqs)): # type: ignore
embeds: List[Tensor] = []
device, dtype = src_seq.device, src_seq.dtype
# mandatory src_sec
embeds.append(src_seq)
# supervised tgt_seq
if tgt_seq is not None:
tgt_seq = tgt_seq.to(device).type(dtype)
if double_target:
embeds.append(torch.repeat_interleave(tgt_seq, 2, dim=0))
else:
embeds.append(tgt_seq)
seqs_lst.append(torch.concat(embeds))
seqs, padding_mask = pad_seqs(seqs_lst)
if document_lengths is not None:
document_lengths = document_lengths.to(seqs.device)
if tgt_seqs is not None:
source_lengths = torch.tensor(
[seq.size(0) for seq in src_seqs], device=seqs.device
)
else:
source_lengths = None
output = EmbeddingsBatch(
seqs,
padding_mask=padding_mask,
document_lengths=document_lengths,
source_lengths=source_lengths,
)
return output