File size: 9,293 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | import itertools
import logging
from collections import defaultdict
from collections.abc import Sequence
from typing import Iterator, TypedDict
import torch
from tqdm import tqdm
from .FastPLMs.e1.modeling_e1 import E1ForMaskedLM, E1MaskedLMOutputWithPast, E1BatchPreparer, get_context, DataPrepConfig, KVCache
IndexedSequence = tuple[int, str]
logger = logging.getLogger(__name__)
class E1Prediction(TypedDict, total=False):
id: str | int
context_id: str | int | None
logits: torch.Tensor
token_embeddings: torch.Tensor
mean_token_embeddings: torch.Tensor
class E1Predictor:
def __init__(
self,
model: E1ForMaskedLM,
data_prep_config: DataPrepConfig | None = None,
max_batch_tokens: int = 65536,
use_cache: bool = True,
cache_size: int = 4,
save_masked_positions_only: bool = False,
fields_to_save: list[str] = ["logits", "token_embeddings", "mean_token_embeddings"],
keep_predictions_in_gpu: bool = False,
):
self.model = model
self.max_batch_tokens = max_batch_tokens
self.batch_preparer = E1BatchPreparer(data_prep_config=data_prep_config)
self.model.eval()
self.kv_cache = KVCache(cache_size=cache_size) if use_cache else None
self.fields_to_save = fields_to_save
self.save_masked_positions_only = save_masked_positions_only
self.keep_predictions_in_gpu = keep_predictions_in_gpu
def group_by_length(self, indexed_sequences: list[IndexedSequence]) -> list[list[IndexedSequence]]:
batches: list[list[IndexedSequence]] = [[]]
for idx, seq in sorted(indexed_sequences, key=lambda idx_seq: (len(idx_seq[1]), idx_seq[0])):
if len(batches[-1]) > 0 and len(seq) * (len(batches[-1]) + 1) > self.max_batch_tokens:
batches.append([])
batches[-1].append((idx, seq))
return batches
def group_by_context(self, indexed_sequences: list[IndexedSequence]) -> list[list[IndexedSequence]]:
batches: dict[str | None, list[IndexedSequence]] = defaultdict(list)
for idx, seq in indexed_sequences:
batches[get_context(seq)].append((idx, seq))
return list(batches.values())
def batch_sequences(self, sequences: list[str]) -> list[tuple[list[int], bool]]: # type: ignore[override]
"""
Batches the sequences and returns indices for the current rank
We want to keep sequences of similar length together.
Ensures that no batch exceeds max_batch_tokens
[For E1, also ensures if context is present, preserve locality of context]
"""
indexed_sequences: list[IndexedSequence] = list(enumerate(sequences))
indexed_batches = self.group_by_context(indexed_sequences)
# Preserve context ordering
indexed_batches = list(
itertools.chain.from_iterable([self.group_by_length(batch) for batch in indexed_batches])
)
batches = [[item[0] for item in batch] for batch in indexed_batches] # type: ignore[no-redef,misc]
assert sorted(sum(batches, [])) == list(range(len(sequences))), (
"Batches must contain all indices with no repetition"
)
batches_with_validity = [(b, True) for b in batches]
return batches_with_validity
@torch.no_grad()
def predict_batch(self, sequences: list[str], sequence_metadata: list[dict[str, str | int]]) -> list[E1Prediction]:
"""
Returns the logits/embeddings for the last sequence for multi-sequence inputs.
"""
outputs = self.predict_batch_padded(sequences)
outputs["logits"] = outputs["logits"].float()
outputs["embeddings"] = outputs["embeddings"].float()
token_mask = outputs["non_boundary_token_mask"] & outputs["last_sequence_mask"]
if self.save_masked_positions_only:
token_mask = token_mask & outputs["mask_positions_mask"]
predictions = []
for i in range(len(sequences)):
pred: E1Prediction = {
"id": sequence_metadata[i]["id"],
"context_id": sequence_metadata[i].get("context_id", None),
}
if "logits" in self.fields_to_save:
pred["logits"] = outputs["logits"][i, token_mask[i]]
if not self.keep_predictions_in_gpu:
pred["logits"] = pred["logits"].to("cpu") # type: ignore[union-attr]
if "token_embeddings" in self.fields_to_save:
pred["token_embeddings"] = outputs["embeddings"][i, token_mask[i]]
if not self.keep_predictions_in_gpu:
pred["token_embeddings"] = pred["token_embeddings"].to("cpu") # type: ignore[union-attr]
if "mean_token_embeddings" in self.fields_to_save:
pred["mean_token_embeddings"] = outputs["embeddings"][i, token_mask[i]].mean(dim=0)
if not self.keep_predictions_in_gpu:
pred["mean_token_embeddings"] = pred["mean_token_embeddings"].to("cpu") # type: ignore[union-attr]
predictions.append(pred)
return predictions
@torch.no_grad()
def predict_batch_padded(self, sequences: list[str]) -> dict[str, torch.Tensor]:
"""
If use_cache is True, this function will return the logits/embeddings for the last sequence for multi-sequence inputs.
If use_cache is False, this function will return the logits/embeddings for every sequence for multi-sequence inputs.
Returns three additional masks:
- non_boundary_token_mask: True for tokens that are part of the input sequence i.e not boundary tokens like 1, 2, <bos>, <eos>, <pad>, etc.
- last_sequence_mask: True for tokens that are part of the last sequence (including boundary tokens) in case of multi-sequence input.
- mask_positions_mask: True for masked positions.
- valid_token_mask: True for valid tokens.
"""
device_type = "cuda" if torch.cuda.is_available() else "cpu"
with torch.autocast(device_type, torch.bfloat16):
batch = self.batch_preparer.get_batch_kwargs(sequences, device=torch.device(device_type))
if self.kv_cache is not None:
self.kv_cache.before_forward(batch)
output: E1MaskedLMOutputWithPast = self.model(
input_ids=batch["input_ids"],
within_seq_position_ids=batch["within_seq_position_ids"],
global_position_ids=batch["global_position_ids"],
sequence_ids=batch["sequence_ids"],
past_key_values=batch.get("past_key_values", None),
use_cache=batch.get("use_cache", False),
output_attentions=False,
output_hidden_states=False,
)
if self.kv_cache is not None:
self.kv_cache.after_forward(batch, output)
logits = output.logits
embeddings = output.last_hidden_state
padding_mask = batch["input_ids"] == self.batch_preparer.pad_token_id
last_sequence_mask = batch["sequence_ids"] == batch["sequence_ids"].max(dim=1)[0][:, None] # type: ignore[union-attr]
boundary_token_mask = self.batch_preparer.get_boundary_token_mask(batch["input_ids"])
mask_positions_mask = self.batch_preparer.get_mask_positions_mask(batch["input_ids"])
return {
"logits": logits,
"embeddings": embeddings,
"last_sequence_mask": last_sequence_mask,
"non_boundary_token_mask": ~boundary_token_mask,
"mask_positions_mask": mask_positions_mask,
"valid_token_mask": ~padding_mask,
}
@torch.no_grad()
def predict(
self,
sequences: Sequence[str],
sequence_ids: Sequence[int | str] | None = None,
context_seqs: dict[str, str] | None = None,
) -> Iterator[E1Prediction]:
if sequence_ids is None:
sequence_ids = list(range(len(sequences)))
if context_seqs:
sequences_with_context = [
(ctx + "," + seq, {"context_id": ctx_id, "id": sequence_id})
for ctx_id, ctx in context_seqs.items()
for seq, sequence_id in zip(sequences, sequence_ids)
]
else:
sequences_with_context = [(seq, {"id": sequence_id}) for seq, sequence_id in zip(sequences, sequence_ids)]
sequences, sequence_metadata = tuple(zip(*sequences_with_context)) # type: ignore[assignment]
sequence_batch_indices: list[tuple[list[int], bool]] = self.batch_sequences(sequences) # type: ignore[arg-type]
logger.info(f"Predicting for {len(sequence_batch_indices)} batches")
for indices, is_valid_batch in tqdm(
sequence_batch_indices, desc="Predicting batches"
):
sequence_batch = [sequences[i] for i in indices]
sequence_batch_metadata = [sequence_metadata[i] for i in indices]
batch_predictions = self.predict_batch(sequence_batch, sequence_batch_metadata)
if not is_valid_batch:
continue
for prediction in batch_predictions:
yield prediction |