File size: 6,029 Bytes
473c3a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from __future__ import annotations
import inspect
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Protocol, Union
import numpy as np
import torch
from sklearn.decomposition import PCA
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
logger = logging.getLogger(__name__)
PathLike = Union[Path, str]
PCADimType = Union[int, None, float, Literal["auto"]]
_DEFAULT_BATCH_SIZE = 256
class ModulewithWeights(Protocol):
weight: torch.nn.Parameter
def create_embeddings(
model: PreTrainedModel,
tokenized: list[list[int]],
device: str,
pad_token_id: int,
) -> np.ndarray:
"""
Create output embeddings for a bunch of tokens using a pretrained model.
It does a forward pass for all tokens passed in `tokens`.
:param model: The model to use.
This should be a transformers model.
:param tokenized: All tokenized tokens.
:param device: The torch device to use.
:param pad_token_id: The pad token id. Used to pad sequences.
:return: The output embeddings.
"""
model = model.to(device)
out_weights: np.ndarray
intermediate_weights: list[np.ndarray] = []
# Add token_type_ids only if the model supports it
add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args
lengths = np.asarray([len(sequence) for sequence in tokenized])
sort_order = np.argsort(lengths)
sorted_tokenized = [tokenized[i] for i in sort_order]
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]
encoded = {}
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
if add_token_type_ids:
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
out = _encode_mean_using_model(model, encoded)
intermediate_weights.extend(out.numpy())
pbar.update(len(batch))
# Sort the output back to the original order
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
out_weights = np.stack(intermediate_weights)
out_weights = np.nan_to_num(out_weights)
return out_weights
@torch.no_grad()
def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using a model.
Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
So detection of these is necessary.
:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The mean of the output for each token.
"""
encodings = {k: v.to(model.device) for k, v in encodings.items()}
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
out: torch.Tensor = encoded.last_hidden_state.cpu()
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
# See here: https://github.com/numpy/numpy/issues/19808
if out.dtype == torch.bfloat16:
out = out.float()
# Take the mean by averaging over the attention mask.
mask = encodings["attention_mask"].cpu().float()
mask /= mask.sum(1)[:, None]
return torch.bmm(mask[:, None, :].float(), out).squeeze(1)
def post_process_embeddings(
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
) -> np.ndarray:
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
if pca_dims is not None:
if pca_dims == "auto":
pca_dims = embeddings.shape[1]
if pca_dims > embeddings.shape[1]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
"Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
"Applying PCA will probably improve performance, so consider just leaving it."
)
pca_dims = embeddings.shape[1]
if pca_dims >= embeddings.shape[0]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
)
elif pca_dims <= embeddings.shape[1]:
if isinstance(pca_dims, float):
logger.info(f"Applying PCA with {pca_dims} explained variance.")
else:
logger.info(f"Applying PCA with n_components {pca_dims}")
orig_dims = embeddings.shape[1]
p = PCA(n_components=pca_dims, svd_solver="full")
embeddings = p.fit_transform(embeddings)
if embeddings.shape[1] < orig_dims:
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
explained_variance = np.sum(p.explained_variance_)
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
logger.info(f"Explained variance: {explained_variance:.3f}.")
if sif_coefficient is not None:
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
proba = inv_rank / np.sum(inv_rank)
embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]
return embeddings
|