File size: 6,725 Bytes
473c3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Self, TypeVar

import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

from distiller.model2vec import StaticModel

if TYPE_CHECKING:
    from tokenizers import Encoding, Tokenizer

logger = logging.getLogger(__name__)


class FinetunableStaticModel(nn.Module):
    def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
        """
        Initialize a trainable StaticModel from a StaticModel.

        :param vectors: The embeddings of the staticmodel.
        :param tokenizer: The tokenizer.
        :param out_dim: The output dimension of the head.
        :param pad_id: The padding id. This is set to 0 in almost all model2vec models
        """
        super().__init__()
        self.pad_id = pad_id
        self.out_dim = out_dim
        self.embed_dim = vectors.shape[1]

        self.vectors = vectors
        if self.vectors.dtype != torch.float32:
            dtype = str(self.vectors.dtype)
            logger.warning(
                f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
            )
            self.vectors = vectors.float()

        self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
        self.head = self.construct_head()
        self.w = self.construct_weights()
        self.tokenizer = tokenizer

    def construct_weights(self) -> nn.Parameter:
        """Construct the weights for the model."""
        weights = torch.zeros(len(self.vectors))
        weights[self.pad_id] = -10_000
        return nn.Parameter(weights)

    def construct_head(self) -> nn.Sequential:
        """Method should be overridden for various other classes."""
        return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))

    @classmethod
    def from_pretrained(
        cls, *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
    ) -> Self:
        """Load the model from a pretrained model2vec model."""
        model = StaticModel.from_pretrained(model_name)
        return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)

    @classmethod
    def from_static_model(cls, *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> Self:
        """Load the model from a static model."""
        model.embedding = np.nan_to_num(model.embedding)
        embeddings_converted = torch.from_numpy(model.embedding)
        return cls(
            vectors=embeddings_converted,
            pad_id=model.tokenizer.token_to_id("[PAD]"),
            out_dim=out_dim,
            tokenizer=model.tokenizer,
            **kwargs,
        )

    def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        A forward pass and mean pooling.

        This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients
        to pass through.

        :param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
        :return: The mean over the input ids, weighted by token weights.
        """
        w = self.w[input_ids]
        w = torch.sigmoid(w)
        zeros = (input_ids != self.pad_id).float()
        w = w * zeros
        # Add a small epsilon to avoid division by zero
        length = zeros.sum(1) + 1e-16
        embedded = self.embeddings(input_ids)
        # Weigh each token
        embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
        # Mean pooling by dividing by the length
        embedded = embedded / length[:, None]

        return nn.functional.normalize(embedded)

    def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through the mean, and a classifier layer after."""
        encoded = self._encode(input_ids)
        return self.head(encoded), encoded

    def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor:
        """
        Tokenize a bunch of strings into a single padded 2D tensor.

        Note that this is not used during training.

        :param texts: The texts to tokenize.
        :param max_length: If this is None, the sequence lengths are truncated to 512.
        :return: A 2D padded tensor
        """
        encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
        encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
        return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)

    @property
    def device(self) -> str:
        """Get the device of the model."""
        return self.embeddings.weight.device

    def to_static_model(self) -> StaticModel:
        """Convert the model to a static model."""
        emb = self.embeddings.weight.detach().cpu().numpy()
        w = torch.sigmoid(self.w).detach().cpu().numpy()

        return StaticModel(emb * w[:, None], self.tokenizer, normalize=True)


class TextDataset(Dataset):
    def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:
        """
        A dataset of texts.

        :param tokenized_texts: The tokenized texts. Each text is a list of token ids.
        :param targets: The targets.
        :raises ValueError: If the number of labels does not match the number of texts.
        """
        if len(targets) != len(tokenized_texts):
            msg = "Number of labels does not match number of texts."
            raise ValueError(msg)
        self.tokenized_texts = tokenized_texts
        self.targets = targets

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.tokenized_texts)

    def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
        """Gets an item."""
        return self.tokenized_texts[index], self.targets[index]

    @staticmethod
    def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
        """Collate function."""
        texts, targets = zip(*batch, strict=False)

        tensors = [torch.LongTensor(x) for x in texts]
        padded = pad_sequence(tensors, batch_first=True, padding_value=0)

        return padded, torch.stack(targets)

    def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
        """Convert the dataset to a DataLoader."""
        return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)


ModelType = TypeVar("ModelType", bound=FinetunableStaticModel)