Sarthak commited on
Commit
473c3a0
·
1 Parent(s): 72121b3

chore: moved model2vec as in internal package

Browse files
src/distiller/model2vec/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model import StaticModel
2
+ from .version import __version__
3
+
4
+ __all__ = ["StaticModel", "__version__"]
src/distiller/model2vec/distill/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from distiller.model2vec.utils import get_package_extras, importable
2
+
3
+ _REQUIRED_EXTRA = "distill"
4
+
5
+ for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
6
+ importable(extra_dependency, _REQUIRED_EXTRA)
7
+
8
+ from distiller.model2vec.distill.distillation import distill, distill_from_model
9
+
10
+ __all__ = ["distill", "distill_from_model"]
src/distiller/model2vec/distill/distillation.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import re
6
+ from typing import cast
7
+
8
+ import numpy as np
9
+ from huggingface_hub import model_info
10
+ from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast
11
+
12
+ from distiller.model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
13
+ from distiller.model2vec.distill.utils import select_optimal_device
14
+ from distiller.model2vec.model import StaticModel
15
+ from distiller.model2vec.quantization import DType, quantize_embeddings
16
+ from distiller.model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def distill_from_model(
22
+ model: PreTrainedModel,
23
+ tokenizer: PreTrainedTokenizerFast,
24
+ vocabulary: list[str] | None = None,
25
+ device: str | None = None,
26
+ pca_dims: PCADimType = 256,
27
+ apply_zipf: bool | None = None,
28
+ sif_coefficient: float | None = 1e-4,
29
+ token_remove_pattern: str | None = r"\[unused\d+\]",
30
+ quantize_to: DType | str = DType.Float16,
31
+ use_subword: bool | None = None,
32
+ ) -> StaticModel:
33
+ """
34
+ Distill a staticmodel from a sentence transformer.
35
+
36
+ This function creates a set of embeddings from a sentence transformer. It does this by doing either
37
+ a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
38
+
39
+ If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
40
+ If you don't pass a vocabulary, we use the model's tokenizer directly.
41
+
42
+ :param model: The model to use.
43
+ :param tokenizer: The tokenizer to use.
44
+ :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
45
+ :param device: The device to use.
46
+ :param pca_dims: The number of components to use for PCA.
47
+ If this is None, we don't apply PCA.
48
+ If this is 'auto', we don't reduce dimensionality, but still apply PCA.
49
+ :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
50
+ Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
51
+ :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
52
+ Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
53
+ :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
54
+ If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
55
+ :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
56
+ :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
57
+ :return: A StaticModel
58
+ :raises: ValueError if the vocabulary is empty after preprocessing.
59
+
60
+ """
61
+ if use_subword is not None:
62
+ logger.warning(
63
+ "The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
64
+ )
65
+ quantize_to = DType(quantize_to)
66
+ backend_tokenizer = tokenizer.backend_tokenizer
67
+ sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)
68
+
69
+ if vocabulary is None:
70
+ vocabulary = []
71
+
72
+ device = select_optimal_device(device)
73
+
74
+ n_tokens_before = len(vocabulary)
75
+ # Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
76
+ all_tokens, backend_tokenizer = clean_and_create_vocabulary(
77
+ tokenizer, vocabulary, token_remove_regex=token_remove_regex
78
+ )
79
+ n_tokens_after = len([token for token in all_tokens if not token.is_internal])
80
+ if n_tokens_before:
81
+ logger.info(
82
+ f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
83
+ )
84
+
85
+ if not all_tokens:
86
+ msg = "The vocabulary is empty after preprocessing. Please check your token_remove_pattern."
87
+ raise ValueError(msg)
88
+
89
+ unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token"))
90
+ pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token"))
91
+
92
+ # Weird if to satsify mypy
93
+ if pad_token is None:
94
+ if unk_token is not None:
95
+ pad_token = unk_token
96
+ logger.warning(
97
+ "The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token."
98
+ )
99
+ else:
100
+ pad_token = unk_token or all_tokens[0].form
101
+ logger.warning(
102
+ "The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token."
103
+ )
104
+
105
+ # Replace the vocabulary in the tokenizer with the new vocabulary.
106
+ backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
107
+
108
+ logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
109
+ # Convert tokens to IDs
110
+ token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
111
+
112
+ # Create the embeddings
113
+ embeddings = create_embeddings(
114
+ tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
115
+ )
116
+
117
+ # Post process the embeddings by applying PCA and Zipf weighting.
118
+ embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
119
+ # Quantize the embeddings.
120
+ embeddings = quantize_embeddings(embeddings, quantize_to)
121
+
122
+ model_name = getattr(model, "name_or_path", "")
123
+
124
+ config = {
125
+ "model_type": "model2vec",
126
+ "architectures": ["StaticModel"],
127
+ "tokenizer_name": model_name,
128
+ "apply_pca": pca_dims,
129
+ "apply_zipf": apply_zipf,
130
+ "sif_coefficient": sif_coefficient,
131
+ "hidden_dim": embeddings.shape[1],
132
+ "seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
133
+ "normalize": True,
134
+ }
135
+
136
+ if os.path.exists(model_name):
137
+ # Using a local model. Get the model name from the path.
138
+ model_name = os.path.basename(model_name)
139
+ language = None
140
+ else:
141
+ # Get the language from the model card.
142
+ try:
143
+ info = model_info(model_name)
144
+ language = info.cardData.get("language", None) if info.cardData is not None else None
145
+ except Exception as e:
146
+ # NOTE: bare except because there's many reasons this can fail.
147
+ logger.warning(f"Couldn't get the model info from the Hugging Face Hub: {e}. Setting language to None.")
148
+ language = None
149
+
150
+ return StaticModel(
151
+ vectors=embeddings,
152
+ tokenizer=backend_tokenizer,
153
+ config=config,
154
+ base_model_name=model_name,
155
+ language=language,
156
+ normalize=True,
157
+ )
158
+
159
+
160
+ def _validate_parameters(
161
+ apply_zipf: bool | None,
162
+ sif_coefficient: float | None,
163
+ token_remove_pattern: str | None,
164
+ ) -> tuple[float | None, re.Pattern | None]:
165
+ """
166
+ Validate the parameters passed to the distillation function.
167
+
168
+ :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
169
+ Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
170
+ :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
171
+ Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
172
+ :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
173
+ :return: The SIF coefficient to use.
174
+ :raises: ValueError if the regex can't be compiled.
175
+
176
+ """
177
+ if apply_zipf is not None:
178
+ logger.warning(
179
+ "The `apply_zipf` parameter is deprecated and will be removed in the next release. "
180
+ "Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
181
+ "no weighting is applied."
182
+ )
183
+ if apply_zipf and sif_coefficient is None:
184
+ logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
185
+ sif_coefficient = 1e-4
186
+ elif not apply_zipf:
187
+ logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
188
+ sif_coefficient = None
189
+
190
+ if sif_coefficient is not None and not 0 < sif_coefficient < 1.0:
191
+ msg = "SIF coefficient must be a value > 0 and < 1.0."
192
+ raise ValueError(msg)
193
+
194
+ token_remove_regex: re.Pattern | None = None
195
+ if token_remove_pattern is not None:
196
+ try:
197
+ token_remove_regex = re.compile(token_remove_pattern)
198
+ except re.error as e:
199
+ msg = f"Couldn't compile the regex pattern: {e}"
200
+ raise ValueError(msg)
201
+
202
+ return sif_coefficient, token_remove_regex
203
+
204
+
205
+ def distill(
206
+ model_name: str,
207
+ vocabulary: list[str] | None = None,
208
+ device: str | None = None,
209
+ pca_dims: PCADimType = 256,
210
+ apply_zipf: bool | None = None,
211
+ sif_coefficient: float | None = 1e-4,
212
+ token_remove_pattern: str | None = r"\[unused\d+\]",
213
+ trust_remote_code: bool = False,
214
+ quantize_to: DType | str = DType.Float16,
215
+ use_subword: bool | None = None,
216
+ ) -> StaticModel:
217
+ """
218
+ Distill a staticmodel from a sentence transformer.
219
+
220
+ This function creates a set of embeddings from a sentence transformer. It does this by doing either
221
+ a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
222
+
223
+ If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
224
+ If you don't pass a vocabulary, we use the model's tokenizer directly.
225
+
226
+ :param model_name: The model name to use. Any sentencetransformer compatible model works.
227
+ :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
228
+ :param device: The device to use.
229
+ :param pca_dims: The number of components to use for PCA.
230
+ If this is None, we don't apply PCA.
231
+ If this is 'auto', we don't reduce dimenionality, but still apply PCA.
232
+ :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
233
+ Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
234
+ :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
235
+ Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
236
+ :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
237
+ :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
238
+ :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
239
+ :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
240
+ :return: A StaticModel
241
+
242
+ """
243
+ model: PreTrainedModel = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
244
+ tokenizer = cast(
245
+ "PreTrainedTokenizerFast",
246
+ AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, use_fast=True),
247
+ )
248
+
249
+ return distill_from_model(
250
+ model=model,
251
+ tokenizer=tokenizer,
252
+ vocabulary=vocabulary,
253
+ device=device,
254
+ pca_dims=pca_dims,
255
+ apply_zipf=apply_zipf,
256
+ token_remove_pattern=token_remove_pattern,
257
+ sif_coefficient=sif_coefficient,
258
+ quantize_to=quantize_to,
259
+ use_subword=use_subword,
260
+ )
src/distiller/model2vec/distill/inference.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Literal, Protocol, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from sklearn.decomposition import PCA
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from tqdm import tqdm
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers import PreTrainedModel
16
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ PathLike = Union[Path, str]
22
+ PCADimType = Union[int, None, float, Literal["auto"]]
23
+
24
+
25
+ _DEFAULT_BATCH_SIZE = 256
26
+
27
+
28
+ class ModulewithWeights(Protocol):
29
+ weight: torch.nn.Parameter
30
+
31
+
32
+ def create_embeddings(
33
+ model: PreTrainedModel,
34
+ tokenized: list[list[int]],
35
+ device: str,
36
+ pad_token_id: int,
37
+ ) -> np.ndarray:
38
+ """
39
+ Create output embeddings for a bunch of tokens using a pretrained model.
40
+
41
+ It does a forward pass for all tokens passed in `tokens`.
42
+
43
+ :param model: The model to use.
44
+ This should be a transformers model.
45
+ :param tokenized: All tokenized tokens.
46
+ :param device: The torch device to use.
47
+ :param pad_token_id: The pad token id. Used to pad sequences.
48
+ :return: The output embeddings.
49
+ """
50
+ model = model.to(device)
51
+
52
+ out_weights: np.ndarray
53
+ intermediate_weights: list[np.ndarray] = []
54
+
55
+ # Add token_type_ids only if the model supports it
56
+ add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args
57
+
58
+ lengths = np.asarray([len(sequence) for sequence in tokenized])
59
+ sort_order = np.argsort(lengths)
60
+
61
+ sorted_tokenized = [tokenized[i] for i in sort_order]
62
+
63
+ pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
64
+
65
+ for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
66
+ batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]
67
+
68
+ encoded = {}
69
+ encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
70
+ encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
71
+
72
+ if add_token_type_ids:
73
+ encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
74
+
75
+ out = _encode_mean_using_model(model, encoded)
76
+ intermediate_weights.extend(out.numpy())
77
+ pbar.update(len(batch))
78
+
79
+ # Sort the output back to the original order
80
+ intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
81
+ out_weights = np.stack(intermediate_weights)
82
+
83
+ out_weights = np.nan_to_num(out_weights)
84
+
85
+ return out_weights
86
+
87
+
88
+ @torch.no_grad()
89
+ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
90
+ """
91
+ Encode a batch of tokens using a model.
92
+
93
+ Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
94
+ So detection of these is necessary.
95
+
96
+ :param model: The model to use.
97
+ :param encodings: The encoded tokens to turn into features.
98
+ :return: The mean of the output for each token.
99
+ """
100
+ encodings = {k: v.to(model.device) for k, v in encodings.items()}
101
+ encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
102
+ out: torch.Tensor = encoded.last_hidden_state.cpu()
103
+ # NOTE: If the dtype is bfloat 16, we convert to float32,
104
+ # because numpy does not suport bfloat16
105
+ # See here: https://github.com/numpy/numpy/issues/19808
106
+ if out.dtype == torch.bfloat16:
107
+ out = out.float()
108
+
109
+ # Take the mean by averaging over the attention mask.
110
+ mask = encodings["attention_mask"].cpu().float()
111
+ mask /= mask.sum(1)[:, None]
112
+
113
+ return torch.bmm(mask[:, None, :].float(), out).squeeze(1)
114
+
115
+
116
+
117
+ def post_process_embeddings(
118
+ embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
119
+ ) -> np.ndarray:
120
+ """Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
121
+ if pca_dims is not None:
122
+ if pca_dims == "auto":
123
+ pca_dims = embeddings.shape[1]
124
+ if pca_dims > embeddings.shape[1]:
125
+ logger.warning(
126
+ f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
127
+ "Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
128
+ "Applying PCA will probably improve performance, so consider just leaving it."
129
+ )
130
+ pca_dims = embeddings.shape[1]
131
+ if pca_dims >= embeddings.shape[0]:
132
+ logger.warning(
133
+ f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
134
+ )
135
+ elif pca_dims <= embeddings.shape[1]:
136
+ if isinstance(pca_dims, float):
137
+ logger.info(f"Applying PCA with {pca_dims} explained variance.")
138
+ else:
139
+ logger.info(f"Applying PCA with n_components {pca_dims}")
140
+
141
+ orig_dims = embeddings.shape[1]
142
+ p = PCA(n_components=pca_dims, svd_solver="full")
143
+ embeddings = p.fit_transform(embeddings)
144
+
145
+ if embeddings.shape[1] < orig_dims:
146
+ explained_variance_ratio = np.sum(p.explained_variance_ratio_)
147
+ explained_variance = np.sum(p.explained_variance_)
148
+ logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
149
+ logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
150
+ logger.info(f"Explained variance: {explained_variance:.3f}.")
151
+
152
+ if sif_coefficient is not None:
153
+ logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
154
+ inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
155
+ proba = inv_rank / np.sum(inv_rank)
156
+ embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]
157
+
158
+ return embeddings
src/distiller/model2vec/distill/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+
5
+ import torch
6
+
7
+ logger = getLogger(__name__)
8
+
9
+
10
+ def select_optimal_device(device: str | None) -> str:
11
+ """
12
+ Guess what your optimal device should be based on backend availability.
13
+
14
+ If you pass a device, we just pass it through.
15
+
16
+ :param device: The device to use. If this is not None you get back what you passed.
17
+ :return: The selected device.
18
+ """
19
+ if device is None:
20
+ if torch.cuda.is_available():
21
+ device = "cuda"
22
+ elif torch.backends.mps.is_available():
23
+ device = "mps"
24
+ else:
25
+ device = "cpu"
26
+ logger.info(f"Automatically selected device: {device}")
27
+
28
+ return device
src/distiller/model2vec/hf_utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, cast
7
+
8
+ import huggingface_hub
9
+ import safetensors
10
+ from huggingface_hub import ModelCard, ModelCardData
11
+ from safetensors.numpy import save_file
12
+ from tokenizers import Tokenizer
13
+
14
+ if TYPE_CHECKING:
15
+ import numpy as np
16
+
17
+ from distiller.model2vec.utils import SafeOpenProtocol
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def save_pretrained(
23
+ folder_path: Path,
24
+ embeddings: np.ndarray,
25
+ tokenizer: Tokenizer,
26
+ config: dict[str, Any],
27
+ create_model_card: bool = True,
28
+ subfolder: str | None = None,
29
+ **kwargs: Any,
30
+ ) -> None:
31
+ """
32
+ Save a model to a folder.
33
+
34
+ :param folder_path: The path to the folder.
35
+ :param embeddings: The embeddings.
36
+ :param tokenizer: The tokenizer.
37
+ :param config: A metadata config.
38
+ :param create_model_card: Whether to create a model card.
39
+ :param subfolder: The subfolder to save the model in.
40
+ :param **kwargs: Any additional arguments.
41
+ """
42
+ folder_path = folder_path / subfolder if subfolder else folder_path
43
+ folder_path.mkdir(exist_ok=True, parents=True)
44
+ save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
45
+ tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
46
+ json.dump(config, open(folder_path / "config.json", "w"), indent=4)
47
+
48
+ # Create modules.json
49
+ modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}]
50
+ if config.get("normalize"):
51
+ # If normalize=True, add sentence_transformers.models.Normalize
52
+ modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"})
53
+ json.dump(modules, open(folder_path / "modules.json", "w"), indent=4)
54
+
55
+ logger.info(f"Saved model to {folder_path}")
56
+
57
+ # Optionally create the model card
58
+ if create_model_card:
59
+ _create_model_card(folder_path, **kwargs)
60
+
61
+
62
+ def _create_model_card(
63
+ folder_path: Path,
64
+ base_model_name: str = "unknown",
65
+ license: str = "mit",
66
+ language: list[str] | None = None,
67
+ model_name: str | None = None,
68
+ template_path: str = "modelcards/model_card_template.md",
69
+ **kwargs: Any,
70
+ ) -> None:
71
+ """
72
+ Create a model card and store it in the specified path.
73
+
74
+ :param folder_path: The path where the model card will be stored.
75
+ :param base_model_name: The name of the base model.
76
+ :param license: The license to use.
77
+ :param language: The language of the model.
78
+ :param model_name: The name of the model to use in the Model Card.
79
+ :param template_path: The path to the template.
80
+ :param **kwargs: Additional metadata for the model card (e.g., model_name, base_model, etc.).
81
+ """
82
+ folder_path = Path(folder_path)
83
+ model_name = model_name or folder_path.name
84
+ full_path = Path(__file__).parent / template_path
85
+
86
+ model_card_data = ModelCardData(
87
+ model_name=model_name,
88
+ base_model=base_model_name,
89
+ license=license,
90
+ language=language,
91
+ tags=["embeddings", "static-embeddings", "sentence-transformers"],
92
+ library_name="model2vec",
93
+ **kwargs,
94
+ )
95
+ model_card = ModelCard.from_template(model_card_data, template_path=str(full_path))
96
+ model_card.save(folder_path / "README.md")
97
+
98
+
99
+ def load_pretrained(
100
+ folder_or_repo_path: str | Path,
101
+ subfolder: str | None = None,
102
+ token: str | None = None,
103
+ from_sentence_transformers: bool = False,
104
+ ) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
105
+ """
106
+ Loads a pretrained model from a folder.
107
+
108
+ :param folder_or_repo_path: The folder or repo path to load from.
109
+ - If this is a local path, we will load from the local path.
110
+ - If the local path is not found, we will attempt to load from the huggingface hub.
111
+ :param subfolder: The subfolder to load from.
112
+ :param token: The huggingface token to use.
113
+ :param from_sentence_transformers: Whether to load the model from a sentence transformers model.
114
+ :raises: FileNotFoundError if the folder exists, but the file does not exist locally.
115
+ :return: The embeddings, tokenizer, config, and metadata.
116
+
117
+ """
118
+ if from_sentence_transformers:
119
+ model_file = "0_StaticEmbedding/model.safetensors"
120
+ tokenizer_file = "0_StaticEmbedding/tokenizer.json"
121
+ config_name = "config_sentence_transformers.json"
122
+ else:
123
+ model_file = "model.safetensors"
124
+ tokenizer_file = "tokenizer.json"
125
+ config_name = "config.json"
126
+
127
+ folder_or_repo_path = Path(folder_or_repo_path)
128
+
129
+ local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
130
+
131
+ if local_folder.exists():
132
+ embeddings_path = local_folder / model_file
133
+ if not embeddings_path.exists():
134
+ msg = f"Embeddings file does not exist in {local_folder}"
135
+ raise FileNotFoundError(msg)
136
+
137
+ config_path = local_folder / config_name
138
+ if not config_path.exists():
139
+ msg = f"Config file does not exist in {local_folder}"
140
+ raise FileNotFoundError(msg)
141
+
142
+ tokenizer_path = local_folder / tokenizer_file
143
+ if not tokenizer_path.exists():
144
+ msg = f"Tokenizer file does not exist in {local_folder}"
145
+ raise FileNotFoundError(msg)
146
+
147
+ # README is optional, so this is a bit finicky.
148
+ readme_path = local_folder / "README.md"
149
+ metadata = _get_metadata_from_readme(readme_path)
150
+
151
+ else:
152
+ logger.info("Folder does not exist locally, attempting to use huggingface hub.")
153
+ embeddings_path = Path(
154
+ huggingface_hub.hf_hub_download(
155
+ folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
156
+ )
157
+ )
158
+
159
+ try:
160
+ readme_path = Path(
161
+ huggingface_hub.hf_hub_download(
162
+ folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
163
+ )
164
+ )
165
+ metadata = _get_metadata_from_readme(Path(readme_path))
166
+ except Exception as e:
167
+ # NOTE: we don't want to raise an error here, since the README is optional.
168
+ logger.info(f"No README found in the model folder: {e} No model card loaded.")
169
+ metadata = {}
170
+
171
+ config_path = Path(
172
+ huggingface_hub.hf_hub_download(
173
+ folder_or_repo_path.as_posix(), config_name, token=token, subfolder=subfolder
174
+ )
175
+ )
176
+ tokenizer_path = Path(
177
+ huggingface_hub.hf_hub_download(
178
+ folder_or_repo_path.as_posix(), tokenizer_file, token=token, subfolder=subfolder
179
+ )
180
+ )
181
+
182
+ opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy"))
183
+ if from_sentence_transformers:
184
+ embeddings = opened_tensor_file.get_tensor("embedding.weight")
185
+ else:
186
+ embeddings = opened_tensor_file.get_tensor("embeddings")
187
+
188
+ tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
189
+ config = json.load(open(config_path))
190
+
191
+ if len(tokenizer.get_vocab()) != len(embeddings):
192
+ logger.warning(
193
+ f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
194
+ )
195
+
196
+ return embeddings, tokenizer, config, metadata
197
+
198
+
199
+ def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
200
+ """Get metadata from a README file."""
201
+ if not readme_path.exists():
202
+ logger.info(f"README file not found in {readme_path}. No model card loaded.")
203
+ return {}
204
+ model_card = ModelCard.load(readme_path)
205
+ data: dict[str, Any] = model_card.data.to_dict()
206
+ if not data:
207
+ logger.info("File README.md exists, but was empty. No model card loaded.")
208
+ return data
209
+
210
+
211
+ def push_folder_to_hub(
212
+ folder_path: Path, subfolder: str | None, repo_id: str, private: bool, token: str | None
213
+ ) -> None:
214
+ """
215
+ Push a model folder to the huggingface hub, including model card.
216
+
217
+ :param folder_path: The path to the folder.
218
+ :param subfolder: The subfolder to push to.
219
+ If None, the folder will be pushed to the root of the repo.
220
+ :param repo_id: The repo name.
221
+ :param private: Whether the repo is private.
222
+ :param token: The huggingface token.
223
+ """
224
+ if not huggingface_hub.repo_exists(repo_id=repo_id, token=token):
225
+ huggingface_hub.create_repo(repo_id, token=token, private=private)
226
+
227
+ # Push model card and all model files to the Hugging Face hub
228
+ huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder)
229
+
230
+ logger.info(f"Pushed model to {repo_id}")
src/distiller/model2vec/inference/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference
2
+
3
+ This subpackage mainly contains helper functions for inference with trained models that have been exported to `scikit-learn` compatible pipelines.
4
+
5
+ If you're looking for information on how to train a model, see [here](../train/README.md).
6
+
7
+ # Usage
8
+
9
+ Let's assume you're using our [potion-edu classifier](https://huggingface.co/minishlab/potion-8m-edu-classifier).
10
+
11
+ ```python
12
+ from model2vec.inference import StaticModelPipeline
13
+
14
+ classifier = StaticModelPipeline.from_pretrained("minishlab/potion-8m-edu-classifier")
15
+ label = classifier.predict("Attitudes towards cattle in the Alps: a study in letting go.")
16
+ ```
17
+
18
+ This should just work.
src/distiller/model2vec/inference/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from distiller.model2vec.utils import get_package_extras, importable
2
+
3
+ _REQUIRED_EXTRA = "inference"
4
+
5
+ for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
6
+ importable(extra_dependency, _REQUIRED_EXTRA)
7
+
8
+ from distiller.model2vec.inference.model import StaticModelPipeline, evaluate_single_or_multi_label
9
+
10
+ __all__ = ["StaticModelPipeline", "evaluate_single_or_multi_label"]
src/distiller/model2vec/inference/model.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from tempfile import TemporaryDirectory
6
+ from typing import TYPE_CHECKING, TypeVar
7
+
8
+ import huggingface_hub
9
+ import numpy as np
10
+ import skops.io
11
+ from sklearn.metrics import classification_report
12
+ from sklearn.neural_network import MLPClassifier
13
+ from sklearn.preprocessing import MultiLabelBinarizer
14
+
15
+ from distiller.model2vec.hf_utils import _create_model_card
16
+ from distiller.model2vec.model import PathLike, StaticModel
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import Sequence
20
+
21
+ from sklearn.pipeline import Pipeline
22
+
23
+ _DEFAULT_TRUST_PATTERN = re.compile(r"sklearn\..+")
24
+ _DEFAULT_MODEL_FILENAME = "pipeline.skops"
25
+
26
+ LabelType = TypeVar("LabelType", list[str], list[list[str]])
27
+
28
+
29
+ class StaticModelPipeline:
30
+ def __init__(self, model: StaticModel, head: Pipeline) -> None:
31
+ """Create a pipeline with a StaticModel encoder."""
32
+ self.model = model
33
+ self.head = head
34
+ classifier = self.head[-1]
35
+ # Check if the classifier is a multilabel classifier.
36
+ # NOTE: this doesn't look robust, but it is.
37
+ # Different classifiers, such as OVR wrappers, support multilabel output natively, so we
38
+ # can just use predict.
39
+ self.multilabel = False
40
+ if isinstance(classifier, MLPClassifier) and classifier.out_activation_ == "logistic":
41
+ self.multilabel = True
42
+
43
+ @property
44
+ def classes_(self) -> np.ndarray:
45
+ """The classes of the classifier."""
46
+ return self.head.classes_
47
+
48
+ @classmethod
49
+ def from_pretrained(
50
+ cls: type[StaticModelPipeline], path: PathLike, token: str | None = None, trust_remote_code: bool = False
51
+ ) -> StaticModelPipeline:
52
+ """
53
+ Load a StaticModel from a local path or huggingface hub path.
54
+
55
+ NOTE: if you load a private model from the huggingface hub, you need to pass a token.
56
+
57
+ :param path: The path to the folder containing the pipeline, or a repository on the Hugging Face Hub
58
+ :param token: The token to use to download the pipeline from the hub.
59
+ :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `sklearn`.
60
+ :return: The loaded pipeline.
61
+ """
62
+ model, head = _load_pipeline(path, token, trust_remote_code)
63
+ model.embedding = np.nan_to_num(model.embedding)
64
+
65
+ return cls(model, head)
66
+
67
+ def save_pretrained(self, path: str) -> None:
68
+ """Save the model to a folder."""
69
+ save_pipeline(self, path)
70
+
71
+ def push_to_hub(
72
+ self, repo_id: str, subfolder: str | None = None, token: str | None = None, private: bool = False
73
+ ) -> None:
74
+ """
75
+ Save a model to a folder, and then push that folder to the hf hub.
76
+
77
+ :param repo_id: The id of the repository to push to.
78
+ :param subfolder: The subfolder to push to.
79
+ :param token: The token to use to push to the hub.
80
+ :param private: Whether the repository should be private.
81
+ """
82
+ from distiller.model2vec.hf_utils import push_folder_to_hub
83
+
84
+ with TemporaryDirectory() as temp_dir:
85
+ save_pipeline(self, temp_dir)
86
+ self.model.save_pretrained(temp_dir)
87
+ push_folder_to_hub(Path(temp_dir), subfolder, repo_id, private, token)
88
+
89
+ def _encode_and_coerce_to_2d(
90
+ self,
91
+ X: Sequence[str],
92
+ show_progress_bar: bool,
93
+ max_length: int | None,
94
+ batch_size: int,
95
+ use_multiprocessing: bool,
96
+ multiprocessing_threshold: int,
97
+ ) -> np.ndarray:
98
+ """Encode the instances and coerce the output to a matrix."""
99
+ encoded = self.model.encode(
100
+ X,
101
+ show_progress_bar=show_progress_bar,
102
+ max_length=max_length,
103
+ batch_size=batch_size,
104
+ use_multiprocessing=use_multiprocessing,
105
+ multiprocessing_threshold=multiprocessing_threshold,
106
+ )
107
+ if np.ndim(encoded) == 1:
108
+ encoded = encoded[None, :]
109
+
110
+ return encoded
111
+
112
+ def predict(
113
+ self,
114
+ X: Sequence[str],
115
+ show_progress_bar: bool = False,
116
+ max_length: int | None = 512,
117
+ batch_size: int = 1024,
118
+ use_multiprocessing: bool = True,
119
+ multiprocessing_threshold: int = 10_000,
120
+ threshold: float = 0.5,
121
+ ) -> np.ndarray:
122
+ """
123
+ Predict the labels of the input.
124
+
125
+ :param X: The input data to predict. Can be a list of strings or a single string.
126
+ :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
127
+ :param max_length: The maximum length of the input sequences. Defaults to 512.
128
+ :param batch_size: The batch size for prediction. Defaults to 1024.
129
+ :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
130
+ :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
131
+ :param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel.
132
+ :return: The predicted labels or probabilities.
133
+ """
134
+ encoded = self._encode_and_coerce_to_2d(
135
+ X,
136
+ show_progress_bar=show_progress_bar,
137
+ max_length=max_length,
138
+ batch_size=batch_size,
139
+ use_multiprocessing=use_multiprocessing,
140
+ multiprocessing_threshold=multiprocessing_threshold,
141
+ )
142
+
143
+ if self.multilabel:
144
+ out_labels = []
145
+ proba = self.head.predict_proba(encoded)
146
+ for vector in proba:
147
+ out_labels.append(self.classes_[vector > threshold])
148
+ return np.asarray(out_labels, dtype=object)
149
+
150
+ return self.head.predict(encoded)
151
+
152
+ def predict_proba(
153
+ self,
154
+ X: Sequence[str],
155
+ show_progress_bar: bool = False,
156
+ max_length: int | None = 512,
157
+ batch_size: int = 1024,
158
+ use_multiprocessing: bool = True,
159
+ multiprocessing_threshold: int = 10_000,
160
+ ) -> np.ndarray:
161
+ """
162
+ Predict the labels of the input.
163
+
164
+ :param X: The input data to predict. Can be a list of strings or a single string.
165
+ :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
166
+ :param max_length: The maximum length of the input sequences. Defaults to 512.
167
+ :param batch_size: The batch size for prediction. Defaults to 1024.
168
+ :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
169
+ :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
170
+ :return: The predicted labels or probabilities.
171
+ """
172
+ encoded = self._encode_and_coerce_to_2d(
173
+ X,
174
+ show_progress_bar=show_progress_bar,
175
+ max_length=max_length,
176
+ batch_size=batch_size,
177
+ use_multiprocessing=use_multiprocessing,
178
+ multiprocessing_threshold=multiprocessing_threshold,
179
+ )
180
+
181
+ return self.head.predict_proba(encoded)
182
+
183
+ def evaluate(
184
+ self, X: Sequence[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
185
+ ) -> str | dict[str, dict[str, float]]:
186
+ """
187
+ Evaluate the classifier on a given dataset using scikit-learn's classification report.
188
+
189
+ :param X: The texts to predict on.
190
+ :param y: The ground truth labels.
191
+ :param batch_size: The batch size.
192
+ :param threshold: The threshold for multilabel classification.
193
+ :param output_dict: Whether to output the classification report as a dictionary.
194
+ :return: A classification report.
195
+ """
196
+ predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
197
+ return evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)
198
+
199
+
200
+
201
+ def _load_pipeline(
202
+ folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False
203
+ ) -> tuple[StaticModel, Pipeline]:
204
+ """
205
+ Load a model and an sklearn pipeline.
206
+
207
+ This assumes the following files are present in the repo:
208
+ - `pipeline.skops`: The head of the pipeline.
209
+ - `config.json`: The configuration of the model.
210
+ - `model.safetensors`: The weights of the model.
211
+ - `tokenizer.json`: The tokenizer of the model.
212
+
213
+ :param folder_or_repo_path: The path to the folder containing the pipeline.
214
+ :param token: The token to use to download the pipeline from the hub. If this is None, you will only
215
+ be able to load the pipeline from a local folder, public repository, or a repository that you have access to
216
+ because you are logged in.
217
+ :param trust_remote_code: Whether to trust the remote code. If this is False,
218
+ we will only load components coming from `sklearn`. If this is True, we will load all components.
219
+ If you set this to True, you are responsible for whatever happens.
220
+ :return: The encoder model and the loaded head
221
+ :raises FileNotFoundError: If the pipeline file does not exist in the folder.
222
+ :raises ValueError: If an untrusted type is found in the pipeline, and `trust_remote_code` is False.
223
+ """
224
+ folder_or_repo_path = Path(folder_or_repo_path)
225
+ model_filename = _DEFAULT_MODEL_FILENAME
226
+ head_pipeline_path: str | Path
227
+ if folder_or_repo_path.exists():
228
+ head_pipeline_path = folder_or_repo_path / model_filename
229
+ if not head_pipeline_path.exists():
230
+ msg = f"Pipeline file does not exist in {folder_or_repo_path}"
231
+ raise FileNotFoundError(msg)
232
+ else:
233
+ head_pipeline_path = huggingface_hub.hf_hub_download(
234
+ folder_or_repo_path.as_posix(), model_filename, token=token
235
+ )
236
+
237
+ model = StaticModel.from_pretrained(folder_or_repo_path)
238
+
239
+ unknown_types = skops.io.get_untrusted_types(file=head_pipeline_path)
240
+ # If the user does not trust remote code, we should check that the unknown types are trusted.
241
+ # By default, we trust everything coming from scikit-learn.
242
+ if not trust_remote_code:
243
+ for t in unknown_types:
244
+ if not _DEFAULT_TRUST_PATTERN.match(t):
245
+ msg = f"Untrusted type {t}."
246
+ raise ValueError(msg)
247
+ head = skops.io.load(head_pipeline_path, trusted=unknown_types)
248
+
249
+ return model, head
250
+
251
+
252
+ def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> None:
253
+ """
254
+ Save a pipeline to a folder.
255
+
256
+ :param pipeline: The pipeline to save.
257
+ :param folder_path: The path to the folder to save the pipeline to.
258
+ """
259
+ folder_path = Path(folder_path)
260
+ folder_path.mkdir(parents=True, exist_ok=True)
261
+ model_filename = _DEFAULT_MODEL_FILENAME
262
+ head_pipeline_path = folder_path / model_filename
263
+ skops.io.dump(pipeline.head, head_pipeline_path)
264
+ pipeline.model.save_pretrained(folder_path)
265
+ base_model_name = pipeline.model.base_model_name
266
+ if isinstance(base_model_name, list) and base_model_name:
267
+ name = base_model_name[0]
268
+ elif isinstance(base_model_name, str):
269
+ name = base_model_name
270
+ else:
271
+ name = "unknown"
272
+ _create_model_card(
273
+ folder_path,
274
+ base_model_name=name,
275
+ language=pipeline.model.language,
276
+ template_path="modelcards/classifier_template.md",
277
+ )
278
+
279
+
280
+ def _is_multi_label_shaped(y: LabelType) -> bool:
281
+ """Check if the labels are in a multi-label shape."""
282
+ return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set))
283
+
284
+
285
+ def evaluate_single_or_multi_label(
286
+ predictions: np.ndarray,
287
+ y: LabelType,
288
+ output_dict: bool = False,
289
+ ) -> str | dict[str, dict[str, float]]:
290
+ """
291
+ Evaluate the classifier on a given dataset using scikit-learn's classification report.
292
+
293
+ :param predictions: The predictions.
294
+ :param y: The ground truth labels.
295
+ :param output_dict: Whether to output the classification report as a dictionary.
296
+ :return: A classification report.
297
+ """
298
+ if _is_multi_label_shaped(y):
299
+ classes = sorted({label for labels in y for label in labels})
300
+ mlb = MultiLabelBinarizer(classes=classes)
301
+ y = mlb.fit_transform(y)
302
+ predictions = mlb.transform(predictions)
303
+ elif isinstance(y[0], (str, int)):
304
+ classes = sorted(set(y))
305
+
306
+ return classification_report(
307
+ y,
308
+ predictions,
309
+ output_dict=output_dict,
310
+ zero_division=0,
311
+ )
312
+
src/distiller/model2vec/model.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import os
5
+ from logging import getLogger
6
+ from pathlib import Path
7
+ from tempfile import TemporaryDirectory
8
+ from typing import TYPE_CHECKING, Any, Union, overload
9
+
10
+ import numpy as np
11
+ from joblib import delayed
12
+ from tqdm import tqdm
13
+
14
+ from .quantization import DType, quantize_and_reduce_dim
15
+ from .utils import ProgressParallel, load_local_model
16
+
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Iterator, Sequence
19
+
20
+ from tokenizers import Encoding, Tokenizer
21
+
22
+ PathLike = Union[Path, str]
23
+
24
+ logger = getLogger(__name__)
25
+
26
+
27
+ class StaticModel:
28
+ def __init__(
29
+ self,
30
+ vectors: np.ndarray,
31
+ tokenizer: Tokenizer,
32
+ config: dict[str, Any] | None = None,
33
+ normalize: bool | None = None,
34
+ base_model_name: str | None = None,
35
+ language: list[str] | None = None,
36
+ ) -> None:
37
+ """
38
+ Initialize the StaticModel.
39
+
40
+ :param vectors: The vectors to use.
41
+ :param tokenizer: The Transformers tokenizer to use.
42
+ :param config: Any metadata config.
43
+ :param normalize: Whether to normalize the embeddings.
44
+ :param base_model_name: The used base model name. Used for creating a model card.
45
+ :param language: The language of the model. Used for creating a model card.
46
+ :raises: ValueError if the number of tokens does not match the number of vectors.
47
+ """
48
+ super().__init__()
49
+ tokens, _ = zip(*sorted(tokenizer.get_vocab().items(), key=lambda x: x[1]), strict=False)
50
+ self.tokens = tokens
51
+
52
+ self.embedding = vectors
53
+
54
+ if len(tokens) != vectors.shape[0]:
55
+ msg = f"Number of tokens ({len(tokens)}) does not match number of vectors ({vectors.shape[0]})"
56
+ raise ValueError(msg)
57
+
58
+ self.tokenizer = tokenizer
59
+ self.unk_token_id: int | None
60
+ if hasattr(self.tokenizer.model, "unk_token") and self.tokenizer.model.unk_token is not None:
61
+ self.unk_token_id = tokenizer.get_vocab()[self.tokenizer.model.unk_token]
62
+ else:
63
+ self.unk_token_id = None # pragma: no cover # Doesn't actually happen, but can happen.
64
+
65
+ self.median_token_length = int(np.median([len(token) for token in self.tokens]))
66
+ self.config = config or {}
67
+ self.base_model_name = base_model_name
68
+ self.language = language
69
+ if hasattr(self.tokenizer, "encode_batch_fast"):
70
+ self._can_encode_fast = True
71
+ else:
72
+ self._can_encode_fast = False
73
+
74
+ if normalize is not None:
75
+ self.normalize = normalize
76
+ else:
77
+ self.normalize = self.config.get("normalize", False)
78
+
79
+ @property
80
+ def dim(self) -> int:
81
+ """Get the dimension of the model."""
82
+ return self.embedding.shape[1]
83
+
84
+ @property
85
+ def normalize(self) -> bool:
86
+ """
87
+ Get the normalize value.
88
+
89
+ :return: The normalize value.
90
+ """
91
+ return self._normalize
92
+
93
+ @normalize.setter
94
+ def normalize(self, value: bool) -> None:
95
+ """Update the config if the value of normalize changes."""
96
+ config_normalize = self.config.get("normalize")
97
+ self._normalize = value
98
+ if config_normalize is not None and value != config_normalize:
99
+ logger.warning(
100
+ f"Set normalization to `{value}`, which does not match config value `{config_normalize}`. Updating config."
101
+ )
102
+ self.config["normalize"] = value
103
+
104
+ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfolder: str | None = None) -> None:
105
+ """
106
+ Save the pretrained model.
107
+
108
+ :param path: The path to save to.
109
+ :param model_name: The model name to use in the Model Card.
110
+ :param subfolder: The subfolder to save to.
111
+ """
112
+ from .hf_utils import save_pretrained
113
+
114
+ save_pretrained(
115
+ folder_path=Path(path),
116
+ embeddings=self.embedding,
117
+ tokenizer=self.tokenizer,
118
+ config=self.config,
119
+ base_model_name=self.base_model_name,
120
+ language=self.language,
121
+ model_name=model_name,
122
+ subfolder=subfolder,
123
+ )
124
+
125
+ def tokenize(self, sentences: Sequence[str], max_length: int | None = None) -> list[list[int]]:
126
+ """
127
+ Tokenize a list of sentences.
128
+
129
+ :param sentences: The sentences to tokenize.
130
+ :param max_length: The maximum length of the sentences in tokens. If this is None, sequences
131
+ are not truncated.
132
+ :return: A list of list of tokens.
133
+ """
134
+ if max_length is not None:
135
+ m = max_length * self.median_token_length
136
+ sentences = [sentence[:m] for sentence in sentences]
137
+
138
+ if self._can_encode_fast:
139
+ encodings: list[Encoding] = self.tokenizer.encode_batch_fast(sentences, add_special_tokens=False)
140
+ else:
141
+ encodings = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
142
+
143
+ encodings_ids = [encoding.ids for encoding in encodings]
144
+
145
+ if self.unk_token_id is not None:
146
+ # NOTE: Remove the unknown token: necessary for word-level models.
147
+ encodings_ids = [
148
+ [token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
149
+ ]
150
+ if max_length is not None:
151
+ encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
152
+
153
+ return encodings_ids
154
+
155
+ @classmethod
156
+ def from_pretrained(
157
+ cls: type[StaticModel],
158
+ path: PathLike,
159
+ token: str | None = None,
160
+ normalize: bool | None = None,
161
+ subfolder: str | None = None,
162
+ quantize_to: str | DType | None = None,
163
+ dimensionality: int | None = None,
164
+ ) -> StaticModel:
165
+ """
166
+ Load a StaticModel from a local path or huggingface hub path.
167
+
168
+ NOTE: if you load a private model from the huggingface hub, you need to pass a token.
169
+
170
+ :param path: The path to load your static model from.
171
+ :param token: The huggingface token to use.
172
+ :param normalize: Whether to normalize the embeddings.
173
+ :param subfolder: The subfolder to load from.
174
+ :param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
175
+ If a string is passed, it is converted to a DType.
176
+ :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
177
+ This is useful if you want to load a model with a lower dimensionality.
178
+ Note that this only applies if you have trained your model using mrl or PCA.
179
+ :return: A StaticModel.
180
+ """
181
+ from .hf_utils import load_pretrained
182
+
183
+ embeddings, tokenizer, config, metadata = load_pretrained(
184
+ folder_or_repo_path=path,
185
+ token=token,
186
+ from_sentence_transformers=False,
187
+ subfolder=subfolder,
188
+ )
189
+
190
+ embeddings = quantize_and_reduce_dim(
191
+ embeddings=embeddings,
192
+ quantize_to=quantize_to,
193
+ dimensionality=dimensionality,
194
+ )
195
+
196
+ return cls(
197
+ embeddings,
198
+ tokenizer,
199
+ config,
200
+ normalize=normalize,
201
+ base_model_name=metadata.get("base_model"),
202
+ language=metadata.get("language"),
203
+ )
204
+
205
+ @classmethod
206
+ def from_sentence_transformers(
207
+ cls: type[StaticModel],
208
+ path: PathLike,
209
+ token: str | None = None,
210
+ normalize: bool | None = None,
211
+ quantize_to: str | DType | None = None,
212
+ dimensionality: int | None = None,
213
+ ) -> StaticModel:
214
+ """
215
+ Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
216
+
217
+ NOTE: if you load a private model from the huggingface hub, you need to pass a token.
218
+
219
+ :param path: The path to load your static model from.
220
+ :param token: The huggingface token to use.
221
+ :param normalize: Whether to normalize the embeddings.
222
+ :param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
223
+ If a string is passed, it is converted to a DType.
224
+ :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
225
+ This is useful if you want to load a model with a lower dimensionality.
226
+ Note that this only applies if you have trained your model using mrl or PCA.
227
+ :return: A StaticModel.
228
+ """
229
+ from .hf_utils import load_pretrained
230
+
231
+ embeddings, tokenizer, config, metadata = load_pretrained(
232
+ folder_or_repo_path=path,
233
+ token=token,
234
+ from_sentence_transformers=True,
235
+ subfolder=None,
236
+ )
237
+
238
+ embeddings = quantize_and_reduce_dim(
239
+ embeddings=embeddings,
240
+ quantize_to=quantize_to,
241
+ dimensionality=dimensionality,
242
+ )
243
+
244
+ return cls(
245
+ embeddings,
246
+ tokenizer,
247
+ config,
248
+ normalize=normalize,
249
+ base_model_name=metadata.get("base_model"),
250
+ language=metadata.get("language"),
251
+ )
252
+
253
+ @overload
254
+ def encode_as_sequence(
255
+ self,
256
+ sentences: str,
257
+ max_length: int | None = None,
258
+ batch_size: int = 1024,
259
+ show_progress_bar: bool = False,
260
+ use_multiprocessing: bool = True,
261
+ multiprocessing_threshold: int = 10_000,
262
+ ) -> np.ndarray: ...
263
+
264
+ @overload
265
+ def encode_as_sequence(
266
+ self,
267
+ sentences: list[str],
268
+ max_length: int | None = None,
269
+ batch_size: int = 1024,
270
+ show_progress_bar: bool = False,
271
+ use_multiprocessing: bool = True,
272
+ multiprocessing_threshold: int = 10_000,
273
+ ) -> list[np.ndarray]: ...
274
+
275
+ def encode_as_sequence(
276
+ self,
277
+ sentences: str | list[str],
278
+ max_length: int | None = None,
279
+ batch_size: int = 1024,
280
+ show_progress_bar: bool = False,
281
+ use_multiprocessing: bool = True,
282
+ multiprocessing_threshold: int = 10_000,
283
+ ) -> list[np.ndarray] | np.ndarray:
284
+ """
285
+ Encode a list of sentences as a list of numpy arrays of tokens.
286
+
287
+ This is useful if you want to use the tokens for further processing, or if you want to do sequence
288
+ modeling.
289
+ Note that if you just want the mean, you should use the `encode` method.
290
+ This is about twice as slow.
291
+ Sentences that do not contain any tokens will be turned into an empty array.
292
+
293
+ NOTE: the input type is currently underspecified. The actual input type is `Sequence[str] | str`, but this
294
+ is not possible to implement in python typing currently.
295
+
296
+ :param sentences: The list of sentences to encode.
297
+ :param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
298
+ If this is None, no truncation is done.
299
+ :param batch_size: The batch size to use.
300
+ :param show_progress_bar: Whether to show the progress bar.
301
+ :param use_multiprocessing: Whether to use multiprocessing.
302
+ By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
303
+ :param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
304
+ :return: The encoded sentences with an embedding per token.
305
+ """
306
+ was_single = False
307
+ if isinstance(sentences, str):
308
+ sentences = [sentences]
309
+ was_single = True
310
+
311
+ # Prepare all batches
312
+ sentence_batches = list(self._batch(sentences, batch_size))
313
+ total_batches = math.ceil(len(sentences) / batch_size)
314
+
315
+ # Use joblib for multiprocessing if requested, and if we have enough sentences
316
+ if use_multiprocessing and len(sentences) > multiprocessing_threshold:
317
+ # Disable parallelism for tokenizers
318
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
319
+
320
+ results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
321
+ delayed(self._encode_batch_as_sequence)(batch, max_length) for batch in sentence_batches
322
+ )
323
+ out_array: list[np.ndarray] = []
324
+ for r in results:
325
+ out_array.extend(r)
326
+ else:
327
+ out_array = []
328
+ for batch in tqdm(
329
+ sentence_batches,
330
+ total=total_batches,
331
+ disable=not show_progress_bar,
332
+ ):
333
+ out_array.extend(self._encode_batch_as_sequence(batch, max_length))
334
+
335
+ if was_single:
336
+ return out_array[0]
337
+ return out_array
338
+
339
+ def _encode_batch_as_sequence(self, sentences: Sequence[str], max_length: int | None) -> list[np.ndarray]:
340
+ """Encode a batch of sentences as a sequence."""
341
+ ids = self.tokenize(sentences=sentences, max_length=max_length)
342
+ out: list[np.ndarray] = []
343
+ for id_list in ids:
344
+ if id_list:
345
+ out.append(self.embedding[id_list])
346
+ else:
347
+ out.append(np.zeros((0, self.dim)))
348
+
349
+ return out
350
+
351
+ def encode(
352
+ self,
353
+ sentences: Sequence[str],
354
+ show_progress_bar: bool = False,
355
+ max_length: int | None = 512,
356
+ batch_size: int = 1024,
357
+ use_multiprocessing: bool = True,
358
+ multiprocessing_threshold: int = 10_000,
359
+ **kwargs: Any,
360
+ ) -> np.ndarray:
361
+ """
362
+ Encode a list of sentences.
363
+
364
+ This function encodes a list of sentences by averaging the word embeddings of the tokens in the sentence.
365
+ For ease of use, we don't batch sentences together.
366
+
367
+ NOTE: the return type is currently underspecified. In the case of a single string, this returns a 1D array,
368
+ but in the case of a list of strings, this returns a 2D array. Not possible to implement in numpy currently.
369
+
370
+ :param sentences: The list of sentences to encode. You can also pass a single sentence.
371
+ :param show_progress_bar: Whether to show the progress bar.
372
+ :param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
373
+ If this is None, no truncation is done.
374
+ :param batch_size: The batch size to use.
375
+ :param use_multiprocessing: Whether to use multiprocessing.
376
+ By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
377
+ :param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
378
+ :param **kwargs: Any additional arguments. These are ignored.
379
+ :return: The encoded sentences. If a single sentence was passed, a vector is returned.
380
+ """
381
+ was_single = False
382
+ if isinstance(sentences, str):
383
+ sentences = [sentences]
384
+ was_single = True
385
+
386
+ # Prepare all batches
387
+ sentence_batches = list(self._batch(sentences, batch_size))
388
+ total_batches = math.ceil(len(sentences) / batch_size)
389
+
390
+ # Use joblib for multiprocessing if requested, and if we have enough sentences
391
+ if use_multiprocessing and len(sentences) > multiprocessing_threshold:
392
+ # Disable parallelism for tokenizers
393
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
394
+
395
+ results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
396
+ delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
397
+ )
398
+ out_array = np.concatenate(results, axis=0)
399
+ else:
400
+ # Don't use multiprocessing
401
+ out_arrays: list[np.ndarray] = []
402
+ for batch in tqdm(
403
+ sentence_batches,
404
+ total=total_batches,
405
+ disable=not show_progress_bar,
406
+ ):
407
+ out_arrays.append(self._encode_batch(batch, max_length))
408
+ out_array = np.concatenate(out_arrays, axis=0)
409
+
410
+ if was_single:
411
+ return out_array[0]
412
+ return out_array
413
+
414
+ def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.ndarray:
415
+ """Encode a batch of sentences."""
416
+ ids = self.tokenize(sentences=sentences, max_length=max_length)
417
+ out: list[np.ndarray] = []
418
+ for id_list in ids:
419
+ if id_list:
420
+ out.append(self.embedding[id_list].mean(0))
421
+ else:
422
+ out.append(np.zeros(self.dim))
423
+
424
+ out_array = np.stack(out)
425
+ if self.normalize:
426
+ norm = np.linalg.norm(out_array, axis=1, keepdims=True) + 1e-32
427
+ out_array = out_array / norm
428
+
429
+ return out_array
430
+
431
+ @staticmethod
432
+ def _batch(sentences: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]:
433
+ """Batch the sentences into equal-sized."""
434
+ return (sentences[i : i + batch_size] for i in range(0, len(sentences), batch_size))
435
+
436
+ def push_to_hub(
437
+ self, repo_id: str, private: bool = False, token: str | None = None, subfolder: str | None = None
438
+ ) -> None:
439
+ """
440
+ Push the model to the huggingface hub.
441
+
442
+ NOTE: you need to pass a token if you are pushing a private model.
443
+
444
+ :param repo_id: The repo id to push to.
445
+ :param private: Whether the repo, if created is set to private.
446
+ If the repo already exists, this doesn't change the visibility.
447
+ :param token: The huggingface token to use.
448
+ :param subfolder: The subfolder to push to.
449
+ """
450
+ from .hf_utils import push_folder_to_hub
451
+
452
+ with TemporaryDirectory() as temp_dir:
453
+ self.save_pretrained(temp_dir, model_name=repo_id)
454
+ push_folder_to_hub(Path(temp_dir), subfolder=subfolder, repo_id=repo_id, private=private, token=token)
455
+
456
+ @classmethod
457
+ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
458
+ """
459
+ Loads a model from a local path.
460
+
461
+ You should only use this code path if you are concerned with start-up time.
462
+ Loading via the `from_pretrained` method is safer, and auto-downloads, but
463
+ also means we import a whole bunch of huggingface code that we don't need.
464
+
465
+ Additionally, huggingface will check the most recent version of the model,
466
+ which can be slow.
467
+
468
+ :param path: The path to load the model from. The path is a directory saved by the
469
+ `save_pretrained` method.
470
+ :return: A StaticModel
471
+ :raises: ValueError if the path is not a directory.
472
+ """
473
+ path = Path(path)
474
+ if not path.is_dir():
475
+ msg = f"Path {path} is not a directory."
476
+ raise ValueError(msg)
477
+
478
+ embeddings, tokenizer, config = load_local_model(path)
479
+
480
+ return StaticModel(embeddings, tokenizer, config)
src/distiller/model2vec/modelcards/classifier_template.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ {{ card_data }}
3
+ ---
4
+
5
+ # {{ model_name }} Model Card
6
+
7
+ This [Model2Vec](https://github.com/MinishLab/model2vec) model is a fine-tuned version of {% if base_model %}the [{{ base_model }}](https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Model2Vec model. It also includes a classifier head on top.
8
+
9
+ ## Installation
10
+
11
+ Install model2vec using pip:
12
+ ```
13
+ pip install model2vec[inference]
14
+ ```
15
+
16
+ ## Usage
17
+ Load this model using the `from_pretrained` method:
18
+ ```python
19
+ from model2vec.inference import StaticModelPipeline
20
+
21
+ # Load a pretrained Model2Vec model
22
+ model = StaticModelPipeline.from_pretrained("{{ model_name }}")
23
+
24
+ # Predict labels
25
+ predicted = model.predict(["Example sentence"])
26
+ ```
27
+
28
+ ## Additional Resources
29
+
30
+ - [Model2Vec Repo](https://github.com/MinishLab/model2vec)
31
+ - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
32
+ - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
33
+ - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials)
34
+ - [Website](https://minishlab.github.io/)
35
+
36
+ ## Library Authors
37
+
38
+ Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
39
+
40
+ ## Citation
41
+
42
+ Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
43
+ ```
44
+ @article{minishlab2024model2vec,
45
+ author = {Tulkens, Stephan and {van Dongen}, Thomas},
46
+ title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
47
+ year = {2024},
48
+ url = {https://github.com/MinishLab/model2vec}
49
+ }
50
+ ```
src/distiller/model2vec/modelcards/model_card_template.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ {{ card_data }}
3
+ ---
4
+
5
+ # {{ model_name }} Model Card
6
+
7
+ This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of {% if base_model %}the {{ base_model }}(https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers.
8
+
9
+
10
+ ## Installation
11
+
12
+ Install model2vec using pip:
13
+ ```
14
+ pip install model2vec
15
+ ```
16
+
17
+ ## Usage
18
+
19
+ ### Using Model2Vec
20
+
21
+ The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models.
22
+
23
+ Load this model using the `from_pretrained` method:
24
+ ```python
25
+ from model2vec import StaticModel
26
+
27
+ # Load a pretrained Model2Vec model
28
+ model = StaticModel.from_pretrained("{{ model_name }}")
29
+
30
+ # Compute text embeddings
31
+ embeddings = model.encode(["Example sentence"])
32
+ ```
33
+
34
+ ### Using Sentence Transformers
35
+
36
+ You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model:
37
+
38
+ ```python
39
+ from sentence_transformers import SentenceTransformer
40
+
41
+ # Load a pretrained Sentence Transformer model
42
+ model = SentenceTransformer("{{ model_name }}")
43
+
44
+ # Compute text embeddings
45
+ embeddings = model.encode(["Example sentence"])
46
+ ```
47
+
48
+ ### Distilling a Model2Vec model
49
+
50
+ You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code:
51
+
52
+ ```python
53
+ from model2vec.distill import distill
54
+
55
+ # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model
56
+ m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256)
57
+
58
+ # Save the model
59
+ m2v_model.save_pretrained("m2v_model")
60
+ ```
61
+
62
+ ## How it works
63
+
64
+ Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec.
65
+
66
+ It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence.
67
+
68
+ ## Additional Resources
69
+
70
+ - [Model2Vec Repo](https://github.com/MinishLab/model2vec)
71
+ - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
72
+ - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
73
+ - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials)
74
+ - [Website](https://minishlab.github.io/)
75
+
76
+
77
+ ## Library Authors
78
+
79
+ Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
80
+
81
+ ## Citation
82
+
83
+ Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
84
+ ```
85
+ @article{minishlab2024model2vec,
86
+ author = {Tulkens, Stephan and {van Dongen}, Thomas},
87
+ title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
88
+ year = {2024},
89
+ url = {https://github.com/MinishLab/model2vec}
90
+ }
91
+ ```
src/distiller/model2vec/py.typed ADDED
File without changes
src/distiller/model2vec/quantization.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+
5
+ import numpy as np
6
+
7
+
8
+ class DType(str, Enum):
9
+ Float16 = "float16"
10
+ Float32 = "float32"
11
+ Float64 = "float64"
12
+ Int8 = "int8"
13
+
14
+
15
+ def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
16
+ """
17
+ Quantize embeddings to a specified data type to reduce memory usage.
18
+
19
+ :param embeddings: The embeddings to quantize, as a numpy array.
20
+ :param quantize_to: The data type to quantize to.
21
+ :return: The quantized embeddings.
22
+ :raises ValueError: If the quantization type is not valid.
23
+ """
24
+ if quantize_to == DType.Float16:
25
+ return embeddings.astype(np.float16)
26
+ if quantize_to == DType.Float32:
27
+ return embeddings.astype(np.float32)
28
+ if quantize_to == DType.Float64:
29
+ return embeddings.astype(np.float64)
30
+ if quantize_to == DType.Int8:
31
+ # Normalize to [-128, 127] range for int8
32
+ # We normalize to -127 to 127 to keep symmetry.
33
+ scale = np.max(np.abs(embeddings)) / 127.0
34
+ return np.round(embeddings / scale).astype(np.int8)
35
+ msg = "Not a valid enum member of DType."
36
+ raise ValueError(msg)
37
+
38
+
39
+ def quantize_and_reduce_dim(
40
+ embeddings: np.ndarray, quantize_to: str | DType | None, dimensionality: int | None
41
+ ) -> np.ndarray:
42
+ """
43
+ Quantize embeddings to a datatype and reduce dimensionality.
44
+
45
+ :param embeddings: The embeddings to quantize and reduce, as a numpy array.
46
+ :param quantize_to: The data type to quantize to. If None, no quantization is performed.
47
+ :param dimensionality: The number of dimensions to keep. If None, no dimensionality reduction is performed.
48
+ :return: The quantized and reduced embeddings.
49
+ :raises ValueError: If the passed dimensionality is not None and greater than the model dimensionality.
50
+ """
51
+ if quantize_to is not None:
52
+ quantize_to = DType(quantize_to)
53
+ embeddings = quantize_embeddings(embeddings, quantize_to)
54
+
55
+ if dimensionality is not None:
56
+ if dimensionality > embeddings.shape[1]:
57
+ msg = f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
58
+ raise ValueError(
59
+ msg
60
+ )
61
+ embeddings = embeddings[:, :dimensionality]
62
+
63
+ return embeddings
src/distiller/model2vec/tokenizer/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distiller.model2vec.utils import importable
2
+
3
+ importable("transformers", "tokenizer")
4
+
5
+ from distiller.model2vec.tokenizer.tokenizer import (
6
+ clean_and_create_vocabulary,
7
+ create_tokenizer,
8
+ replace_vocabulary,
9
+ turn_tokens_into_ids,
10
+ )
11
+
12
+ __all__ = ["clean_and_create_vocabulary", "create_tokenizer", "replace_vocabulary", "turn_tokens_into_ids"]
src/distiller/model2vec/tokenizer/datamodels.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class Token:
6
+ """A class to represent a token."""
7
+
8
+ form: str
9
+ # The normalized and pretokenized form of the token
10
+ normalized_form: str
11
+ # Whether the word is a continuing subword.
12
+ is_subword: bool
13
+ # Whether the token is internal to the model.
14
+ is_internal: bool
src/distiller/model2vec/tokenizer/model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+
7
+
8
+ def process_tokenizer(
9
+ tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
10
+ ) -> dict[str, Any]:
11
+ """Process the WordPiece tokenizer JSON."""
12
+ if tokenizer_json["model"]["type"] == "Unigram":
13
+ return _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token)
14
+ tokenizer_json["model"]["type"] = "Unigram"
15
+ tokenizer_json["model"]["unk_id"] = pre_tokenized_tokens.index(unk_token) if unk_token else None
16
+
17
+ token_weights = np.asarray([_calculate_token_weight_for_unigram(token) for token in pre_tokenized_tokens])
18
+ proba = (token_weights / np.sum(token_weights)).tolist()
19
+ tokenizer_json["model"]["vocab"] = [(token, np.log(p)) for token, p in zip(pre_tokenized_tokens, proba, strict=False)]
20
+
21
+ return tokenizer_json
22
+
23
+
24
+ def _process_unigram(
25
+ tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
26
+ ) -> dict[str, Any]:
27
+ """Process the Unigram tokenizer JSON."""
28
+ current_probas = dict(tokenizer_json["model"]["vocab"])
29
+ avg_proba = sum(current_probas.values()) / len(current_probas)
30
+ new_probas = [[word, current_probas.get(word, avg_proba)] for word in pre_tokenized_tokens]
31
+ tokenizer_json["model"]["vocab"] = new_probas
32
+
33
+ tokens, _ = zip(*tokenizer_json["model"]["vocab"], strict=False)
34
+ if unk_token is not None:
35
+ tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token)
36
+
37
+ return tokenizer_json
38
+
39
+
40
+ def _calculate_token_weight_for_unigram(token: str) -> float:
41
+ """Calculate the token weight for Unigram."""
42
+ # Always prefer longer tokens.
43
+ return len(token) + token.count("▁") + token.count("Ġ")
src/distiller/model2vec/tokenizer/normalizer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from string import punctuation
2
+
3
+ from tokenizers import Regex, Tokenizer
4
+ from tokenizers.normalizers import Replace, Sequence, Strip
5
+
6
+
7
+ def replace_normalizer(
8
+ tokenizer: Tokenizer,
9
+ ) -> Tokenizer:
10
+ """
11
+ Replace the normalizer for the tokenizer.
12
+
13
+ The new normalizer will replace punctuation with a space before and after the punctuation.
14
+ It will also replace multiple spaces with a single space and strip the right side of the string.
15
+ If the tokenizer already has a normalizer, it will be added to the new normalizer.
16
+ If the tokenizer does not have a normalizer, a new normalizer will be created.
17
+
18
+ :param tokenizer: The tokenizer to change.
19
+ :return: The tokenizer with a replaced normalizer.
20
+ """
21
+ normalizer = tokenizer.normalizer
22
+ new_normalizers = []
23
+ for char in punctuation:
24
+ new_normalizers.append(Replace(char, f" {char} "))
25
+
26
+ new_normalizers.append(Replace(Regex(r"\s+"), " "))
27
+ new_normalizers.append(Strip(right=True))
28
+ if normalizer is None:
29
+ normalizer = Sequence(new_normalizers) # type: ignore
30
+ else:
31
+ normalizer = Sequence([normalizer, *new_normalizers]) # type: ignore
32
+ tokenizer.normalizer = normalizer # type: ignore
33
+
34
+ return tokenizer
src/distiller/model2vec/tokenizer/pretokenizer.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ if TYPE_CHECKING:
7
+ from tokenizers import Tokenizer
8
+
9
+ _FORBIDDEN_PRETOKENIZERS = (
10
+ "WhiteSpace",
11
+ "WhitespaceSplit",
12
+ "BertPreTokenizer",
13
+ "CharDelimiterSplit",
14
+ "Punctuation",
15
+ "Split",
16
+ "UnicodeScripts",
17
+ )
18
+ _BASIC_METASPACE = {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False}
19
+
20
+
21
+ def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] | None:
22
+ """Fixes a single pretokenizer to allow multiword units."""
23
+ if pre_tokenizer["type"] in _FORBIDDEN_PRETOKENIZERS:
24
+ return None
25
+ if pre_tokenizer["type"] == "ByteLevel":
26
+ pre_tokenizer["add_prefix_space"] = True
27
+ pre_tokenizer["use_regex"] = False
28
+ if pre_tokenizer["type"] == "Metaspace":
29
+ pre_tokenizer["split"] = False
30
+ pre_tokenizer["prepend_scheme"] = "always"
31
+
32
+ return pre_tokenizer
33
+
34
+
35
+ def replace_pretokenizer(tokenizer: Tokenizer) -> Tokenizer:
36
+ """Fixes a single pretokenizer to allow multiword units."""
37
+ tokenizer_json = json.loads(tokenizer.to_str())
38
+ pre_tokenizer_json = tokenizer_json.get("pre_tokenizer", None)
39
+
40
+ if pre_tokenizer_json is None:
41
+ pre_tokenizer_json = _BASIC_METASPACE
42
+
43
+ elif pre_tokenizer_json["type"] == "Sequence":
44
+ new_pretokenizers = []
45
+ for single_pretokenizer in pre_tokenizer_json["pretokenizers"]:
46
+ new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer)
47
+ if new_pretokenizer is not None:
48
+ new_pretokenizers.append(new_pretokenizer)
49
+
50
+ if new_pretokenizers:
51
+ pre_tokenizer_json["pretokenizers"] = new_pretokenizers
52
+ else:
53
+ pre_tokenizer_json = _BASIC_METASPACE
54
+
55
+ pre_tokenizer_json = _fix_single_pretokenizer(pre_tokenizer_json) or _BASIC_METASPACE
56
+ tokenizer_json["pre_tokenizer"] = pre_tokenizer_json
57
+
58
+ return tokenizer.from_str(json.dumps(tokenizer_json))
src/distiller/model2vec/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from typing import TYPE_CHECKING, Any, cast
6
+
7
+ from tokenizers import Tokenizer
8
+ from transformers import PreTrainedTokenizerFast
9
+
10
+ from distiller.model2vec.tokenizer.datamodels import Token
11
+ from distiller.model2vec.tokenizer.model import process_tokenizer
12
+ from distiller.model2vec.tokenizer.normalizer import replace_normalizer
13
+ from distiller.model2vec.tokenizer.pretokenizer import replace_pretokenizer
14
+
15
+ if TYPE_CHECKING:
16
+ import re
17
+
18
+ from tokenizers.normalizers import Normalizer
19
+ from tokenizers.pre_tokenizers import (
20
+ PreTokenizer,
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ _DEFAULT_POST_PROCESSOR_TEMPLATE = {
27
+ "type": "TemplateProcessing",
28
+ "single": [{"Sequence": {"id": "A", "type_id": 0}}],
29
+ "pair": [{"Sequence": {"id": "A", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 0}}],
30
+ "special_tokens": {},
31
+ }
32
+
33
+
34
+ def _remap_added_tokens(
35
+ special_tokens: list[dict[str, Any]],
36
+ vocabulary: list[str],
37
+ ) -> list[dict[str, Any]]:
38
+ """
39
+ Remap special tokens in the tokenizer.
40
+
41
+ This function updates the special tokens in the tokenizer based on a mapping provided.
42
+ It also ensures that the special tokens are present in the vocabulary.
43
+
44
+ :param special_tokens: The special tokens to remap.
45
+ :param vocabulary: The vocabulary as a list of tokens.
46
+ :return: The updated special tokens.
47
+ """
48
+ # Deepcopy
49
+ special_tokens = [{**x} for x in special_tokens]
50
+ for token in special_tokens:
51
+ token["id"] = vocabulary.index(token["content"])
52
+
53
+ return special_tokens
54
+
55
+
56
+ def replace_vocabulary(
57
+ tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
58
+ ) -> Tokenizer:
59
+ """Replace the vocabulary of a tokenizer with a new one."""
60
+ tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str())
61
+ added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]
62
+
63
+ pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary]
64
+
65
+ # We need to remove the added tokens but keep [UNK] and [PAD] tokens.
66
+ added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens)
67
+ added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens)
68
+
69
+ # Remove old added tokens from added tokens
70
+ tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
71
+ tokenizer_json = process_tokenizer(
72
+ tokenizer_json, pre_tokenized_tokens, "[UNK]" if "[UNK]" in pre_tokenized_tokens else None
73
+ )
74
+
75
+ # Remap special tokens
76
+ tokenizer_json["added_tokens"] = _remap_added_tokens(
77
+ special_tokens=tokenizer_json["added_tokens"],
78
+ vocabulary=pre_tokenized_tokens,
79
+ )
80
+ tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE
81
+
82
+ return Tokenizer.from_str(json.dumps(tokenizer_json))
83
+
84
+
85
+ def _rename_added_token(
86
+ form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str]
87
+ ) -> list[dict[str, Any]]:
88
+ """Rename added tokens in the tokenizer."""
89
+ if form is None:
90
+ return added_tokens
91
+
92
+ idx = vocabulary.index(form)
93
+ added_token = [x for x in added_tokens if x["content"] == form]
94
+ if added_token:
95
+ added_token[0]["id"] = idx
96
+ added_token[0]["content"] = new_form
97
+ vocabulary[idx] = new_form
98
+
99
+ return added_tokens
100
+
101
+
102
+ def clean_and_create_vocabulary(
103
+ tokenizer: PreTrainedTokenizerFast,
104
+ vocabulary: list[str],
105
+ token_remove_regex: re.Pattern | None,
106
+ ) -> tuple[list[Token], Tokenizer]:
107
+ """Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
108
+ seen_tokens = set()
109
+ post_normalize_seen_tokens = set()
110
+ n_empty = 0
111
+ n_duplicates = 0
112
+
113
+ backend_tokenizer = tokenizer.backend_tokenizer
114
+
115
+ # Make a base list of tokens.
116
+ internal_vocab: dict[str, int] = tokenizer.get_vocab()
117
+ internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])]
118
+
119
+ cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
120
+ # Copy the backend tokenizer to avoid modifying the original.
121
+ backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str())
122
+ backend_tokenizer = replace_normalizer(backend_tokenizer)
123
+
124
+ internal_tokens_set = {token.form for token in cleaned_vocabulary}
125
+
126
+ normalizer: Normalizer | None = backend_tokenizer.normalizer
127
+ for token in vocabulary:
128
+ if normalizer is not None:
129
+ token = cast("str", normalizer.normalize_str(token))
130
+
131
+ if not token:
132
+ n_empty += 1
133
+ continue
134
+
135
+ pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer
136
+ normalized_token = token
137
+ if pre_tokenizer is not None:
138
+ normalized_token = _normalize_vocabulary_token(
139
+ token=token,
140
+ pre_tokenizer=pre_tokenizer,
141
+ )
142
+
143
+ # We need to check whether the pretokenized token is in the vocabulary.
144
+ # But we need to return the original token, because that will be tokenized
145
+ # again by the tokenizer during featurization.
146
+ if normalized_token in seen_tokens or normalized_token in internal_tokens_set:
147
+ n_duplicates += 1
148
+ continue
149
+
150
+ # Add the possibly pretokenized token to seen
151
+ seen_tokens.add(normalized_token)
152
+
153
+ # After checking the token exists, we need to normalize it into the token
154
+ # it will become. For byte tokens, this means we don't do anything. For
155
+ # other types of tokens, we will insert a metaspace.
156
+ # In the case of multiword tokens, we replace any spaces with the metaspace
157
+ # or byte prefix token.
158
+ if not normalized_token.startswith(("▁", "Ġ")):
159
+ normalized_token = normalized_token.replace(" ", "▁")
160
+ normalized_token = f"▁{normalized_token}"
161
+ else:
162
+ normalized_token = normalized_token.replace(" ", normalized_token[0])
163
+
164
+ if normalized_token in post_normalize_seen_tokens:
165
+ n_duplicates += 1
166
+ continue
167
+
168
+ post_normalize_seen_tokens.add(normalized_token)
169
+ # Add the original string to the vocabulary.
170
+ cleaned_vocabulary.append(
171
+ Token(form=token, normalized_form=normalized_token, is_subword=False, is_internal=False)
172
+ )
173
+
174
+ if n_duplicates:
175
+ logger.warning(f"Removed {n_duplicates} duplicate tokens.")
176
+ if n_empty:
177
+ logger.warning(f"Removed {n_empty} empty tokens.")
178
+
179
+ return cleaned_vocabulary, replace_pretokenizer(backend_tokenizer)
180
+
181
+
182
+ def _process_internal_tokens(
183
+ tokenizer: PreTrainedTokenizerFast,
184
+ backend_tokenizer: Tokenizer,
185
+ internal_tokens: list[str],
186
+ token_remove_regex: re.Pattern | None,
187
+ ) -> list[Token]:
188
+ """Clean internal tokens."""
189
+ # Get the pad and unk token from the tokenizer.
190
+ pad_token: str | None = tokenizer.special_tokens_map.get("pad_token") # type: ignore[assignment]
191
+ unk_token: str | None = tokenizer.special_tokens_map.get("unk_token") # type: ignore[assignment]
192
+ # Empty set if no pad or unk token is set.
193
+ added_tokens_to_keep: set[str] = {x for x in (pad_token, unk_token) if x is not None}
194
+ added_tokens_to_remove = set(tokenizer.added_tokens_encoder) - added_tokens_to_keep
195
+ cleaned_internal_tokens: list[Token] = []
196
+
197
+ # Figure out whether token is a subword or not.
198
+ encoded = backend_tokenizer.encode(f" {'a' * 25}", add_special_tokens=False)
199
+ first_token, second_token, *_ = encoded.tokens
200
+ # Isolate the prefix. We can't do first_token[0] because we don't know
201
+ # how long the prefix is.
202
+ # e.g., "Ġaaaa" -> "Ġ"
203
+ a_index = None if "a" not in first_token else first_token.index("a")
204
+ word_prefix = first_token[:a_index]
205
+ is_byte_prefix = word_prefix == "Ġ"
206
+ second_token = encoded.tokens[1]
207
+ # The second token is the first subword token.
208
+ # If a tokenizer uses subwords, this token will have been prefixed.
209
+ # We don't know how long the prefix is.
210
+ a_index = None if "a" not in second_token else second_token.index("a")
211
+ subword_prefix = second_token[:a_index]
212
+
213
+ pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer
214
+
215
+ for token in internal_tokens:
216
+ # Create the token objects. If this returns None, it was unsucessful for some reason.
217
+ if token_object := _create_single_internal_token(
218
+ token=token,
219
+ subword_prefix=subword_prefix,
220
+ word_prefix=word_prefix,
221
+ pre_tokenizer=pre_tokenizer,
222
+ is_byte_prefix=is_byte_prefix,
223
+ token_remove_regex=token_remove_regex,
224
+ added_tokens_to_keep=added_tokens_to_keep,
225
+ added_tokens_to_remove=added_tokens_to_remove,
226
+ ):
227
+ cleaned_internal_tokens.append(token_object)
228
+
229
+ if len(cleaned_internal_tokens) != len(internal_tokens):
230
+ logger.info(
231
+ f"Removed {len(internal_tokens) - len(cleaned_internal_tokens)} internal tokens from the vocabulary."
232
+ )
233
+
234
+ return cleaned_internal_tokens
235
+
236
+
237
+ def _create_single_internal_token(
238
+ token: str,
239
+ subword_prefix: str,
240
+ word_prefix: str,
241
+ pre_tokenizer: PreTokenizer | None,
242
+ is_byte_prefix: bool,
243
+ token_remove_regex: re.Pattern | None,
244
+ added_tokens_to_keep: set[str],
245
+ added_tokens_to_remove: set[str],
246
+ ) -> Token | None:
247
+ """Create a token object from a string."""
248
+ if token in added_tokens_to_remove:
249
+ # We remove any tokens that are added tokens that aren't [UNK] or [PAD].
250
+ return None
251
+ if token in added_tokens_to_keep:
252
+ # Don't put added tokens through the regular motions.
253
+ return Token(form=token, normalized_form=token, is_subword=False, is_internal=True)
254
+ if token_remove_regex and token_remove_regex.match(token):
255
+ # If the regex matches, remove the token.
256
+ return None
257
+
258
+ # A token is a subword if there is a subword prefix and the word
259
+ # starts with a subword prefix, or if there is a WORD prefix, and the word
260
+ # does not start with this prefix. For metaspace tokenizers, for example:
261
+ # "doghouse" -> ["_dog", "house"]
262
+ # So we can only tell that "house" is a subword by knowing that it is not prefixed
263
+ # and word-initial tokens are.
264
+ is_subword = False
265
+ if subword_prefix:
266
+ is_subword = bool(token.startswith(subword_prefix))
267
+ if word_prefix:
268
+ is_subword = not bool(token.startswith(word_prefix))
269
+
270
+ # Byte prefixed tokenizers don't need to be checked.
271
+ if pre_tokenizer is not None and not is_byte_prefix:
272
+ # We need to check the thing without prefixes. If we have a word prefix,
273
+ # we need to check tokens that have are subwords. Other way around for subword
274
+ # prefixes.
275
+ if (subword_prefix and not is_subword) or (word_prefix and is_subword):
276
+ # If this is True, the token is unreachable, even though it is a subword token.
277
+ if len(pre_tokenizer.pre_tokenize_str(token)) > 1:
278
+ return None
279
+
280
+ # Turn a token into a normalized form for later processing.
281
+ normalized_form = _create_normalized_form(token, subword_prefix, word_prefix, is_byte_prefix, is_subword)
282
+
283
+ return Token(form=token, normalized_form=normalized_form, is_subword=is_subword, is_internal=True)
284
+
285
+
286
+ def _create_normalized_form(
287
+ token: str, subword_prefix: str, word_prefix: str, is_byte_prefix: bool, is_subword: bool
288
+ ) -> str:
289
+ """Turn an internal token string into a normalized form."""
290
+ # We don't need to check byte prefixed strings.
291
+ if is_byte_prefix:
292
+ return token
293
+ # We need to check if the token is a subword or not and remove the prefix.
294
+ if is_subword:
295
+ return token.removeprefix(subword_prefix)
296
+ # If the token is not a subword, we need to remove the word prefix, and add metaspace.
297
+ return f"▁{token.removeprefix(word_prefix)}"
298
+
299
+
300
+ def turn_tokens_into_ids(
301
+ tokens: list[Token], tokenizer: PreTrainedTokenizerFast, unk_token: str | None
302
+ ) -> list[list[int]]:
303
+ """
304
+ Convert a list of Token objects to their corresponding token ID sequences.
305
+
306
+ :param tokens: List of Token objects to convert
307
+ :param tokenizer: The tokenizer to use for converting tokens to IDs
308
+ :param unk_token: The string form of the unk token.
309
+ :return: List of token IDs corresponding to the input tokens
310
+ """
311
+ unk_id = None if unk_token is None else tokenizer.convert_tokens_to_ids(unk_token)
312
+ prefix, suffix = find_eos_bos(tokenizer)
313
+
314
+ token_ids: list[list[int]] = []
315
+ for token in tokens:
316
+ if token.is_internal:
317
+ # Careful. Any incorrect tokens will just get `[UNK]``, so this could go horribly wrong
318
+ # Cast because return type is wrong.
319
+ token_id: int = cast("int", tokenizer.convert_tokens_to_ids(token.form)) or 0
320
+ # Explicitly check and warn if `unk_id` appears, but don't crash.
321
+ if unk_id is not None and token_id == unk_id and token.form != unk_token:
322
+ logger.warning(f"Token {token.form} was set to unk. This is wrong.")
323
+ token_ids.append([*prefix, token_id, *suffix])
324
+ else:
325
+ token_ids.append(tokenizer.encode(token.form))
326
+
327
+ return token_ids
328
+
329
+
330
+ def find_eos_bos(tokenizer: PreTrainedTokenizerFast) -> tuple[list[int], list[int]]:
331
+ """Finds the eos and bos tokens for a tokenizer."""
332
+ # Little bit complicated, because not all tokenizers have eos and bos tokens.
333
+ encoding = tokenizer.encode("a", add_special_tokens=True)
334
+ if len(encoding) != 3:
335
+ a_encoded = tokenizer.encode("a", add_special_tokens=False)
336
+ if len(a_encoded) != 1:
337
+ msg = f"Error while encoding, couldn't determine eos and bos tokens. The model tokenizes 'a' to '{a_encoded}'"
338
+ raise ValueError(
339
+ msg
340
+ )
341
+ a_idx = encoding.index(a_encoded[0])
342
+ prefix, suffix = encoding[:a_idx], encoding[a_idx + 1 :]
343
+ else:
344
+ prefix, suffix = encoding[:1], encoding[2:]
345
+ return prefix, suffix
346
+
347
+
348
+ def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str:
349
+ """Normalize a token that is not in the initial token vocabulary."""
350
+ # Add prefix space for byte tokenizers.
351
+ prefixed_token = f" {token}"
352
+ pretokenized_tokens: tuple[str, ...]
353
+ pretokenized_tokens, offsets = zip(*pre_tokenizer.pre_tokenize_str(prefixed_token), strict=False)
354
+ # The first item is always the start of the token.
355
+ new_token = [pretokenized_tokens[0]]
356
+ # Loop over the subtokens and offsets.
357
+ for t, (s, _) in zip(pretokenized_tokens[1:], offsets[1:], strict=False):
358
+ # Do not prefix the token with a space if it starts with a metaspace.
359
+ if t.startswith("▁"):
360
+ new_token.append(t)
361
+ # If the character before the subtoken is a space, we have a
362
+ # multiword token. e.g., "room for the moon", which is split into
363
+ # ["room", "for", "the", "moon"].
364
+ # If it doesn't have a space, it is part of a complex multiword token,
365
+ # e.g., "chat-gpt", which is split into ["chat", "-", "gpt"].
366
+ elif prefixed_token[s - 1] == " ":
367
+ new_token.append(f" {t}")
368
+ else:
369
+ new_token.append(t)
370
+ return "".join(new_token)
371
+
372
+
373
+
374
+ def create_tokenizer(
375
+ tokenizer: PreTrainedTokenizerFast,
376
+ vocabulary: list[str],
377
+ token_remove_regex: re.Pattern | None = None,
378
+ ) -> PreTrainedTokenizerFast:
379
+ """
380
+ Create a tokenizer by adding tokens to the vocabulary.
381
+
382
+ This function turns any tokenizer into a supertoken tokenizer. It does the following:
383
+ 1. Turns the tokenizer model into a unigram model.
384
+ 2. Adds a new pretokenizer, splitting on punctuation.
385
+ 3. Adds all tokens in vocabulary to the model.
386
+ 4. Removes any internal tokens that conform to the regex.
387
+
388
+ :param tokenizer: The tokenizer to use.
389
+ :param vocabulary: The vocabulary to use.
390
+ :param token_remove_regex: The regex to use to remove tokens from the vocabulary.
391
+ :return: The created tokenizer.
392
+ """
393
+ unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token"))
394
+ pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token"))
395
+ cleaned_vocabulary, backend_tokenizer = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex)
396
+ new_tokenizer = replace_vocabulary(backend_tokenizer, cleaned_vocabulary, unk_token, pad_token)
397
+
398
+ return PreTrainedTokenizerFast(tokenizer_object=new_tokenizer)
src/distiller/model2vec/train/README.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training
2
+
3
+ Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html).
4
+
5
+ We support both single and multi-label classification, which work seamlessly based on the labels you provide.
6
+
7
+ # Installation
8
+
9
+ To train, make sure you install the training extra:
10
+
11
+ ```
12
+ pip install model2vec[training]
13
+ ```
14
+
15
+ # Quickstart
16
+
17
+ To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows:
18
+
19
+ ```python
20
+ from model2vec.distill import distill
21
+ from model2vec.train import StaticModelForClassification
22
+
23
+ # From a distilled model
24
+ distilled_model = distill("baai/bge-base-en-v1.5")
25
+ classifier = StaticModelForClassification.from_static_model(model=distilled_model)
26
+
27
+ # From a pre-trained model: potion is the default
28
+ classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m")
29
+ ```
30
+
31
+ This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-32m](https://huggingface.co/minishlab/potion-base-32M), our best model to date. This is our recommended path if you're working with general English data.
32
+
33
+ Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed.
34
+
35
+ ```python
36
+ import numpy as np
37
+ from datasets import load_dataset
38
+
39
+ # Load the subj dataset
40
+ ds = load_dataset("setfit/subj")
41
+ train = ds["train"]
42
+ test = ds["test"]
43
+
44
+ s = perf_counter()
45
+ classifier = classifier.fit(train["text"], train["label"])
46
+
47
+ print(f"Training took {int(perf_counter() - s)} seconds.")
48
+ # Training took 81 seconds
49
+ classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
50
+ print(classification_report)
51
+ # Achieved 91.0 test accuracy
52
+ ```
53
+
54
+ As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training.
55
+
56
+ The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5.
57
+
58
+ Note that this model is as fast as you're used to from us:
59
+
60
+ ```python
61
+ from time import perf_counter
62
+
63
+ s = perf_counter()
64
+ classifier.predict(test["text"])
65
+ print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.")
66
+ # Took 67 milliseconds for 2000 instances on CPU.
67
+ ```
68
+
69
+ ## Multi-label classification
70
+
71
+ Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset:
72
+
73
+ ```python
74
+ from datasets import load_dataset
75
+ from model2vec.train import StaticModelForClassification
76
+
77
+ # Initialize a classifier from a pre-trained model
78
+ classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
79
+
80
+ # Load a multi-label dataset
81
+ ds = load_dataset("google-research-datasets/go_emotions")
82
+
83
+ # Inspect some of the labels
84
+ print(ds["train"]["labels"][40:50])
85
+ # [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]
86
+
87
+ # Train the classifier on text (X) and labels (y)
88
+ classifier.fit(ds["train"]["text"], ds["train"]["labels"])
89
+ ```
90
+
91
+ Then, we can evaluate the classifier:
92
+
93
+ ```python
94
+ from sklearn import metrics
95
+ from sklearn.preprocessing import MultiLabelBinarizer
96
+
97
+ classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
98
+ print(classification_report)
99
+ # Accuracy: 0.410
100
+ # Precision: 0.527
101
+ # Recall: 0.410
102
+ # F1: 0.439
103
+ ```
104
+
105
+ The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster.
106
+
107
+ # Persistence
108
+
109
+ You can turn a classifier into a scikit-learn compatible pipeline, as follows:
110
+
111
+ ```python
112
+ pipeline = classifier.to_pipeline()
113
+ ```
114
+
115
+ This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization.
116
+
117
+ If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions:
118
+
119
+ ```python
120
+ pipeline.save_pretrained(path)
121
+ pipeline.push_to_hub("my_cool/project")
122
+ ```
123
+
124
+ Later, you can load these as follows:
125
+
126
+ ```python
127
+ from model2vec.inference import StaticModelPipeline
128
+
129
+ pipeline = StaticModelPipeline.from_pretrained("my_cool/project")
130
+ ```
131
+
132
+ Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk.
133
+
134
+
135
+ # Bring your own architecture
136
+
137
+ Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc.
138
+
139
+ The core functionality of the `StaticModelForClassification` is contained in a couple of functions:
140
+
141
+ * `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior.
142
+ * `train_test_split`: governs the train test split before classification.
143
+ * `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training.
144
+ * `_encode`: The encoding function used in the model.
145
+ * `fit`: contains all the lightning-related fitting logic.
146
+
147
+ The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic.
148
+
149
+ # Results
150
+
151
+ We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation.
src/distiller/model2vec/train/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from distiller.model2vec.utils import get_package_extras, importable
2
+
3
+ _REQUIRED_EXTRA = "train"
4
+
5
+ for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
6
+ importable(extra_dependency, _REQUIRED_EXTRA)
7
+
8
+ from distiller.model2vec.train.classifier import StaticModelForClassification
9
+
10
+ __all__ = ["StaticModelForClassification"]
src/distiller/model2vec/train/base.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, Any, Self, TypeVar
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ from distiller.model2vec import StaticModel
13
+
14
+ if TYPE_CHECKING:
15
+ from tokenizers import Encoding, Tokenizer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class FinetunableStaticModel(nn.Module):
21
+ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
22
+ """
23
+ Initialize a trainable StaticModel from a StaticModel.
24
+
25
+ :param vectors: The embeddings of the staticmodel.
26
+ :param tokenizer: The tokenizer.
27
+ :param out_dim: The output dimension of the head.
28
+ :param pad_id: The padding id. This is set to 0 in almost all model2vec models
29
+ """
30
+ super().__init__()
31
+ self.pad_id = pad_id
32
+ self.out_dim = out_dim
33
+ self.embed_dim = vectors.shape[1]
34
+
35
+ self.vectors = vectors
36
+ if self.vectors.dtype != torch.float32:
37
+ dtype = str(self.vectors.dtype)
38
+ logger.warning(
39
+ f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
40
+ )
41
+ self.vectors = vectors.float()
42
+
43
+ self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
44
+ self.head = self.construct_head()
45
+ self.w = self.construct_weights()
46
+ self.tokenizer = tokenizer
47
+
48
+ def construct_weights(self) -> nn.Parameter:
49
+ """Construct the weights for the model."""
50
+ weights = torch.zeros(len(self.vectors))
51
+ weights[self.pad_id] = -10_000
52
+ return nn.Parameter(weights)
53
+
54
+ def construct_head(self) -> nn.Sequential:
55
+ """Method should be overridden for various other classes."""
56
+ return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
57
+
58
+ @classmethod
59
+ def from_pretrained(
60
+ cls, *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
61
+ ) -> Self:
62
+ """Load the model from a pretrained model2vec model."""
63
+ model = StaticModel.from_pretrained(model_name)
64
+ return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)
65
+
66
+ @classmethod
67
+ def from_static_model(cls, *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> Self:
68
+ """Load the model from a static model."""
69
+ model.embedding = np.nan_to_num(model.embedding)
70
+ embeddings_converted = torch.from_numpy(model.embedding)
71
+ return cls(
72
+ vectors=embeddings_converted,
73
+ pad_id=model.tokenizer.token_to_id("[PAD]"),
74
+ out_dim=out_dim,
75
+ tokenizer=model.tokenizer,
76
+ **kwargs,
77
+ )
78
+
79
+ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
80
+ """
81
+ A forward pass and mean pooling.
82
+
83
+ This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients
84
+ to pass through.
85
+
86
+ :param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
87
+ :return: The mean over the input ids, weighted by token weights.
88
+ """
89
+ w = self.w[input_ids]
90
+ w = torch.sigmoid(w)
91
+ zeros = (input_ids != self.pad_id).float()
92
+ w = w * zeros
93
+ # Add a small epsilon to avoid division by zero
94
+ length = zeros.sum(1) + 1e-16
95
+ embedded = self.embeddings(input_ids)
96
+ # Weigh each token
97
+ embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
98
+ # Mean pooling by dividing by the length
99
+ embedded = embedded / length[:, None]
100
+
101
+ return nn.functional.normalize(embedded)
102
+
103
+ def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
104
+ """Forward pass through the mean, and a classifier layer after."""
105
+ encoded = self._encode(input_ids)
106
+ return self.head(encoded), encoded
107
+
108
+ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor:
109
+ """
110
+ Tokenize a bunch of strings into a single padded 2D tensor.
111
+
112
+ Note that this is not used during training.
113
+
114
+ :param texts: The texts to tokenize.
115
+ :param max_length: If this is None, the sequence lengths are truncated to 512.
116
+ :return: A 2D padded tensor
117
+ """
118
+ encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
119
+ encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
120
+ return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)
121
+
122
+ @property
123
+ def device(self) -> str:
124
+ """Get the device of the model."""
125
+ return self.embeddings.weight.device
126
+
127
+ def to_static_model(self) -> StaticModel:
128
+ """Convert the model to a static model."""
129
+ emb = self.embeddings.weight.detach().cpu().numpy()
130
+ w = torch.sigmoid(self.w).detach().cpu().numpy()
131
+
132
+ return StaticModel(emb * w[:, None], self.tokenizer, normalize=True)
133
+
134
+
135
+ class TextDataset(Dataset):
136
+ def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:
137
+ """
138
+ A dataset of texts.
139
+
140
+ :param tokenized_texts: The tokenized texts. Each text is a list of token ids.
141
+ :param targets: The targets.
142
+ :raises ValueError: If the number of labels does not match the number of texts.
143
+ """
144
+ if len(targets) != len(tokenized_texts):
145
+ msg = "Number of labels does not match number of texts."
146
+ raise ValueError(msg)
147
+ self.tokenized_texts = tokenized_texts
148
+ self.targets = targets
149
+
150
+ def __len__(self) -> int:
151
+ """Return the length of the dataset."""
152
+ return len(self.tokenized_texts)
153
+
154
+ def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
155
+ """Gets an item."""
156
+ return self.tokenized_texts[index], self.targets[index]
157
+
158
+ @staticmethod
159
+ def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
160
+ """Collate function."""
161
+ texts, targets = zip(*batch, strict=False)
162
+
163
+ tensors = [torch.LongTensor(x) for x in texts]
164
+ padded = pad_sequence(tensors, batch_first=True, padding_value=0)
165
+
166
+ return padded, torch.stack(targets)
167
+
168
+ def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
169
+ """Convert the dataset to a DataLoader."""
170
+ return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
171
+
172
+
173
+ ModelType = TypeVar("ModelType", bound=FinetunableStaticModel)
src/distiller/model2vec/train/classifier.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from collections import Counter
5
+ from itertools import chain
6
+ from tempfile import TemporaryDirectory
7
+ from typing import TYPE_CHECKING, TypeVar, cast
8
+
9
+ import lightning as pl
10
+ import numpy as np
11
+ import torch
12
+ from lightning.pytorch.callbacks import Callback, EarlyStopping
13
+ from sklearn.metrics import jaccard_score
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.neural_network import MLPClassifier
16
+ from sklearn.pipeline import make_pipeline
17
+ from torch import nn
18
+ from tqdm import trange
19
+
20
+ from distiller.model2vec.inference import StaticModelPipeline, evaluate_single_or_multi_label
21
+ from distiller.model2vec.train.base import FinetunableStaticModel, TextDataset
22
+
23
+ if TYPE_CHECKING:
24
+ from lightning.pytorch.utilities.types import OptimizerLRScheduler
25
+ from tokenizers import Tokenizer
26
+
27
+ logger = logging.getLogger(__name__)
28
+ _RANDOM_SEED = 42
29
+
30
+ LabelType = TypeVar("LabelType", list[str], list[list[str]])
31
+
32
+
33
+ class StaticModelForClassification(FinetunableStaticModel):
34
+ def __init__(
35
+ self,
36
+ *,
37
+ vectors: torch.Tensor,
38
+ tokenizer: Tokenizer,
39
+ n_layers: int = 1,
40
+ hidden_dim: int = 512,
41
+ out_dim: int = 2,
42
+ pad_id: int = 0,
43
+ ) -> None:
44
+ """Initialize a standard classifier model."""
45
+ self.n_layers = n_layers
46
+ self.hidden_dim = hidden_dim
47
+ # Alias: Follows scikit-learn. Set to dummy classes
48
+ self.classes_: list[str] = [str(x) for x in range(out_dim)]
49
+ # multilabel flag will be set based on the type of `y` passed to fit.
50
+ self.multilabel: bool = False
51
+ super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)
52
+
53
+ @property
54
+ def classes(self) -> np.ndarray:
55
+ """Return all clasess in the correct order."""
56
+ return np.array(self.classes_)
57
+
58
+ def construct_head(self) -> nn.Sequential:
59
+ """Constructs a simple classifier head."""
60
+ if self.n_layers == 0:
61
+ return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
62
+ modules = [
63
+ nn.Linear(self.embed_dim, self.hidden_dim),
64
+ nn.ReLU(),
65
+ ]
66
+ for _ in range(self.n_layers - 1):
67
+ modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
68
+ modules.extend([nn.Linear(self.hidden_dim, self.out_dim)])
69
+
70
+ for module in modules:
71
+ if isinstance(module, nn.Linear):
72
+ nn.init.kaiming_uniform_(module.weight)
73
+ nn.init.zeros_(module.bias)
74
+
75
+ return nn.Sequential(*modules)
76
+
77
+ def predict(
78
+ self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024, threshold: float = 0.5
79
+ ) -> np.ndarray:
80
+ """
81
+ Predict labels for a set of texts.
82
+
83
+ In single-label mode, each prediction is a single class.
84
+ In multilabel mode, each prediction is a list of classes.
85
+
86
+ :param X: The texts to predict on.
87
+ :param show_progress_bar: Whether to show a progress bar.
88
+ :param batch_size: The batch size.
89
+ :param threshold: The threshold for multilabel classification.
90
+ :return: The predictions.
91
+ """
92
+ pred = []
93
+ for batch in trange(0, len(X), batch_size, disable=not show_progress_bar):
94
+ logits = self._predict_single_batch(X[batch : batch + batch_size])
95
+ if self.multilabel:
96
+ probs = torch.sigmoid(logits)
97
+ mask = (probs > threshold).cpu().numpy()
98
+ pred.extend([self.classes[np.flatnonzero(row)] for row in mask])
99
+ else:
100
+ pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()])
101
+ if self.multilabel:
102
+ # Return as object array to allow for lists of varying lengths.
103
+ return np.array(pred, dtype=object)
104
+ return np.array(pred)
105
+
106
+ @torch.no_grad()
107
+ def _predict_single_batch(self, X: list[str]) -> torch.Tensor:
108
+ input_ids = self.tokenize(X)
109
+ vectors, _ = self.forward(input_ids)
110
+ return vectors
111
+
112
+ def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray:
113
+ """
114
+ Predict probabilities for each class.
115
+
116
+ In single-label mode, returns softmax probabilities.
117
+ In multilabel mode, returns sigmoid probabilities.
118
+ """
119
+ pred = []
120
+ for batch in trange(0, len(X), batch_size, disable=not show_progress_bar):
121
+ logits = self._predict_single_batch(X[batch : batch + batch_size])
122
+ if self.multilabel:
123
+ pred.append(torch.sigmoid(logits).cpu().numpy())
124
+ else:
125
+ pred.append(torch.softmax(logits, dim=1).cpu().numpy())
126
+ return np.concatenate(pred, axis=0)
127
+
128
+ def fit(
129
+ self,
130
+ X: list[str],
131
+ y: LabelType,
132
+ learning_rate: float = 1e-3,
133
+ batch_size: int | None = None,
134
+ min_epochs: int | None = None,
135
+ max_epochs: int | None = -1,
136
+ early_stopping_patience: int | None = 5,
137
+ test_size: float = 0.1,
138
+ device: str = "auto",
139
+ X_val: list[str] | None = None,
140
+ y_val: LabelType | None = None,
141
+ ) -> StaticModelForClassification:
142
+ """
143
+ Fit a model.
144
+
145
+ This function creates a Lightning Trainer object and fits the model to the data.
146
+ It supports both single-label and multi-label classification.
147
+ We use early stopping. After training, the weights of the best model are loaded back into the model.
148
+
149
+ This function seeds everything with a seed of 42, so the results are reproducible.
150
+ It also splits the data into a train and validation set, again with a random seed.
151
+
152
+ If `X_val` and `y_val` are not provided, the function will automatically
153
+ split the training data into a train and validation set using `test_size`.
154
+
155
+ :param X: The texts to train on.
156
+ :param y: The labels to train on. If the first element is a list, multi-label classification is assumed.
157
+ :param learning_rate: The learning rate.
158
+ :param batch_size: The batch size. If None, a good batch size is chosen automatically.
159
+ :param min_epochs: The minimum number of epochs to train for.
160
+ :param max_epochs: The maximum number of epochs to train for.
161
+ If this is -1, the model trains until early stopping is triggered.
162
+ :param early_stopping_patience: The patience for early stopping.
163
+ If this is None, early stopping is disabled.
164
+ :param test_size: The test size for the train-test split.
165
+ :param device: The device to train on. If this is "auto", the device is chosen automatically.
166
+ :param X_val: The texts to be used for validation.
167
+ :param y_val: The labels to be used for validation.
168
+ :return: The fitted model.
169
+ :raises ValueError: If either X_val or y_val are provided, but not both.
170
+ """
171
+ pl.seed_everything(_RANDOM_SEED)
172
+ logger.info("Re-initializing model.")
173
+
174
+ # Determine whether the task is multilabel based on the type of y.
175
+
176
+ self._initialize(y)
177
+
178
+ if (X_val is not None) != (y_val is not None):
179
+ msg = "Both X_val and y_val must be provided together, or neither."
180
+ raise ValueError(msg)
181
+
182
+ if X_val is not None and y_val is not None:
183
+ # Additional check to ensure y_val is of the same type as y
184
+ if type(y_val[0]) != type(y[0]):
185
+ msg = "X_val and y_val must be of the same type as X and y."
186
+ raise ValueError(msg)
187
+
188
+ train_texts = X
189
+ train_labels = y
190
+ validation_texts = X_val
191
+ validation_labels = y_val
192
+ else:
193
+ train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
194
+ X,
195
+ y,
196
+ test_size=test_size,
197
+ )
198
+
199
+ if batch_size is None:
200
+ # Set to a multiple of 32
201
+ base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
202
+ batch_size = int(base_number * 32)
203
+ logger.info("Batch size automatically set to %d.", batch_size)
204
+
205
+ logger.info("Preparing train dataset.")
206
+ train_dataset = self._prepare_dataset(train_texts, train_labels)
207
+ logger.info("Preparing validation dataset.")
208
+ val_dataset = self._prepare_dataset(validation_texts, validation_labels)
209
+
210
+ c = _ClassifierLightningModule(self, learning_rate=learning_rate)
211
+
212
+ n_train_batches = len(train_dataset) // batch_size
213
+ callbacks: list[Callback] = []
214
+ if early_stopping_patience is not None:
215
+ callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience)
216
+ callbacks.append(callback)
217
+
218
+ # If the dataset is small, we check the validation set every epoch.
219
+ # If the dataset is large, we check the validation set every 250 batches.
220
+ if n_train_batches < 250:
221
+ val_check_interval = None
222
+ check_val_every_epoch = 1
223
+ else:
224
+ val_check_interval = max(250, 2 * len(val_dataset) // batch_size)
225
+ check_val_every_epoch = None
226
+
227
+ with TemporaryDirectory() as tempdir:
228
+ trainer = pl.Trainer(
229
+ min_epochs=min_epochs,
230
+ max_epochs=max_epochs,
231
+ callbacks=callbacks,
232
+ val_check_interval=val_check_interval,
233
+ check_val_every_n_epoch=check_val_every_epoch,
234
+ accelerator=device,
235
+ default_root_dir=tempdir,
236
+ )
237
+
238
+ trainer.fit(
239
+ c,
240
+ train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size),
241
+ val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size),
242
+ )
243
+ best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore
244
+ best_model_weights = torch.load(best_model_path, weights_only=True)
245
+
246
+ state_dict = {}
247
+ for weight_name, weight in best_model_weights["state_dict"].items():
248
+ state_dict[weight_name.removeprefix("model.")] = weight
249
+
250
+ self.load_state_dict(state_dict)
251
+ self.eval()
252
+ return self
253
+
254
+ def evaluate(
255
+ self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
256
+ ) -> str | dict[str, dict[str, float]]:
257
+ """
258
+ Evaluate the classifier on a given dataset using scikit-learn's classification report.
259
+
260
+ :param X: The texts to predict on.
261
+ :param y: The ground truth labels.
262
+ :param batch_size: The batch size.
263
+ :param threshold: The threshold for multilabel classification.
264
+ :param output_dict: Whether to output the classification report as a dictionary.
265
+ :return: A classification report.
266
+ """
267
+ self.eval()
268
+ predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
269
+ return evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)
270
+
271
+
272
+ def _initialize(self, y: LabelType) -> None:
273
+ """
274
+ Sets the output dimensionality, the classes, and initializes the head.
275
+
276
+ :param y: The labels.
277
+ :raises ValueError: If the labels are inconsistent.
278
+ """
279
+ if isinstance(y[0], (str, int)):
280
+ # Check if all labels are strings or integers.
281
+ if not all(isinstance(label, (str, int)) for label in y):
282
+ msg = "Inconsistent label types in y. All labels must be strings or integers."
283
+ raise ValueError(msg)
284
+ self.multilabel = False
285
+ classes = sorted(set(y))
286
+ else:
287
+ # Check if all labels are lists or tuples.
288
+ if not all(isinstance(label, (list, tuple)) for label in y):
289
+ msg = "Inconsistent label types in y. All labels must be lists or tuples."
290
+ raise ValueError(msg)
291
+ self.multilabel = True
292
+ classes = sorted(set(chain.from_iterable(y)))
293
+
294
+ self.classes_ = classes
295
+ self.out_dim = len(self.classes_) # Update output dimension
296
+ self.head = self.construct_head()
297
+ self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
298
+ self.w = self.construct_weights()
299
+ self.train()
300
+
301
+ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> TextDataset:
302
+ """
303
+ Prepare a dataset. For multilabel classification, each target is converted into a multi-hot vector.
304
+
305
+ :param X: The texts.
306
+ :param y: The labels.
307
+ :param max_length: The maximum length of the input.
308
+ :return: A TextDataset.
309
+ """
310
+ # This is a speed optimization.
311
+ # assumes a mean token length of 10, which is really high, so safe.
312
+ truncate_length = max_length * 10
313
+ X = [x[:truncate_length] for x in X]
314
+ tokenized: list[list[int]] = [
315
+ encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False)
316
+ ]
317
+ if self.multilabel:
318
+ # Convert labels to multi-hot vectors
319
+ num_classes = len(self.classes_)
320
+ labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float)
321
+ mapping = {label: idx for idx, label in enumerate(self.classes_)}
322
+ for i, sample_labels in enumerate(y):
323
+ indices = [mapping[label] for label in sample_labels]
324
+ labels_tensor[i, indices] = 1.0
325
+ else:
326
+ labels_tensor = torch.tensor([self.classes_.index(label) for label in cast("list[str]", y)], dtype=torch.long)
327
+ return TextDataset(tokenized, labels_tensor)
328
+
329
+ def _train_test_split(
330
+ self,
331
+ X: list[str],
332
+ y: list[str] | list[list[str]],
333
+ test_size: float,
334
+ ) -> tuple[list[str], list[str], LabelType, LabelType]:
335
+ """
336
+ Split the data.
337
+
338
+ For single-label classification, stratification is attempted (if possible).
339
+ For multilabel classification, a random split is performed.
340
+ """
341
+ if not self.multilabel:
342
+ label_counts = Counter(y)
343
+ if min(label_counts.values()) < 2:
344
+ logger.info("Some classes have less than 2 samples. Stratification is disabled.")
345
+ return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True)
346
+ return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y)
347
+ # Multilabel classification does not support stratification.
348
+ return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True)
349
+
350
+ def to_pipeline(self) -> StaticModelPipeline:
351
+ """Convert the model to an sklearn pipeline."""
352
+ static_model = self.to_static_model()
353
+
354
+ random_state = np.random.RandomState(_RANDOM_SEED)
355
+ n_items = len(self.classes)
356
+ X = random_state.randn(n_items, static_model.dim)
357
+ y = self.classes
358
+
359
+ converted = make_pipeline(MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers))
360
+ converted.fit(X, y)
361
+ mlp_head: MLPClassifier = converted[-1]
362
+
363
+ for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]):
364
+ mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T
365
+ mlp_head.intercepts_[index] = layer.bias.detach().cpu().numpy()
366
+ # Below is necessary to ensure that the converted model works correctly.
367
+ # In scikit-learn, a binary classifier only has a single vector of output coefficients
368
+ # and a single intercept. We use two output vectors.
369
+ # To convert correctly, we need to set the outputs correctly, and fix the activation function.
370
+ # Make sure n_outputs is set to > 1.
371
+ mlp_head.n_outputs_ = self.out_dim
372
+ # Set to softmax or sigmoid
373
+ mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax"
374
+
375
+ return StaticModelPipeline(static_model, converted)
376
+
377
+
378
+ class _ClassifierLightningModule(pl.LightningModule):
379
+ def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None:
380
+ """Initialize the LightningModule."""
381
+ super().__init__()
382
+ self.model = model
383
+ self.learning_rate = learning_rate
384
+ self.loss_function = nn.CrossEntropyLoss() if not model.multilabel else nn.BCEWithLogitsLoss()
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ """Simple forward pass."""
388
+ return self.model(x)
389
+
390
+ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
391
+ """Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training."""
392
+ x, y = batch
393
+ head_out, _ = self.model(x)
394
+ loss = self.loss_function(head_out, y)
395
+ self.log("train_loss", loss)
396
+ return loss
397
+
398
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
399
+ """Validation step computing loss and accuracy."""
400
+ x, y = batch
401
+ head_out, _ = self.model(x)
402
+ loss = self.loss_function(head_out, y)
403
+ if self.model.multilabel:
404
+ preds = (torch.sigmoid(head_out) > 0.5).float()
405
+ # Multilabel accuracy is defined as the Jaccard score averaged over samples.
406
+ accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples")
407
+ else:
408
+ accuracy = (head_out.argmax(dim=1) == y).float().mean()
409
+ self.log("val_loss", loss)
410
+ self.log("val_accuracy", accuracy, prog_bar=True)
411
+
412
+ return loss
413
+
414
+ def configure_optimizers(self) -> OptimizerLRScheduler:
415
+ """Configure optimizer and learning rate scheduler."""
416
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
417
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
418
+ optimizer,
419
+ mode="min",
420
+ factor=0.5,
421
+ patience=3,
422
+ min_lr=1e-6,
423
+ threshold=0.03,
424
+ threshold_mode="rel",
425
+ )
426
+ return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
src/distiller/model2vec/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import re
6
+ from importlib import import_module
7
+ from importlib.metadata import metadata
8
+ from typing import TYPE_CHECKING, Any, Protocol, cast
9
+
10
+ import safetensors
11
+ from joblib import Parallel
12
+ from tokenizers import Tokenizer
13
+ from tqdm import tqdm
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Iterator
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ProgressParallel(Parallel):
25
+ """A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""
26
+
27
+ def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
28
+ """
29
+ Initialize the ProgressParallel object.
30
+
31
+ :param use_tqdm: Whether to show the progress bar.
32
+ :param total: Total number of tasks (batches) you expect to process. If None,
33
+ it updates the total dynamically to the number of dispatched tasks.
34
+ :param *args: Additional arguments to pass to `Parallel.__init__`.
35
+ :param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
36
+ """
37
+ self._use_tqdm = use_tqdm
38
+ self._total = total
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
42
+ """Create a tqdm context."""
43
+ with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
44
+ self._pbar = self._pbar
45
+ return super().__call__(*args, **kwargs)
46
+
47
+ def print_progress(self) -> None:
48
+ """Hook called by joblib as tasks complete. We update the tqdm bar here."""
49
+ if self._total is None:
50
+ # If no fixed total was given, we dynamically set the total
51
+ self._pbar.total = self.n_dispatched_tasks
52
+ # Move the bar to the number of completed tasks
53
+ self._pbar.n = self.n_completed_tasks
54
+ self._pbar.refresh()
55
+
56
+
57
+ class SafeOpenProtocol(Protocol):
58
+ """Protocol to fix safetensors safe open."""
59
+
60
+ def get_tensor(self, key: str) -> np.ndarray:
61
+ """Get a tensor."""
62
+ ... # pragma: no cover
63
+
64
+
65
+ _MODULE_MAP = (("scikit-learn", "sklearn"),)
66
+ _DIVIDERS = re.compile(r"[=<>!]+")
67
+
68
+
69
+ def get_package_extras(package: str, extra: str) -> Iterator[str]:
70
+ """Get the extras of the package."""
71
+ try:
72
+ message = metadata(package)
73
+ except Exception as e:
74
+ # For local packages without metadata, return empty iterator
75
+ # This allows the package to work without installed metadata
76
+ logger.debug(f"Could not retrieve metadata for package '{package}': {e}")
77
+ return iter([])
78
+
79
+ all_packages = message.get_all("Requires-Dist") or []
80
+ for package in all_packages:
81
+ name, *rest = package.split(";", maxsplit=1)
82
+ if rest:
83
+ # Extract and clean the extra requirement
84
+ found_extra = rest[0].split("==")[-1].strip(" \"'")
85
+ if found_extra == extra:
86
+ prefix, *_ = _DIVIDERS.split(name)
87
+ yield prefix.strip()
88
+
89
+
90
+ def importable(module: str, extra: str) -> None:
91
+ """Check if a module is importable."""
92
+ module = dict(_MODULE_MAP).get(module, module)
93
+ try:
94
+ import_module(module)
95
+ except ImportError:
96
+ msg = f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`"
97
+ raise ImportError(msg)
98
+
99
+
100
+ def setup_logging() -> None:
101
+ """Simple logging setup."""
102
+ from rich.logging import RichHandler
103
+
104
+ logging.basicConfig(
105
+ level="INFO",
106
+ format="%(name)s - %(message)s",
107
+ datefmt="%Y-%m-%d %H:%M:%S",
108
+ handlers=[RichHandler(rich_tracebacks=True)],
109
+ )
110
+
111
+
112
+ def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
113
+ """Load a local model."""
114
+ embeddings_path = folder / "model.safetensors"
115
+ tokenizer_path = folder / "tokenizer.json"
116
+ config_path = folder / "config.json"
117
+
118
+ opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy"))
119
+ embeddings = opened_tensor_file.get_tensor("embeddings")
120
+
121
+ config = json.load(open(config_path)) if config_path.exists() else {}
122
+
123
+ tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
124
+
125
+ if len(tokenizer.get_vocab()) != len(embeddings):
126
+ logger.warning(
127
+ f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
128
+ )
129
+
130
+ return embeddings, tokenizer, config
src/distiller/model2vec/version.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __version_triple__ = (0, 5, 0)
2
+ __version__ = ".".join(map(str, __version_triple__))