File size: 6,725 Bytes
473c3a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Self, TypeVar
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from distiller.model2vec import StaticModel
if TYPE_CHECKING:
from tokenizers import Encoding, Tokenizer
logger = logging.getLogger(__name__)
class FinetunableStaticModel(nn.Module):
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
"""
Initialize a trainable StaticModel from a StaticModel.
:param vectors: The embeddings of the staticmodel.
:param tokenizer: The tokenizer.
:param out_dim: The output dimension of the head.
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
"""
super().__init__()
self.pad_id = pad_id
self.out_dim = out_dim
self.embed_dim = vectors.shape[1]
self.vectors = vectors
if self.vectors.dtype != torch.float32:
dtype = str(self.vectors.dtype)
logger.warning(
f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
)
self.vectors = vectors.float()
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
self.head = self.construct_head()
self.w = self.construct_weights()
self.tokenizer = tokenizer
def construct_weights(self) -> nn.Parameter:
"""Construct the weights for the model."""
weights = torch.zeros(len(self.vectors))
weights[self.pad_id] = -10_000
return nn.Parameter(weights)
def construct_head(self) -> nn.Sequential:
"""Method should be overridden for various other classes."""
return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
@classmethod
def from_pretrained(
cls, *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
) -> Self:
"""Load the model from a pretrained model2vec model."""
model = StaticModel.from_pretrained(model_name)
return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)
@classmethod
def from_static_model(cls, *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> Self:
"""Load the model from a static model."""
model.embedding = np.nan_to_num(model.embedding)
embeddings_converted = torch.from_numpy(model.embedding)
return cls(
vectors=embeddings_converted,
pad_id=model.tokenizer.token_to_id("[PAD]"),
out_dim=out_dim,
tokenizer=model.tokenizer,
**kwargs,
)
def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
A forward pass and mean pooling.
This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients
to pass through.
:param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
:return: The mean over the input ids, weighted by token weights.
"""
w = self.w[input_ids]
w = torch.sigmoid(w)
zeros = (input_ids != self.pad_id).float()
w = w * zeros
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
embedded = self.embeddings(input_ids)
# Weigh each token
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# Mean pooling by dividing by the length
embedded = embedded / length[:, None]
return nn.functional.normalize(embedded)
def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the mean, and a classifier layer after."""
encoded = self._encode(input_ids)
return self.head(encoded), encoded
def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor:
"""
Tokenize a bunch of strings into a single padded 2D tensor.
Note that this is not used during training.
:param texts: The texts to tokenize.
:param max_length: If this is None, the sequence lengths are truncated to 512.
:return: A 2D padded tensor
"""
encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)
@property
def device(self) -> str:
"""Get the device of the model."""
return self.embeddings.weight.device
def to_static_model(self) -> StaticModel:
"""Convert the model to a static model."""
emb = self.embeddings.weight.detach().cpu().numpy()
w = torch.sigmoid(self.w).detach().cpu().numpy()
return StaticModel(emb * w[:, None], self.tokenizer, normalize=True)
class TextDataset(Dataset):
def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:
"""
A dataset of texts.
:param tokenized_texts: The tokenized texts. Each text is a list of token ids.
:param targets: The targets.
:raises ValueError: If the number of labels does not match the number of texts.
"""
if len(targets) != len(tokenized_texts):
msg = "Number of labels does not match number of texts."
raise ValueError(msg)
self.tokenized_texts = tokenized_texts
self.targets = targets
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.tokenized_texts)
def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
"""Gets an item."""
return self.tokenized_texts[index], self.targets[index]
@staticmethod
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
"""Collate function."""
texts, targets = zip(*batch, strict=False)
tensors = [torch.LongTensor(x) for x in texts]
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
return padded, torch.stack(targets)
def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
"""Convert the dataset to a DataLoader."""
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
ModelType = TypeVar("ModelType", bound=FinetunableStaticModel)
|