Sarthak commited on
Commit ·
473c3a0
1
Parent(s): 72121b3
chore: moved model2vec as in internal package
Browse files- src/distiller/model2vec/__init__.py +4 -0
- src/distiller/model2vec/distill/__init__.py +10 -0
- src/distiller/model2vec/distill/distillation.py +260 -0
- src/distiller/model2vec/distill/inference.py +158 -0
- src/distiller/model2vec/distill/utils.py +28 -0
- src/distiller/model2vec/hf_utils.py +230 -0
- src/distiller/model2vec/inference/README.md +18 -0
- src/distiller/model2vec/inference/__init__.py +10 -0
- src/distiller/model2vec/inference/model.py +312 -0
- src/distiller/model2vec/model.py +480 -0
- src/distiller/model2vec/modelcards/classifier_template.md +50 -0
- src/distiller/model2vec/modelcards/model_card_template.md +91 -0
- src/distiller/model2vec/py.typed +0 -0
- src/distiller/model2vec/quantization.py +63 -0
- src/distiller/model2vec/tokenizer/__init__.py +12 -0
- src/distiller/model2vec/tokenizer/datamodels.py +14 -0
- src/distiller/model2vec/tokenizer/model.py +43 -0
- src/distiller/model2vec/tokenizer/normalizer.py +34 -0
- src/distiller/model2vec/tokenizer/pretokenizer.py +58 -0
- src/distiller/model2vec/tokenizer/tokenizer.py +398 -0
- src/distiller/model2vec/train/README.md +151 -0
- src/distiller/model2vec/train/__init__.py +10 -0
- src/distiller/model2vec/train/base.py +173 -0
- src/distiller/model2vec/train/classifier.py +426 -0
- src/distiller/model2vec/utils.py +130 -0
- src/distiller/model2vec/version.py +2 -0
src/distiller/model2vec/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import StaticModel
|
| 2 |
+
from .version import __version__
|
| 3 |
+
|
| 4 |
+
__all__ = ["StaticModel", "__version__"]
|
src/distiller/model2vec/distill/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from distiller.model2vec.utils import get_package_extras, importable
|
| 2 |
+
|
| 3 |
+
_REQUIRED_EXTRA = "distill"
|
| 4 |
+
|
| 5 |
+
for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
|
| 6 |
+
importable(extra_dependency, _REQUIRED_EXTRA)
|
| 7 |
+
|
| 8 |
+
from distiller.model2vec.distill.distillation import distill, distill_from_model
|
| 9 |
+
|
| 10 |
+
__all__ = ["distill", "distill_from_model"]
|
src/distiller/model2vec/distill/distillation.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from typing import cast
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from huggingface_hub import model_info
|
| 10 |
+
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast
|
| 11 |
+
|
| 12 |
+
from distiller.model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
|
| 13 |
+
from distiller.model2vec.distill.utils import select_optimal_device
|
| 14 |
+
from distiller.model2vec.model import StaticModel
|
| 15 |
+
from distiller.model2vec.quantization import DType, quantize_embeddings
|
| 16 |
+
from distiller.model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def distill_from_model(
|
| 22 |
+
model: PreTrainedModel,
|
| 23 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 24 |
+
vocabulary: list[str] | None = None,
|
| 25 |
+
device: str | None = None,
|
| 26 |
+
pca_dims: PCADimType = 256,
|
| 27 |
+
apply_zipf: bool | None = None,
|
| 28 |
+
sif_coefficient: float | None = 1e-4,
|
| 29 |
+
token_remove_pattern: str | None = r"\[unused\d+\]",
|
| 30 |
+
quantize_to: DType | str = DType.Float16,
|
| 31 |
+
use_subword: bool | None = None,
|
| 32 |
+
) -> StaticModel:
|
| 33 |
+
"""
|
| 34 |
+
Distill a staticmodel from a sentence transformer.
|
| 35 |
+
|
| 36 |
+
This function creates a set of embeddings from a sentence transformer. It does this by doing either
|
| 37 |
+
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
|
| 38 |
+
|
| 39 |
+
If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
|
| 40 |
+
If you don't pass a vocabulary, we use the model's tokenizer directly.
|
| 41 |
+
|
| 42 |
+
:param model: The model to use.
|
| 43 |
+
:param tokenizer: The tokenizer to use.
|
| 44 |
+
:param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
|
| 45 |
+
:param device: The device to use.
|
| 46 |
+
:param pca_dims: The number of components to use for PCA.
|
| 47 |
+
If this is None, we don't apply PCA.
|
| 48 |
+
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
|
| 49 |
+
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
|
| 50 |
+
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
|
| 51 |
+
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
|
| 52 |
+
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
|
| 53 |
+
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
|
| 54 |
+
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
|
| 55 |
+
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
|
| 56 |
+
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
|
| 57 |
+
:return: A StaticModel
|
| 58 |
+
:raises: ValueError if the vocabulary is empty after preprocessing.
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
if use_subword is not None:
|
| 62 |
+
logger.warning(
|
| 63 |
+
"The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
|
| 64 |
+
)
|
| 65 |
+
quantize_to = DType(quantize_to)
|
| 66 |
+
backend_tokenizer = tokenizer.backend_tokenizer
|
| 67 |
+
sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)
|
| 68 |
+
|
| 69 |
+
if vocabulary is None:
|
| 70 |
+
vocabulary = []
|
| 71 |
+
|
| 72 |
+
device = select_optimal_device(device)
|
| 73 |
+
|
| 74 |
+
n_tokens_before = len(vocabulary)
|
| 75 |
+
# Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
|
| 76 |
+
all_tokens, backend_tokenizer = clean_and_create_vocabulary(
|
| 77 |
+
tokenizer, vocabulary, token_remove_regex=token_remove_regex
|
| 78 |
+
)
|
| 79 |
+
n_tokens_after = len([token for token in all_tokens if not token.is_internal])
|
| 80 |
+
if n_tokens_before:
|
| 81 |
+
logger.info(
|
| 82 |
+
f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if not all_tokens:
|
| 86 |
+
msg = "The vocabulary is empty after preprocessing. Please check your token_remove_pattern."
|
| 87 |
+
raise ValueError(msg)
|
| 88 |
+
|
| 89 |
+
unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token"))
|
| 90 |
+
pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token"))
|
| 91 |
+
|
| 92 |
+
# Weird if to satsify mypy
|
| 93 |
+
if pad_token is None:
|
| 94 |
+
if unk_token is not None:
|
| 95 |
+
pad_token = unk_token
|
| 96 |
+
logger.warning(
|
| 97 |
+
"The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token."
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
pad_token = unk_token or all_tokens[0].form
|
| 101 |
+
logger.warning(
|
| 102 |
+
"The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Replace the vocabulary in the tokenizer with the new vocabulary.
|
| 106 |
+
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
|
| 107 |
+
|
| 108 |
+
logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
|
| 109 |
+
# Convert tokens to IDs
|
| 110 |
+
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
|
| 111 |
+
|
| 112 |
+
# Create the embeddings
|
| 113 |
+
embeddings = create_embeddings(
|
| 114 |
+
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Post process the embeddings by applying PCA and Zipf weighting.
|
| 118 |
+
embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
|
| 119 |
+
# Quantize the embeddings.
|
| 120 |
+
embeddings = quantize_embeddings(embeddings, quantize_to)
|
| 121 |
+
|
| 122 |
+
model_name = getattr(model, "name_or_path", "")
|
| 123 |
+
|
| 124 |
+
config = {
|
| 125 |
+
"model_type": "model2vec",
|
| 126 |
+
"architectures": ["StaticModel"],
|
| 127 |
+
"tokenizer_name": model_name,
|
| 128 |
+
"apply_pca": pca_dims,
|
| 129 |
+
"apply_zipf": apply_zipf,
|
| 130 |
+
"sif_coefficient": sif_coefficient,
|
| 131 |
+
"hidden_dim": embeddings.shape[1],
|
| 132 |
+
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
|
| 133 |
+
"normalize": True,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
if os.path.exists(model_name):
|
| 137 |
+
# Using a local model. Get the model name from the path.
|
| 138 |
+
model_name = os.path.basename(model_name)
|
| 139 |
+
language = None
|
| 140 |
+
else:
|
| 141 |
+
# Get the language from the model card.
|
| 142 |
+
try:
|
| 143 |
+
info = model_info(model_name)
|
| 144 |
+
language = info.cardData.get("language", None) if info.cardData is not None else None
|
| 145 |
+
except Exception as e:
|
| 146 |
+
# NOTE: bare except because there's many reasons this can fail.
|
| 147 |
+
logger.warning(f"Couldn't get the model info from the Hugging Face Hub: {e}. Setting language to None.")
|
| 148 |
+
language = None
|
| 149 |
+
|
| 150 |
+
return StaticModel(
|
| 151 |
+
vectors=embeddings,
|
| 152 |
+
tokenizer=backend_tokenizer,
|
| 153 |
+
config=config,
|
| 154 |
+
base_model_name=model_name,
|
| 155 |
+
language=language,
|
| 156 |
+
normalize=True,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _validate_parameters(
|
| 161 |
+
apply_zipf: bool | None,
|
| 162 |
+
sif_coefficient: float | None,
|
| 163 |
+
token_remove_pattern: str | None,
|
| 164 |
+
) -> tuple[float | None, re.Pattern | None]:
|
| 165 |
+
"""
|
| 166 |
+
Validate the parameters passed to the distillation function.
|
| 167 |
+
|
| 168 |
+
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
|
| 169 |
+
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
|
| 170 |
+
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
|
| 171 |
+
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
|
| 172 |
+
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
|
| 173 |
+
:return: The SIF coefficient to use.
|
| 174 |
+
:raises: ValueError if the regex can't be compiled.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
if apply_zipf is not None:
|
| 178 |
+
logger.warning(
|
| 179 |
+
"The `apply_zipf` parameter is deprecated and will be removed in the next release. "
|
| 180 |
+
"Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
|
| 181 |
+
"no weighting is applied."
|
| 182 |
+
)
|
| 183 |
+
if apply_zipf and sif_coefficient is None:
|
| 184 |
+
logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
|
| 185 |
+
sif_coefficient = 1e-4
|
| 186 |
+
elif not apply_zipf:
|
| 187 |
+
logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
|
| 188 |
+
sif_coefficient = None
|
| 189 |
+
|
| 190 |
+
if sif_coefficient is not None and not 0 < sif_coefficient < 1.0:
|
| 191 |
+
msg = "SIF coefficient must be a value > 0 and < 1.0."
|
| 192 |
+
raise ValueError(msg)
|
| 193 |
+
|
| 194 |
+
token_remove_regex: re.Pattern | None = None
|
| 195 |
+
if token_remove_pattern is not None:
|
| 196 |
+
try:
|
| 197 |
+
token_remove_regex = re.compile(token_remove_pattern)
|
| 198 |
+
except re.error as e:
|
| 199 |
+
msg = f"Couldn't compile the regex pattern: {e}"
|
| 200 |
+
raise ValueError(msg)
|
| 201 |
+
|
| 202 |
+
return sif_coefficient, token_remove_regex
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def distill(
|
| 206 |
+
model_name: str,
|
| 207 |
+
vocabulary: list[str] | None = None,
|
| 208 |
+
device: str | None = None,
|
| 209 |
+
pca_dims: PCADimType = 256,
|
| 210 |
+
apply_zipf: bool | None = None,
|
| 211 |
+
sif_coefficient: float | None = 1e-4,
|
| 212 |
+
token_remove_pattern: str | None = r"\[unused\d+\]",
|
| 213 |
+
trust_remote_code: bool = False,
|
| 214 |
+
quantize_to: DType | str = DType.Float16,
|
| 215 |
+
use_subword: bool | None = None,
|
| 216 |
+
) -> StaticModel:
|
| 217 |
+
"""
|
| 218 |
+
Distill a staticmodel from a sentence transformer.
|
| 219 |
+
|
| 220 |
+
This function creates a set of embeddings from a sentence transformer. It does this by doing either
|
| 221 |
+
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
|
| 222 |
+
|
| 223 |
+
If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
|
| 224 |
+
If you don't pass a vocabulary, we use the model's tokenizer directly.
|
| 225 |
+
|
| 226 |
+
:param model_name: The model name to use. Any sentencetransformer compatible model works.
|
| 227 |
+
:param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
|
| 228 |
+
:param device: The device to use.
|
| 229 |
+
:param pca_dims: The number of components to use for PCA.
|
| 230 |
+
If this is None, we don't apply PCA.
|
| 231 |
+
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
|
| 232 |
+
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
|
| 233 |
+
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
|
| 234 |
+
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
|
| 235 |
+
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
|
| 236 |
+
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
|
| 237 |
+
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
|
| 238 |
+
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
|
| 239 |
+
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
|
| 240 |
+
:return: A StaticModel
|
| 241 |
+
|
| 242 |
+
"""
|
| 243 |
+
model: PreTrainedModel = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
| 244 |
+
tokenizer = cast(
|
| 245 |
+
"PreTrainedTokenizerFast",
|
| 246 |
+
AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, use_fast=True),
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return distill_from_model(
|
| 250 |
+
model=model,
|
| 251 |
+
tokenizer=tokenizer,
|
| 252 |
+
vocabulary=vocabulary,
|
| 253 |
+
device=device,
|
| 254 |
+
pca_dims=pca_dims,
|
| 255 |
+
apply_zipf=apply_zipf,
|
| 256 |
+
token_remove_pattern=token_remove_pattern,
|
| 257 |
+
sif_coefficient=sif_coefficient,
|
| 258 |
+
quantize_to=quantize_to,
|
| 259 |
+
use_subword=use_subword,
|
| 260 |
+
)
|
src/distiller/model2vec/distill/inference.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import TYPE_CHECKING, Literal, Protocol, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from sklearn.decomposition import PCA
|
| 11 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from transformers import PreTrainedModel
|
| 16 |
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
PathLike = Union[Path, str]
|
| 22 |
+
PCADimType = Union[int, None, float, Literal["auto"]]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
_DEFAULT_BATCH_SIZE = 256
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ModulewithWeights(Protocol):
|
| 29 |
+
weight: torch.nn.Parameter
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_embeddings(
|
| 33 |
+
model: PreTrainedModel,
|
| 34 |
+
tokenized: list[list[int]],
|
| 35 |
+
device: str,
|
| 36 |
+
pad_token_id: int,
|
| 37 |
+
) -> np.ndarray:
|
| 38 |
+
"""
|
| 39 |
+
Create output embeddings for a bunch of tokens using a pretrained model.
|
| 40 |
+
|
| 41 |
+
It does a forward pass for all tokens passed in `tokens`.
|
| 42 |
+
|
| 43 |
+
:param model: The model to use.
|
| 44 |
+
This should be a transformers model.
|
| 45 |
+
:param tokenized: All tokenized tokens.
|
| 46 |
+
:param device: The torch device to use.
|
| 47 |
+
:param pad_token_id: The pad token id. Used to pad sequences.
|
| 48 |
+
:return: The output embeddings.
|
| 49 |
+
"""
|
| 50 |
+
model = model.to(device)
|
| 51 |
+
|
| 52 |
+
out_weights: np.ndarray
|
| 53 |
+
intermediate_weights: list[np.ndarray] = []
|
| 54 |
+
|
| 55 |
+
# Add token_type_ids only if the model supports it
|
| 56 |
+
add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args
|
| 57 |
+
|
| 58 |
+
lengths = np.asarray([len(sequence) for sequence in tokenized])
|
| 59 |
+
sort_order = np.argsort(lengths)
|
| 60 |
+
|
| 61 |
+
sorted_tokenized = [tokenized[i] for i in sort_order]
|
| 62 |
+
|
| 63 |
+
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
|
| 64 |
+
|
| 65 |
+
for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
|
| 66 |
+
batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]
|
| 67 |
+
|
| 68 |
+
encoded = {}
|
| 69 |
+
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
|
| 70 |
+
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
|
| 71 |
+
|
| 72 |
+
if add_token_type_ids:
|
| 73 |
+
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
|
| 74 |
+
|
| 75 |
+
out = _encode_mean_using_model(model, encoded)
|
| 76 |
+
intermediate_weights.extend(out.numpy())
|
| 77 |
+
pbar.update(len(batch))
|
| 78 |
+
|
| 79 |
+
# Sort the output back to the original order
|
| 80 |
+
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
|
| 81 |
+
out_weights = np.stack(intermediate_weights)
|
| 82 |
+
|
| 83 |
+
out_weights = np.nan_to_num(out_weights)
|
| 84 |
+
|
| 85 |
+
return out_weights
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
|
| 90 |
+
"""
|
| 91 |
+
Encode a batch of tokens using a model.
|
| 92 |
+
|
| 93 |
+
Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
|
| 94 |
+
So detection of these is necessary.
|
| 95 |
+
|
| 96 |
+
:param model: The model to use.
|
| 97 |
+
:param encodings: The encoded tokens to turn into features.
|
| 98 |
+
:return: The mean of the output for each token.
|
| 99 |
+
"""
|
| 100 |
+
encodings = {k: v.to(model.device) for k, v in encodings.items()}
|
| 101 |
+
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
|
| 102 |
+
out: torch.Tensor = encoded.last_hidden_state.cpu()
|
| 103 |
+
# NOTE: If the dtype is bfloat 16, we convert to float32,
|
| 104 |
+
# because numpy does not suport bfloat16
|
| 105 |
+
# See here: https://github.com/numpy/numpy/issues/19808
|
| 106 |
+
if out.dtype == torch.bfloat16:
|
| 107 |
+
out = out.float()
|
| 108 |
+
|
| 109 |
+
# Take the mean by averaging over the attention mask.
|
| 110 |
+
mask = encodings["attention_mask"].cpu().float()
|
| 111 |
+
mask /= mask.sum(1)[:, None]
|
| 112 |
+
|
| 113 |
+
return torch.bmm(mask[:, None, :].float(), out).squeeze(1)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def post_process_embeddings(
|
| 118 |
+
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
|
| 119 |
+
) -> np.ndarray:
|
| 120 |
+
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
|
| 121 |
+
if pca_dims is not None:
|
| 122 |
+
if pca_dims == "auto":
|
| 123 |
+
pca_dims = embeddings.shape[1]
|
| 124 |
+
if pca_dims > embeddings.shape[1]:
|
| 125 |
+
logger.warning(
|
| 126 |
+
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
|
| 127 |
+
"Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
|
| 128 |
+
"Applying PCA will probably improve performance, so consider just leaving it."
|
| 129 |
+
)
|
| 130 |
+
pca_dims = embeddings.shape[1]
|
| 131 |
+
if pca_dims >= embeddings.shape[0]:
|
| 132 |
+
logger.warning(
|
| 133 |
+
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
|
| 134 |
+
)
|
| 135 |
+
elif pca_dims <= embeddings.shape[1]:
|
| 136 |
+
if isinstance(pca_dims, float):
|
| 137 |
+
logger.info(f"Applying PCA with {pca_dims} explained variance.")
|
| 138 |
+
else:
|
| 139 |
+
logger.info(f"Applying PCA with n_components {pca_dims}")
|
| 140 |
+
|
| 141 |
+
orig_dims = embeddings.shape[1]
|
| 142 |
+
p = PCA(n_components=pca_dims, svd_solver="full")
|
| 143 |
+
embeddings = p.fit_transform(embeddings)
|
| 144 |
+
|
| 145 |
+
if embeddings.shape[1] < orig_dims:
|
| 146 |
+
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
|
| 147 |
+
explained_variance = np.sum(p.explained_variance_)
|
| 148 |
+
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
|
| 149 |
+
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
|
| 150 |
+
logger.info(f"Explained variance: {explained_variance:.3f}.")
|
| 151 |
+
|
| 152 |
+
if sif_coefficient is not None:
|
| 153 |
+
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
|
| 154 |
+
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
|
| 155 |
+
proba = inv_rank / np.sum(inv_rank)
|
| 156 |
+
embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]
|
| 157 |
+
|
| 158 |
+
return embeddings
|
src/distiller/model2vec/distill/utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from logging import getLogger
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
logger = getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def select_optimal_device(device: str | None) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Guess what your optimal device should be based on backend availability.
|
| 13 |
+
|
| 14 |
+
If you pass a device, we just pass it through.
|
| 15 |
+
|
| 16 |
+
:param device: The device to use. If this is not None you get back what you passed.
|
| 17 |
+
:return: The selected device.
|
| 18 |
+
"""
|
| 19 |
+
if device is None:
|
| 20 |
+
if torch.cuda.is_available():
|
| 21 |
+
device = "cuda"
|
| 22 |
+
elif torch.backends.mps.is_available():
|
| 23 |
+
device = "mps"
|
| 24 |
+
else:
|
| 25 |
+
device = "cpu"
|
| 26 |
+
logger.info(f"Automatically selected device: {device}")
|
| 27 |
+
|
| 28 |
+
return device
|
src/distiller/model2vec/hf_utils.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import TYPE_CHECKING, Any, cast
|
| 7 |
+
|
| 8 |
+
import huggingface_hub
|
| 9 |
+
import safetensors
|
| 10 |
+
from huggingface_hub import ModelCard, ModelCardData
|
| 11 |
+
from safetensors.numpy import save_file
|
| 12 |
+
from tokenizers import Tokenizer
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from distiller.model2vec.utils import SafeOpenProtocol
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def save_pretrained(
|
| 23 |
+
folder_path: Path,
|
| 24 |
+
embeddings: np.ndarray,
|
| 25 |
+
tokenizer: Tokenizer,
|
| 26 |
+
config: dict[str, Any],
|
| 27 |
+
create_model_card: bool = True,
|
| 28 |
+
subfolder: str | None = None,
|
| 29 |
+
**kwargs: Any,
|
| 30 |
+
) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Save a model to a folder.
|
| 33 |
+
|
| 34 |
+
:param folder_path: The path to the folder.
|
| 35 |
+
:param embeddings: The embeddings.
|
| 36 |
+
:param tokenizer: The tokenizer.
|
| 37 |
+
:param config: A metadata config.
|
| 38 |
+
:param create_model_card: Whether to create a model card.
|
| 39 |
+
:param subfolder: The subfolder to save the model in.
|
| 40 |
+
:param **kwargs: Any additional arguments.
|
| 41 |
+
"""
|
| 42 |
+
folder_path = folder_path / subfolder if subfolder else folder_path
|
| 43 |
+
folder_path.mkdir(exist_ok=True, parents=True)
|
| 44 |
+
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
|
| 45 |
+
tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
|
| 46 |
+
json.dump(config, open(folder_path / "config.json", "w"), indent=4)
|
| 47 |
+
|
| 48 |
+
# Create modules.json
|
| 49 |
+
modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}]
|
| 50 |
+
if config.get("normalize"):
|
| 51 |
+
# If normalize=True, add sentence_transformers.models.Normalize
|
| 52 |
+
modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"})
|
| 53 |
+
json.dump(modules, open(folder_path / "modules.json", "w"), indent=4)
|
| 54 |
+
|
| 55 |
+
logger.info(f"Saved model to {folder_path}")
|
| 56 |
+
|
| 57 |
+
# Optionally create the model card
|
| 58 |
+
if create_model_card:
|
| 59 |
+
_create_model_card(folder_path, **kwargs)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _create_model_card(
|
| 63 |
+
folder_path: Path,
|
| 64 |
+
base_model_name: str = "unknown",
|
| 65 |
+
license: str = "mit",
|
| 66 |
+
language: list[str] | None = None,
|
| 67 |
+
model_name: str | None = None,
|
| 68 |
+
template_path: str = "modelcards/model_card_template.md",
|
| 69 |
+
**kwargs: Any,
|
| 70 |
+
) -> None:
|
| 71 |
+
"""
|
| 72 |
+
Create a model card and store it in the specified path.
|
| 73 |
+
|
| 74 |
+
:param folder_path: The path where the model card will be stored.
|
| 75 |
+
:param base_model_name: The name of the base model.
|
| 76 |
+
:param license: The license to use.
|
| 77 |
+
:param language: The language of the model.
|
| 78 |
+
:param model_name: The name of the model to use in the Model Card.
|
| 79 |
+
:param template_path: The path to the template.
|
| 80 |
+
:param **kwargs: Additional metadata for the model card (e.g., model_name, base_model, etc.).
|
| 81 |
+
"""
|
| 82 |
+
folder_path = Path(folder_path)
|
| 83 |
+
model_name = model_name or folder_path.name
|
| 84 |
+
full_path = Path(__file__).parent / template_path
|
| 85 |
+
|
| 86 |
+
model_card_data = ModelCardData(
|
| 87 |
+
model_name=model_name,
|
| 88 |
+
base_model=base_model_name,
|
| 89 |
+
license=license,
|
| 90 |
+
language=language,
|
| 91 |
+
tags=["embeddings", "static-embeddings", "sentence-transformers"],
|
| 92 |
+
library_name="model2vec",
|
| 93 |
+
**kwargs,
|
| 94 |
+
)
|
| 95 |
+
model_card = ModelCard.from_template(model_card_data, template_path=str(full_path))
|
| 96 |
+
model_card.save(folder_path / "README.md")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_pretrained(
|
| 100 |
+
folder_or_repo_path: str | Path,
|
| 101 |
+
subfolder: str | None = None,
|
| 102 |
+
token: str | None = None,
|
| 103 |
+
from_sentence_transformers: bool = False,
|
| 104 |
+
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
|
| 105 |
+
"""
|
| 106 |
+
Loads a pretrained model from a folder.
|
| 107 |
+
|
| 108 |
+
:param folder_or_repo_path: The folder or repo path to load from.
|
| 109 |
+
- If this is a local path, we will load from the local path.
|
| 110 |
+
- If the local path is not found, we will attempt to load from the huggingface hub.
|
| 111 |
+
:param subfolder: The subfolder to load from.
|
| 112 |
+
:param token: The huggingface token to use.
|
| 113 |
+
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
|
| 114 |
+
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
|
| 115 |
+
:return: The embeddings, tokenizer, config, and metadata.
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
if from_sentence_transformers:
|
| 119 |
+
model_file = "0_StaticEmbedding/model.safetensors"
|
| 120 |
+
tokenizer_file = "0_StaticEmbedding/tokenizer.json"
|
| 121 |
+
config_name = "config_sentence_transformers.json"
|
| 122 |
+
else:
|
| 123 |
+
model_file = "model.safetensors"
|
| 124 |
+
tokenizer_file = "tokenizer.json"
|
| 125 |
+
config_name = "config.json"
|
| 126 |
+
|
| 127 |
+
folder_or_repo_path = Path(folder_or_repo_path)
|
| 128 |
+
|
| 129 |
+
local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
|
| 130 |
+
|
| 131 |
+
if local_folder.exists():
|
| 132 |
+
embeddings_path = local_folder / model_file
|
| 133 |
+
if not embeddings_path.exists():
|
| 134 |
+
msg = f"Embeddings file does not exist in {local_folder}"
|
| 135 |
+
raise FileNotFoundError(msg)
|
| 136 |
+
|
| 137 |
+
config_path = local_folder / config_name
|
| 138 |
+
if not config_path.exists():
|
| 139 |
+
msg = f"Config file does not exist in {local_folder}"
|
| 140 |
+
raise FileNotFoundError(msg)
|
| 141 |
+
|
| 142 |
+
tokenizer_path = local_folder / tokenizer_file
|
| 143 |
+
if not tokenizer_path.exists():
|
| 144 |
+
msg = f"Tokenizer file does not exist in {local_folder}"
|
| 145 |
+
raise FileNotFoundError(msg)
|
| 146 |
+
|
| 147 |
+
# README is optional, so this is a bit finicky.
|
| 148 |
+
readme_path = local_folder / "README.md"
|
| 149 |
+
metadata = _get_metadata_from_readme(readme_path)
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
|
| 153 |
+
embeddings_path = Path(
|
| 154 |
+
huggingface_hub.hf_hub_download(
|
| 155 |
+
folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
readme_path = Path(
|
| 161 |
+
huggingface_hub.hf_hub_download(
|
| 162 |
+
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
+
metadata = _get_metadata_from_readme(Path(readme_path))
|
| 166 |
+
except Exception as e:
|
| 167 |
+
# NOTE: we don't want to raise an error here, since the README is optional.
|
| 168 |
+
logger.info(f"No README found in the model folder: {e} No model card loaded.")
|
| 169 |
+
metadata = {}
|
| 170 |
+
|
| 171 |
+
config_path = Path(
|
| 172 |
+
huggingface_hub.hf_hub_download(
|
| 173 |
+
folder_or_repo_path.as_posix(), config_name, token=token, subfolder=subfolder
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
tokenizer_path = Path(
|
| 177 |
+
huggingface_hub.hf_hub_download(
|
| 178 |
+
folder_or_repo_path.as_posix(), tokenizer_file, token=token, subfolder=subfolder
|
| 179 |
+
)
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy"))
|
| 183 |
+
if from_sentence_transformers:
|
| 184 |
+
embeddings = opened_tensor_file.get_tensor("embedding.weight")
|
| 185 |
+
else:
|
| 186 |
+
embeddings = opened_tensor_file.get_tensor("embeddings")
|
| 187 |
+
|
| 188 |
+
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
| 189 |
+
config = json.load(open(config_path))
|
| 190 |
+
|
| 191 |
+
if len(tokenizer.get_vocab()) != len(embeddings):
|
| 192 |
+
logger.warning(
|
| 193 |
+
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return embeddings, tokenizer, config, metadata
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
|
| 200 |
+
"""Get metadata from a README file."""
|
| 201 |
+
if not readme_path.exists():
|
| 202 |
+
logger.info(f"README file not found in {readme_path}. No model card loaded.")
|
| 203 |
+
return {}
|
| 204 |
+
model_card = ModelCard.load(readme_path)
|
| 205 |
+
data: dict[str, Any] = model_card.data.to_dict()
|
| 206 |
+
if not data:
|
| 207 |
+
logger.info("File README.md exists, but was empty. No model card loaded.")
|
| 208 |
+
return data
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def push_folder_to_hub(
|
| 212 |
+
folder_path: Path, subfolder: str | None, repo_id: str, private: bool, token: str | None
|
| 213 |
+
) -> None:
|
| 214 |
+
"""
|
| 215 |
+
Push a model folder to the huggingface hub, including model card.
|
| 216 |
+
|
| 217 |
+
:param folder_path: The path to the folder.
|
| 218 |
+
:param subfolder: The subfolder to push to.
|
| 219 |
+
If None, the folder will be pushed to the root of the repo.
|
| 220 |
+
:param repo_id: The repo name.
|
| 221 |
+
:param private: Whether the repo is private.
|
| 222 |
+
:param token: The huggingface token.
|
| 223 |
+
"""
|
| 224 |
+
if not huggingface_hub.repo_exists(repo_id=repo_id, token=token):
|
| 225 |
+
huggingface_hub.create_repo(repo_id, token=token, private=private)
|
| 226 |
+
|
| 227 |
+
# Push model card and all model files to the Hugging Face hub
|
| 228 |
+
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder)
|
| 229 |
+
|
| 230 |
+
logger.info(f"Pushed model to {repo_id}")
|
src/distiller/model2vec/inference/README.md
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference
|
| 2 |
+
|
| 3 |
+
This subpackage mainly contains helper functions for inference with trained models that have been exported to `scikit-learn` compatible pipelines.
|
| 4 |
+
|
| 5 |
+
If you're looking for information on how to train a model, see [here](../train/README.md).
|
| 6 |
+
|
| 7 |
+
# Usage
|
| 8 |
+
|
| 9 |
+
Let's assume you're using our [potion-edu classifier](https://huggingface.co/minishlab/potion-8m-edu-classifier).
|
| 10 |
+
|
| 11 |
+
```python
|
| 12 |
+
from model2vec.inference import StaticModelPipeline
|
| 13 |
+
|
| 14 |
+
classifier = StaticModelPipeline.from_pretrained("minishlab/potion-8m-edu-classifier")
|
| 15 |
+
label = classifier.predict("Attitudes towards cattle in the Alps: a study in letting go.")
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
This should just work.
|
src/distiller/model2vec/inference/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from distiller.model2vec.utils import get_package_extras, importable
|
| 2 |
+
|
| 3 |
+
_REQUIRED_EXTRA = "inference"
|
| 4 |
+
|
| 5 |
+
for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
|
| 6 |
+
importable(extra_dependency, _REQUIRED_EXTRA)
|
| 7 |
+
|
| 8 |
+
from distiller.model2vec.inference.model import StaticModelPipeline, evaluate_single_or_multi_label
|
| 9 |
+
|
| 10 |
+
__all__ = ["StaticModelPipeline", "evaluate_single_or_multi_label"]
|
src/distiller/model2vec/inference/model.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from tempfile import TemporaryDirectory
|
| 6 |
+
from typing import TYPE_CHECKING, TypeVar
|
| 7 |
+
|
| 8 |
+
import huggingface_hub
|
| 9 |
+
import numpy as np
|
| 10 |
+
import skops.io
|
| 11 |
+
from sklearn.metrics import classification_report
|
| 12 |
+
from sklearn.neural_network import MLPClassifier
|
| 13 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 14 |
+
|
| 15 |
+
from distiller.model2vec.hf_utils import _create_model_card
|
| 16 |
+
from distiller.model2vec.model import PathLike, StaticModel
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from collections.abc import Sequence
|
| 20 |
+
|
| 21 |
+
from sklearn.pipeline import Pipeline
|
| 22 |
+
|
| 23 |
+
_DEFAULT_TRUST_PATTERN = re.compile(r"sklearn\..+")
|
| 24 |
+
_DEFAULT_MODEL_FILENAME = "pipeline.skops"
|
| 25 |
+
|
| 26 |
+
LabelType = TypeVar("LabelType", list[str], list[list[str]])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class StaticModelPipeline:
|
| 30 |
+
def __init__(self, model: StaticModel, head: Pipeline) -> None:
|
| 31 |
+
"""Create a pipeline with a StaticModel encoder."""
|
| 32 |
+
self.model = model
|
| 33 |
+
self.head = head
|
| 34 |
+
classifier = self.head[-1]
|
| 35 |
+
# Check if the classifier is a multilabel classifier.
|
| 36 |
+
# NOTE: this doesn't look robust, but it is.
|
| 37 |
+
# Different classifiers, such as OVR wrappers, support multilabel output natively, so we
|
| 38 |
+
# can just use predict.
|
| 39 |
+
self.multilabel = False
|
| 40 |
+
if isinstance(classifier, MLPClassifier) and classifier.out_activation_ == "logistic":
|
| 41 |
+
self.multilabel = True
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def classes_(self) -> np.ndarray:
|
| 45 |
+
"""The classes of the classifier."""
|
| 46 |
+
return self.head.classes_
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_pretrained(
|
| 50 |
+
cls: type[StaticModelPipeline], path: PathLike, token: str | None = None, trust_remote_code: bool = False
|
| 51 |
+
) -> StaticModelPipeline:
|
| 52 |
+
"""
|
| 53 |
+
Load a StaticModel from a local path or huggingface hub path.
|
| 54 |
+
|
| 55 |
+
NOTE: if you load a private model from the huggingface hub, you need to pass a token.
|
| 56 |
+
|
| 57 |
+
:param path: The path to the folder containing the pipeline, or a repository on the Hugging Face Hub
|
| 58 |
+
:param token: The token to use to download the pipeline from the hub.
|
| 59 |
+
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `sklearn`.
|
| 60 |
+
:return: The loaded pipeline.
|
| 61 |
+
"""
|
| 62 |
+
model, head = _load_pipeline(path, token, trust_remote_code)
|
| 63 |
+
model.embedding = np.nan_to_num(model.embedding)
|
| 64 |
+
|
| 65 |
+
return cls(model, head)
|
| 66 |
+
|
| 67 |
+
def save_pretrained(self, path: str) -> None:
|
| 68 |
+
"""Save the model to a folder."""
|
| 69 |
+
save_pipeline(self, path)
|
| 70 |
+
|
| 71 |
+
def push_to_hub(
|
| 72 |
+
self, repo_id: str, subfolder: str | None = None, token: str | None = None, private: bool = False
|
| 73 |
+
) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Save a model to a folder, and then push that folder to the hf hub.
|
| 76 |
+
|
| 77 |
+
:param repo_id: The id of the repository to push to.
|
| 78 |
+
:param subfolder: The subfolder to push to.
|
| 79 |
+
:param token: The token to use to push to the hub.
|
| 80 |
+
:param private: Whether the repository should be private.
|
| 81 |
+
"""
|
| 82 |
+
from distiller.model2vec.hf_utils import push_folder_to_hub
|
| 83 |
+
|
| 84 |
+
with TemporaryDirectory() as temp_dir:
|
| 85 |
+
save_pipeline(self, temp_dir)
|
| 86 |
+
self.model.save_pretrained(temp_dir)
|
| 87 |
+
push_folder_to_hub(Path(temp_dir), subfolder, repo_id, private, token)
|
| 88 |
+
|
| 89 |
+
def _encode_and_coerce_to_2d(
|
| 90 |
+
self,
|
| 91 |
+
X: Sequence[str],
|
| 92 |
+
show_progress_bar: bool,
|
| 93 |
+
max_length: int | None,
|
| 94 |
+
batch_size: int,
|
| 95 |
+
use_multiprocessing: bool,
|
| 96 |
+
multiprocessing_threshold: int,
|
| 97 |
+
) -> np.ndarray:
|
| 98 |
+
"""Encode the instances and coerce the output to a matrix."""
|
| 99 |
+
encoded = self.model.encode(
|
| 100 |
+
X,
|
| 101 |
+
show_progress_bar=show_progress_bar,
|
| 102 |
+
max_length=max_length,
|
| 103 |
+
batch_size=batch_size,
|
| 104 |
+
use_multiprocessing=use_multiprocessing,
|
| 105 |
+
multiprocessing_threshold=multiprocessing_threshold,
|
| 106 |
+
)
|
| 107 |
+
if np.ndim(encoded) == 1:
|
| 108 |
+
encoded = encoded[None, :]
|
| 109 |
+
|
| 110 |
+
return encoded
|
| 111 |
+
|
| 112 |
+
def predict(
|
| 113 |
+
self,
|
| 114 |
+
X: Sequence[str],
|
| 115 |
+
show_progress_bar: bool = False,
|
| 116 |
+
max_length: int | None = 512,
|
| 117 |
+
batch_size: int = 1024,
|
| 118 |
+
use_multiprocessing: bool = True,
|
| 119 |
+
multiprocessing_threshold: int = 10_000,
|
| 120 |
+
threshold: float = 0.5,
|
| 121 |
+
) -> np.ndarray:
|
| 122 |
+
"""
|
| 123 |
+
Predict the labels of the input.
|
| 124 |
+
|
| 125 |
+
:param X: The input data to predict. Can be a list of strings or a single string.
|
| 126 |
+
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
|
| 127 |
+
:param max_length: The maximum length of the input sequences. Defaults to 512.
|
| 128 |
+
:param batch_size: The batch size for prediction. Defaults to 1024.
|
| 129 |
+
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
|
| 130 |
+
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
|
| 131 |
+
:param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel.
|
| 132 |
+
:return: The predicted labels or probabilities.
|
| 133 |
+
"""
|
| 134 |
+
encoded = self._encode_and_coerce_to_2d(
|
| 135 |
+
X,
|
| 136 |
+
show_progress_bar=show_progress_bar,
|
| 137 |
+
max_length=max_length,
|
| 138 |
+
batch_size=batch_size,
|
| 139 |
+
use_multiprocessing=use_multiprocessing,
|
| 140 |
+
multiprocessing_threshold=multiprocessing_threshold,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if self.multilabel:
|
| 144 |
+
out_labels = []
|
| 145 |
+
proba = self.head.predict_proba(encoded)
|
| 146 |
+
for vector in proba:
|
| 147 |
+
out_labels.append(self.classes_[vector > threshold])
|
| 148 |
+
return np.asarray(out_labels, dtype=object)
|
| 149 |
+
|
| 150 |
+
return self.head.predict(encoded)
|
| 151 |
+
|
| 152 |
+
def predict_proba(
|
| 153 |
+
self,
|
| 154 |
+
X: Sequence[str],
|
| 155 |
+
show_progress_bar: bool = False,
|
| 156 |
+
max_length: int | None = 512,
|
| 157 |
+
batch_size: int = 1024,
|
| 158 |
+
use_multiprocessing: bool = True,
|
| 159 |
+
multiprocessing_threshold: int = 10_000,
|
| 160 |
+
) -> np.ndarray:
|
| 161 |
+
"""
|
| 162 |
+
Predict the labels of the input.
|
| 163 |
+
|
| 164 |
+
:param X: The input data to predict. Can be a list of strings or a single string.
|
| 165 |
+
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
|
| 166 |
+
:param max_length: The maximum length of the input sequences. Defaults to 512.
|
| 167 |
+
:param batch_size: The batch size for prediction. Defaults to 1024.
|
| 168 |
+
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
|
| 169 |
+
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
|
| 170 |
+
:return: The predicted labels or probabilities.
|
| 171 |
+
"""
|
| 172 |
+
encoded = self._encode_and_coerce_to_2d(
|
| 173 |
+
X,
|
| 174 |
+
show_progress_bar=show_progress_bar,
|
| 175 |
+
max_length=max_length,
|
| 176 |
+
batch_size=batch_size,
|
| 177 |
+
use_multiprocessing=use_multiprocessing,
|
| 178 |
+
multiprocessing_threshold=multiprocessing_threshold,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return self.head.predict_proba(encoded)
|
| 182 |
+
|
| 183 |
+
def evaluate(
|
| 184 |
+
self, X: Sequence[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
|
| 185 |
+
) -> str | dict[str, dict[str, float]]:
|
| 186 |
+
"""
|
| 187 |
+
Evaluate the classifier on a given dataset using scikit-learn's classification report.
|
| 188 |
+
|
| 189 |
+
:param X: The texts to predict on.
|
| 190 |
+
:param y: The ground truth labels.
|
| 191 |
+
:param batch_size: The batch size.
|
| 192 |
+
:param threshold: The threshold for multilabel classification.
|
| 193 |
+
:param output_dict: Whether to output the classification report as a dictionary.
|
| 194 |
+
:return: A classification report.
|
| 195 |
+
"""
|
| 196 |
+
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
|
| 197 |
+
return evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _load_pipeline(
|
| 202 |
+
folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False
|
| 203 |
+
) -> tuple[StaticModel, Pipeline]:
|
| 204 |
+
"""
|
| 205 |
+
Load a model and an sklearn pipeline.
|
| 206 |
+
|
| 207 |
+
This assumes the following files are present in the repo:
|
| 208 |
+
- `pipeline.skops`: The head of the pipeline.
|
| 209 |
+
- `config.json`: The configuration of the model.
|
| 210 |
+
- `model.safetensors`: The weights of the model.
|
| 211 |
+
- `tokenizer.json`: The tokenizer of the model.
|
| 212 |
+
|
| 213 |
+
:param folder_or_repo_path: The path to the folder containing the pipeline.
|
| 214 |
+
:param token: The token to use to download the pipeline from the hub. If this is None, you will only
|
| 215 |
+
be able to load the pipeline from a local folder, public repository, or a repository that you have access to
|
| 216 |
+
because you are logged in.
|
| 217 |
+
:param trust_remote_code: Whether to trust the remote code. If this is False,
|
| 218 |
+
we will only load components coming from `sklearn`. If this is True, we will load all components.
|
| 219 |
+
If you set this to True, you are responsible for whatever happens.
|
| 220 |
+
:return: The encoder model and the loaded head
|
| 221 |
+
:raises FileNotFoundError: If the pipeline file does not exist in the folder.
|
| 222 |
+
:raises ValueError: If an untrusted type is found in the pipeline, and `trust_remote_code` is False.
|
| 223 |
+
"""
|
| 224 |
+
folder_or_repo_path = Path(folder_or_repo_path)
|
| 225 |
+
model_filename = _DEFAULT_MODEL_FILENAME
|
| 226 |
+
head_pipeline_path: str | Path
|
| 227 |
+
if folder_or_repo_path.exists():
|
| 228 |
+
head_pipeline_path = folder_or_repo_path / model_filename
|
| 229 |
+
if not head_pipeline_path.exists():
|
| 230 |
+
msg = f"Pipeline file does not exist in {folder_or_repo_path}"
|
| 231 |
+
raise FileNotFoundError(msg)
|
| 232 |
+
else:
|
| 233 |
+
head_pipeline_path = huggingface_hub.hf_hub_download(
|
| 234 |
+
folder_or_repo_path.as_posix(), model_filename, token=token
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
model = StaticModel.from_pretrained(folder_or_repo_path)
|
| 238 |
+
|
| 239 |
+
unknown_types = skops.io.get_untrusted_types(file=head_pipeline_path)
|
| 240 |
+
# If the user does not trust remote code, we should check that the unknown types are trusted.
|
| 241 |
+
# By default, we trust everything coming from scikit-learn.
|
| 242 |
+
if not trust_remote_code:
|
| 243 |
+
for t in unknown_types:
|
| 244 |
+
if not _DEFAULT_TRUST_PATTERN.match(t):
|
| 245 |
+
msg = f"Untrusted type {t}."
|
| 246 |
+
raise ValueError(msg)
|
| 247 |
+
head = skops.io.load(head_pipeline_path, trusted=unknown_types)
|
| 248 |
+
|
| 249 |
+
return model, head
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> None:
|
| 253 |
+
"""
|
| 254 |
+
Save a pipeline to a folder.
|
| 255 |
+
|
| 256 |
+
:param pipeline: The pipeline to save.
|
| 257 |
+
:param folder_path: The path to the folder to save the pipeline to.
|
| 258 |
+
"""
|
| 259 |
+
folder_path = Path(folder_path)
|
| 260 |
+
folder_path.mkdir(parents=True, exist_ok=True)
|
| 261 |
+
model_filename = _DEFAULT_MODEL_FILENAME
|
| 262 |
+
head_pipeline_path = folder_path / model_filename
|
| 263 |
+
skops.io.dump(pipeline.head, head_pipeline_path)
|
| 264 |
+
pipeline.model.save_pretrained(folder_path)
|
| 265 |
+
base_model_name = pipeline.model.base_model_name
|
| 266 |
+
if isinstance(base_model_name, list) and base_model_name:
|
| 267 |
+
name = base_model_name[0]
|
| 268 |
+
elif isinstance(base_model_name, str):
|
| 269 |
+
name = base_model_name
|
| 270 |
+
else:
|
| 271 |
+
name = "unknown"
|
| 272 |
+
_create_model_card(
|
| 273 |
+
folder_path,
|
| 274 |
+
base_model_name=name,
|
| 275 |
+
language=pipeline.model.language,
|
| 276 |
+
template_path="modelcards/classifier_template.md",
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _is_multi_label_shaped(y: LabelType) -> bool:
|
| 281 |
+
"""Check if the labels are in a multi-label shape."""
|
| 282 |
+
return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set))
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def evaluate_single_or_multi_label(
|
| 286 |
+
predictions: np.ndarray,
|
| 287 |
+
y: LabelType,
|
| 288 |
+
output_dict: bool = False,
|
| 289 |
+
) -> str | dict[str, dict[str, float]]:
|
| 290 |
+
"""
|
| 291 |
+
Evaluate the classifier on a given dataset using scikit-learn's classification report.
|
| 292 |
+
|
| 293 |
+
:param predictions: The predictions.
|
| 294 |
+
:param y: The ground truth labels.
|
| 295 |
+
:param output_dict: Whether to output the classification report as a dictionary.
|
| 296 |
+
:return: A classification report.
|
| 297 |
+
"""
|
| 298 |
+
if _is_multi_label_shaped(y):
|
| 299 |
+
classes = sorted({label for labels in y for label in labels})
|
| 300 |
+
mlb = MultiLabelBinarizer(classes=classes)
|
| 301 |
+
y = mlb.fit_transform(y)
|
| 302 |
+
predictions = mlb.transform(predictions)
|
| 303 |
+
elif isinstance(y[0], (str, int)):
|
| 304 |
+
classes = sorted(set(y))
|
| 305 |
+
|
| 306 |
+
return classification_report(
|
| 307 |
+
y,
|
| 308 |
+
predictions,
|
| 309 |
+
output_dict=output_dict,
|
| 310 |
+
zero_division=0,
|
| 311 |
+
)
|
| 312 |
+
|
src/distiller/model2vec/model.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
from logging import getLogger
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tempfile import TemporaryDirectory
|
| 8 |
+
from typing import TYPE_CHECKING, Any, Union, overload
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from joblib import delayed
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from .quantization import DType, quantize_and_reduce_dim
|
| 15 |
+
from .utils import ProgressParallel, load_local_model
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from collections.abc import Iterator, Sequence
|
| 19 |
+
|
| 20 |
+
from tokenizers import Encoding, Tokenizer
|
| 21 |
+
|
| 22 |
+
PathLike = Union[Path, str]
|
| 23 |
+
|
| 24 |
+
logger = getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class StaticModel:
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
vectors: np.ndarray,
|
| 31 |
+
tokenizer: Tokenizer,
|
| 32 |
+
config: dict[str, Any] | None = None,
|
| 33 |
+
normalize: bool | None = None,
|
| 34 |
+
base_model_name: str | None = None,
|
| 35 |
+
language: list[str] | None = None,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Initialize the StaticModel.
|
| 39 |
+
|
| 40 |
+
:param vectors: The vectors to use.
|
| 41 |
+
:param tokenizer: The Transformers tokenizer to use.
|
| 42 |
+
:param config: Any metadata config.
|
| 43 |
+
:param normalize: Whether to normalize the embeddings.
|
| 44 |
+
:param base_model_name: The used base model name. Used for creating a model card.
|
| 45 |
+
:param language: The language of the model. Used for creating a model card.
|
| 46 |
+
:raises: ValueError if the number of tokens does not match the number of vectors.
|
| 47 |
+
"""
|
| 48 |
+
super().__init__()
|
| 49 |
+
tokens, _ = zip(*sorted(tokenizer.get_vocab().items(), key=lambda x: x[1]), strict=False)
|
| 50 |
+
self.tokens = tokens
|
| 51 |
+
|
| 52 |
+
self.embedding = vectors
|
| 53 |
+
|
| 54 |
+
if len(tokens) != vectors.shape[0]:
|
| 55 |
+
msg = f"Number of tokens ({len(tokens)}) does not match number of vectors ({vectors.shape[0]})"
|
| 56 |
+
raise ValueError(msg)
|
| 57 |
+
|
| 58 |
+
self.tokenizer = tokenizer
|
| 59 |
+
self.unk_token_id: int | None
|
| 60 |
+
if hasattr(self.tokenizer.model, "unk_token") and self.tokenizer.model.unk_token is not None:
|
| 61 |
+
self.unk_token_id = tokenizer.get_vocab()[self.tokenizer.model.unk_token]
|
| 62 |
+
else:
|
| 63 |
+
self.unk_token_id = None # pragma: no cover # Doesn't actually happen, but can happen.
|
| 64 |
+
|
| 65 |
+
self.median_token_length = int(np.median([len(token) for token in self.tokens]))
|
| 66 |
+
self.config = config or {}
|
| 67 |
+
self.base_model_name = base_model_name
|
| 68 |
+
self.language = language
|
| 69 |
+
if hasattr(self.tokenizer, "encode_batch_fast"):
|
| 70 |
+
self._can_encode_fast = True
|
| 71 |
+
else:
|
| 72 |
+
self._can_encode_fast = False
|
| 73 |
+
|
| 74 |
+
if normalize is not None:
|
| 75 |
+
self.normalize = normalize
|
| 76 |
+
else:
|
| 77 |
+
self.normalize = self.config.get("normalize", False)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def dim(self) -> int:
|
| 81 |
+
"""Get the dimension of the model."""
|
| 82 |
+
return self.embedding.shape[1]
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def normalize(self) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Get the normalize value.
|
| 88 |
+
|
| 89 |
+
:return: The normalize value.
|
| 90 |
+
"""
|
| 91 |
+
return self._normalize
|
| 92 |
+
|
| 93 |
+
@normalize.setter
|
| 94 |
+
def normalize(self, value: bool) -> None:
|
| 95 |
+
"""Update the config if the value of normalize changes."""
|
| 96 |
+
config_normalize = self.config.get("normalize")
|
| 97 |
+
self._normalize = value
|
| 98 |
+
if config_normalize is not None and value != config_normalize:
|
| 99 |
+
logger.warning(
|
| 100 |
+
f"Set normalization to `{value}`, which does not match config value `{config_normalize}`. Updating config."
|
| 101 |
+
)
|
| 102 |
+
self.config["normalize"] = value
|
| 103 |
+
|
| 104 |
+
def save_pretrained(self, path: PathLike, model_name: str | None = None, subfolder: str | None = None) -> None:
|
| 105 |
+
"""
|
| 106 |
+
Save the pretrained model.
|
| 107 |
+
|
| 108 |
+
:param path: The path to save to.
|
| 109 |
+
:param model_name: The model name to use in the Model Card.
|
| 110 |
+
:param subfolder: The subfolder to save to.
|
| 111 |
+
"""
|
| 112 |
+
from .hf_utils import save_pretrained
|
| 113 |
+
|
| 114 |
+
save_pretrained(
|
| 115 |
+
folder_path=Path(path),
|
| 116 |
+
embeddings=self.embedding,
|
| 117 |
+
tokenizer=self.tokenizer,
|
| 118 |
+
config=self.config,
|
| 119 |
+
base_model_name=self.base_model_name,
|
| 120 |
+
language=self.language,
|
| 121 |
+
model_name=model_name,
|
| 122 |
+
subfolder=subfolder,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def tokenize(self, sentences: Sequence[str], max_length: int | None = None) -> list[list[int]]:
|
| 126 |
+
"""
|
| 127 |
+
Tokenize a list of sentences.
|
| 128 |
+
|
| 129 |
+
:param sentences: The sentences to tokenize.
|
| 130 |
+
:param max_length: The maximum length of the sentences in tokens. If this is None, sequences
|
| 131 |
+
are not truncated.
|
| 132 |
+
:return: A list of list of tokens.
|
| 133 |
+
"""
|
| 134 |
+
if max_length is not None:
|
| 135 |
+
m = max_length * self.median_token_length
|
| 136 |
+
sentences = [sentence[:m] for sentence in sentences]
|
| 137 |
+
|
| 138 |
+
if self._can_encode_fast:
|
| 139 |
+
encodings: list[Encoding] = self.tokenizer.encode_batch_fast(sentences, add_special_tokens=False)
|
| 140 |
+
else:
|
| 141 |
+
encodings = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
|
| 142 |
+
|
| 143 |
+
encodings_ids = [encoding.ids for encoding in encodings]
|
| 144 |
+
|
| 145 |
+
if self.unk_token_id is not None:
|
| 146 |
+
# NOTE: Remove the unknown token: necessary for word-level models.
|
| 147 |
+
encodings_ids = [
|
| 148 |
+
[token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
|
| 149 |
+
]
|
| 150 |
+
if max_length is not None:
|
| 151 |
+
encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
|
| 152 |
+
|
| 153 |
+
return encodings_ids
|
| 154 |
+
|
| 155 |
+
@classmethod
|
| 156 |
+
def from_pretrained(
|
| 157 |
+
cls: type[StaticModel],
|
| 158 |
+
path: PathLike,
|
| 159 |
+
token: str | None = None,
|
| 160 |
+
normalize: bool | None = None,
|
| 161 |
+
subfolder: str | None = None,
|
| 162 |
+
quantize_to: str | DType | None = None,
|
| 163 |
+
dimensionality: int | None = None,
|
| 164 |
+
) -> StaticModel:
|
| 165 |
+
"""
|
| 166 |
+
Load a StaticModel from a local path or huggingface hub path.
|
| 167 |
+
|
| 168 |
+
NOTE: if you load a private model from the huggingface hub, you need to pass a token.
|
| 169 |
+
|
| 170 |
+
:param path: The path to load your static model from.
|
| 171 |
+
:param token: The huggingface token to use.
|
| 172 |
+
:param normalize: Whether to normalize the embeddings.
|
| 173 |
+
:param subfolder: The subfolder to load from.
|
| 174 |
+
:param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
|
| 175 |
+
If a string is passed, it is converted to a DType.
|
| 176 |
+
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
|
| 177 |
+
This is useful if you want to load a model with a lower dimensionality.
|
| 178 |
+
Note that this only applies if you have trained your model using mrl or PCA.
|
| 179 |
+
:return: A StaticModel.
|
| 180 |
+
"""
|
| 181 |
+
from .hf_utils import load_pretrained
|
| 182 |
+
|
| 183 |
+
embeddings, tokenizer, config, metadata = load_pretrained(
|
| 184 |
+
folder_or_repo_path=path,
|
| 185 |
+
token=token,
|
| 186 |
+
from_sentence_transformers=False,
|
| 187 |
+
subfolder=subfolder,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
embeddings = quantize_and_reduce_dim(
|
| 191 |
+
embeddings=embeddings,
|
| 192 |
+
quantize_to=quantize_to,
|
| 193 |
+
dimensionality=dimensionality,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return cls(
|
| 197 |
+
embeddings,
|
| 198 |
+
tokenizer,
|
| 199 |
+
config,
|
| 200 |
+
normalize=normalize,
|
| 201 |
+
base_model_name=metadata.get("base_model"),
|
| 202 |
+
language=metadata.get("language"),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
@classmethod
|
| 206 |
+
def from_sentence_transformers(
|
| 207 |
+
cls: type[StaticModel],
|
| 208 |
+
path: PathLike,
|
| 209 |
+
token: str | None = None,
|
| 210 |
+
normalize: bool | None = None,
|
| 211 |
+
quantize_to: str | DType | None = None,
|
| 212 |
+
dimensionality: int | None = None,
|
| 213 |
+
) -> StaticModel:
|
| 214 |
+
"""
|
| 215 |
+
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
|
| 216 |
+
|
| 217 |
+
NOTE: if you load a private model from the huggingface hub, you need to pass a token.
|
| 218 |
+
|
| 219 |
+
:param path: The path to load your static model from.
|
| 220 |
+
:param token: The huggingface token to use.
|
| 221 |
+
:param normalize: Whether to normalize the embeddings.
|
| 222 |
+
:param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
|
| 223 |
+
If a string is passed, it is converted to a DType.
|
| 224 |
+
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
|
| 225 |
+
This is useful if you want to load a model with a lower dimensionality.
|
| 226 |
+
Note that this only applies if you have trained your model using mrl or PCA.
|
| 227 |
+
:return: A StaticModel.
|
| 228 |
+
"""
|
| 229 |
+
from .hf_utils import load_pretrained
|
| 230 |
+
|
| 231 |
+
embeddings, tokenizer, config, metadata = load_pretrained(
|
| 232 |
+
folder_or_repo_path=path,
|
| 233 |
+
token=token,
|
| 234 |
+
from_sentence_transformers=True,
|
| 235 |
+
subfolder=None,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
embeddings = quantize_and_reduce_dim(
|
| 239 |
+
embeddings=embeddings,
|
| 240 |
+
quantize_to=quantize_to,
|
| 241 |
+
dimensionality=dimensionality,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
return cls(
|
| 245 |
+
embeddings,
|
| 246 |
+
tokenizer,
|
| 247 |
+
config,
|
| 248 |
+
normalize=normalize,
|
| 249 |
+
base_model_name=metadata.get("base_model"),
|
| 250 |
+
language=metadata.get("language"),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
@overload
|
| 254 |
+
def encode_as_sequence(
|
| 255 |
+
self,
|
| 256 |
+
sentences: str,
|
| 257 |
+
max_length: int | None = None,
|
| 258 |
+
batch_size: int = 1024,
|
| 259 |
+
show_progress_bar: bool = False,
|
| 260 |
+
use_multiprocessing: bool = True,
|
| 261 |
+
multiprocessing_threshold: int = 10_000,
|
| 262 |
+
) -> np.ndarray: ...
|
| 263 |
+
|
| 264 |
+
@overload
|
| 265 |
+
def encode_as_sequence(
|
| 266 |
+
self,
|
| 267 |
+
sentences: list[str],
|
| 268 |
+
max_length: int | None = None,
|
| 269 |
+
batch_size: int = 1024,
|
| 270 |
+
show_progress_bar: bool = False,
|
| 271 |
+
use_multiprocessing: bool = True,
|
| 272 |
+
multiprocessing_threshold: int = 10_000,
|
| 273 |
+
) -> list[np.ndarray]: ...
|
| 274 |
+
|
| 275 |
+
def encode_as_sequence(
|
| 276 |
+
self,
|
| 277 |
+
sentences: str | list[str],
|
| 278 |
+
max_length: int | None = None,
|
| 279 |
+
batch_size: int = 1024,
|
| 280 |
+
show_progress_bar: bool = False,
|
| 281 |
+
use_multiprocessing: bool = True,
|
| 282 |
+
multiprocessing_threshold: int = 10_000,
|
| 283 |
+
) -> list[np.ndarray] | np.ndarray:
|
| 284 |
+
"""
|
| 285 |
+
Encode a list of sentences as a list of numpy arrays of tokens.
|
| 286 |
+
|
| 287 |
+
This is useful if you want to use the tokens for further processing, or if you want to do sequence
|
| 288 |
+
modeling.
|
| 289 |
+
Note that if you just want the mean, you should use the `encode` method.
|
| 290 |
+
This is about twice as slow.
|
| 291 |
+
Sentences that do not contain any tokens will be turned into an empty array.
|
| 292 |
+
|
| 293 |
+
NOTE: the input type is currently underspecified. The actual input type is `Sequence[str] | str`, but this
|
| 294 |
+
is not possible to implement in python typing currently.
|
| 295 |
+
|
| 296 |
+
:param sentences: The list of sentences to encode.
|
| 297 |
+
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
|
| 298 |
+
If this is None, no truncation is done.
|
| 299 |
+
:param batch_size: The batch size to use.
|
| 300 |
+
:param show_progress_bar: Whether to show the progress bar.
|
| 301 |
+
:param use_multiprocessing: Whether to use multiprocessing.
|
| 302 |
+
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
|
| 303 |
+
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
|
| 304 |
+
:return: The encoded sentences with an embedding per token.
|
| 305 |
+
"""
|
| 306 |
+
was_single = False
|
| 307 |
+
if isinstance(sentences, str):
|
| 308 |
+
sentences = [sentences]
|
| 309 |
+
was_single = True
|
| 310 |
+
|
| 311 |
+
# Prepare all batches
|
| 312 |
+
sentence_batches = list(self._batch(sentences, batch_size))
|
| 313 |
+
total_batches = math.ceil(len(sentences) / batch_size)
|
| 314 |
+
|
| 315 |
+
# Use joblib for multiprocessing if requested, and if we have enough sentences
|
| 316 |
+
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
|
| 317 |
+
# Disable parallelism for tokenizers
|
| 318 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 319 |
+
|
| 320 |
+
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
|
| 321 |
+
delayed(self._encode_batch_as_sequence)(batch, max_length) for batch in sentence_batches
|
| 322 |
+
)
|
| 323 |
+
out_array: list[np.ndarray] = []
|
| 324 |
+
for r in results:
|
| 325 |
+
out_array.extend(r)
|
| 326 |
+
else:
|
| 327 |
+
out_array = []
|
| 328 |
+
for batch in tqdm(
|
| 329 |
+
sentence_batches,
|
| 330 |
+
total=total_batches,
|
| 331 |
+
disable=not show_progress_bar,
|
| 332 |
+
):
|
| 333 |
+
out_array.extend(self._encode_batch_as_sequence(batch, max_length))
|
| 334 |
+
|
| 335 |
+
if was_single:
|
| 336 |
+
return out_array[0]
|
| 337 |
+
return out_array
|
| 338 |
+
|
| 339 |
+
def _encode_batch_as_sequence(self, sentences: Sequence[str], max_length: int | None) -> list[np.ndarray]:
|
| 340 |
+
"""Encode a batch of sentences as a sequence."""
|
| 341 |
+
ids = self.tokenize(sentences=sentences, max_length=max_length)
|
| 342 |
+
out: list[np.ndarray] = []
|
| 343 |
+
for id_list in ids:
|
| 344 |
+
if id_list:
|
| 345 |
+
out.append(self.embedding[id_list])
|
| 346 |
+
else:
|
| 347 |
+
out.append(np.zeros((0, self.dim)))
|
| 348 |
+
|
| 349 |
+
return out
|
| 350 |
+
|
| 351 |
+
def encode(
|
| 352 |
+
self,
|
| 353 |
+
sentences: Sequence[str],
|
| 354 |
+
show_progress_bar: bool = False,
|
| 355 |
+
max_length: int | None = 512,
|
| 356 |
+
batch_size: int = 1024,
|
| 357 |
+
use_multiprocessing: bool = True,
|
| 358 |
+
multiprocessing_threshold: int = 10_000,
|
| 359 |
+
**kwargs: Any,
|
| 360 |
+
) -> np.ndarray:
|
| 361 |
+
"""
|
| 362 |
+
Encode a list of sentences.
|
| 363 |
+
|
| 364 |
+
This function encodes a list of sentences by averaging the word embeddings of the tokens in the sentence.
|
| 365 |
+
For ease of use, we don't batch sentences together.
|
| 366 |
+
|
| 367 |
+
NOTE: the return type is currently underspecified. In the case of a single string, this returns a 1D array,
|
| 368 |
+
but in the case of a list of strings, this returns a 2D array. Not possible to implement in numpy currently.
|
| 369 |
+
|
| 370 |
+
:param sentences: The list of sentences to encode. You can also pass a single sentence.
|
| 371 |
+
:param show_progress_bar: Whether to show the progress bar.
|
| 372 |
+
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
|
| 373 |
+
If this is None, no truncation is done.
|
| 374 |
+
:param batch_size: The batch size to use.
|
| 375 |
+
:param use_multiprocessing: Whether to use multiprocessing.
|
| 376 |
+
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
|
| 377 |
+
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
|
| 378 |
+
:param **kwargs: Any additional arguments. These are ignored.
|
| 379 |
+
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
|
| 380 |
+
"""
|
| 381 |
+
was_single = False
|
| 382 |
+
if isinstance(sentences, str):
|
| 383 |
+
sentences = [sentences]
|
| 384 |
+
was_single = True
|
| 385 |
+
|
| 386 |
+
# Prepare all batches
|
| 387 |
+
sentence_batches = list(self._batch(sentences, batch_size))
|
| 388 |
+
total_batches = math.ceil(len(sentences) / batch_size)
|
| 389 |
+
|
| 390 |
+
# Use joblib for multiprocessing if requested, and if we have enough sentences
|
| 391 |
+
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
|
| 392 |
+
# Disable parallelism for tokenizers
|
| 393 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 394 |
+
|
| 395 |
+
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
|
| 396 |
+
delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
|
| 397 |
+
)
|
| 398 |
+
out_array = np.concatenate(results, axis=0)
|
| 399 |
+
else:
|
| 400 |
+
# Don't use multiprocessing
|
| 401 |
+
out_arrays: list[np.ndarray] = []
|
| 402 |
+
for batch in tqdm(
|
| 403 |
+
sentence_batches,
|
| 404 |
+
total=total_batches,
|
| 405 |
+
disable=not show_progress_bar,
|
| 406 |
+
):
|
| 407 |
+
out_arrays.append(self._encode_batch(batch, max_length))
|
| 408 |
+
out_array = np.concatenate(out_arrays, axis=0)
|
| 409 |
+
|
| 410 |
+
if was_single:
|
| 411 |
+
return out_array[0]
|
| 412 |
+
return out_array
|
| 413 |
+
|
| 414 |
+
def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.ndarray:
|
| 415 |
+
"""Encode a batch of sentences."""
|
| 416 |
+
ids = self.tokenize(sentences=sentences, max_length=max_length)
|
| 417 |
+
out: list[np.ndarray] = []
|
| 418 |
+
for id_list in ids:
|
| 419 |
+
if id_list:
|
| 420 |
+
out.append(self.embedding[id_list].mean(0))
|
| 421 |
+
else:
|
| 422 |
+
out.append(np.zeros(self.dim))
|
| 423 |
+
|
| 424 |
+
out_array = np.stack(out)
|
| 425 |
+
if self.normalize:
|
| 426 |
+
norm = np.linalg.norm(out_array, axis=1, keepdims=True) + 1e-32
|
| 427 |
+
out_array = out_array / norm
|
| 428 |
+
|
| 429 |
+
return out_array
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def _batch(sentences: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]:
|
| 433 |
+
"""Batch the sentences into equal-sized."""
|
| 434 |
+
return (sentences[i : i + batch_size] for i in range(0, len(sentences), batch_size))
|
| 435 |
+
|
| 436 |
+
def push_to_hub(
|
| 437 |
+
self, repo_id: str, private: bool = False, token: str | None = None, subfolder: str | None = None
|
| 438 |
+
) -> None:
|
| 439 |
+
"""
|
| 440 |
+
Push the model to the huggingface hub.
|
| 441 |
+
|
| 442 |
+
NOTE: you need to pass a token if you are pushing a private model.
|
| 443 |
+
|
| 444 |
+
:param repo_id: The repo id to push to.
|
| 445 |
+
:param private: Whether the repo, if created is set to private.
|
| 446 |
+
If the repo already exists, this doesn't change the visibility.
|
| 447 |
+
:param token: The huggingface token to use.
|
| 448 |
+
:param subfolder: The subfolder to push to.
|
| 449 |
+
"""
|
| 450 |
+
from .hf_utils import push_folder_to_hub
|
| 451 |
+
|
| 452 |
+
with TemporaryDirectory() as temp_dir:
|
| 453 |
+
self.save_pretrained(temp_dir, model_name=repo_id)
|
| 454 |
+
push_folder_to_hub(Path(temp_dir), subfolder=subfolder, repo_id=repo_id, private=private, token=token)
|
| 455 |
+
|
| 456 |
+
@classmethod
|
| 457 |
+
def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
|
| 458 |
+
"""
|
| 459 |
+
Loads a model from a local path.
|
| 460 |
+
|
| 461 |
+
You should only use this code path if you are concerned with start-up time.
|
| 462 |
+
Loading via the `from_pretrained` method is safer, and auto-downloads, but
|
| 463 |
+
also means we import a whole bunch of huggingface code that we don't need.
|
| 464 |
+
|
| 465 |
+
Additionally, huggingface will check the most recent version of the model,
|
| 466 |
+
which can be slow.
|
| 467 |
+
|
| 468 |
+
:param path: The path to load the model from. The path is a directory saved by the
|
| 469 |
+
`save_pretrained` method.
|
| 470 |
+
:return: A StaticModel
|
| 471 |
+
:raises: ValueError if the path is not a directory.
|
| 472 |
+
"""
|
| 473 |
+
path = Path(path)
|
| 474 |
+
if not path.is_dir():
|
| 475 |
+
msg = f"Path {path} is not a directory."
|
| 476 |
+
raise ValueError(msg)
|
| 477 |
+
|
| 478 |
+
embeddings, tokenizer, config = load_local_model(path)
|
| 479 |
+
|
| 480 |
+
return StaticModel(embeddings, tokenizer, config)
|
src/distiller/model2vec/modelcards/classifier_template.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
{{ card_data }}
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# {{ model_name }} Model Card
|
| 6 |
+
|
| 7 |
+
This [Model2Vec](https://github.com/MinishLab/model2vec) model is a fine-tuned version of {% if base_model %}the [{{ base_model }}](https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Model2Vec model. It also includes a classifier head on top.
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
|
| 11 |
+
Install model2vec using pip:
|
| 12 |
+
```
|
| 13 |
+
pip install model2vec[inference]
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
Load this model using the `from_pretrained` method:
|
| 18 |
+
```python
|
| 19 |
+
from model2vec.inference import StaticModelPipeline
|
| 20 |
+
|
| 21 |
+
# Load a pretrained Model2Vec model
|
| 22 |
+
model = StaticModelPipeline.from_pretrained("{{ model_name }}")
|
| 23 |
+
|
| 24 |
+
# Predict labels
|
| 25 |
+
predicted = model.predict(["Example sentence"])
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Additional Resources
|
| 29 |
+
|
| 30 |
+
- [Model2Vec Repo](https://github.com/MinishLab/model2vec)
|
| 31 |
+
- [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
|
| 32 |
+
- [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
|
| 33 |
+
- [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials)
|
| 34 |
+
- [Website](https://minishlab.github.io/)
|
| 35 |
+
|
| 36 |
+
## Library Authors
|
| 37 |
+
|
| 38 |
+
Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
|
| 39 |
+
|
| 40 |
+
## Citation
|
| 41 |
+
|
| 42 |
+
Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
|
| 43 |
+
```
|
| 44 |
+
@article{minishlab2024model2vec,
|
| 45 |
+
author = {Tulkens, Stephan and {van Dongen}, Thomas},
|
| 46 |
+
title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
|
| 47 |
+
year = {2024},
|
| 48 |
+
url = {https://github.com/MinishLab/model2vec}
|
| 49 |
+
}
|
| 50 |
+
```
|
src/distiller/model2vec/modelcards/model_card_template.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
{{ card_data }}
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# {{ model_name }} Model Card
|
| 6 |
+
|
| 7 |
+
This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of {% if base_model %}the {{ base_model }}(https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## Installation
|
| 11 |
+
|
| 12 |
+
Install model2vec using pip:
|
| 13 |
+
```
|
| 14 |
+
pip install model2vec
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Usage
|
| 18 |
+
|
| 19 |
+
### Using Model2Vec
|
| 20 |
+
|
| 21 |
+
The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models.
|
| 22 |
+
|
| 23 |
+
Load this model using the `from_pretrained` method:
|
| 24 |
+
```python
|
| 25 |
+
from model2vec import StaticModel
|
| 26 |
+
|
| 27 |
+
# Load a pretrained Model2Vec model
|
| 28 |
+
model = StaticModel.from_pretrained("{{ model_name }}")
|
| 29 |
+
|
| 30 |
+
# Compute text embeddings
|
| 31 |
+
embeddings = model.encode(["Example sentence"])
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### Using Sentence Transformers
|
| 35 |
+
|
| 36 |
+
You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model:
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
from sentence_transformers import SentenceTransformer
|
| 40 |
+
|
| 41 |
+
# Load a pretrained Sentence Transformer model
|
| 42 |
+
model = SentenceTransformer("{{ model_name }}")
|
| 43 |
+
|
| 44 |
+
# Compute text embeddings
|
| 45 |
+
embeddings = model.encode(["Example sentence"])
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Distilling a Model2Vec model
|
| 49 |
+
|
| 50 |
+
You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code:
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
from model2vec.distill import distill
|
| 54 |
+
|
| 55 |
+
# Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model
|
| 56 |
+
m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256)
|
| 57 |
+
|
| 58 |
+
# Save the model
|
| 59 |
+
m2v_model.save_pretrained("m2v_model")
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## How it works
|
| 63 |
+
|
| 64 |
+
Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec.
|
| 65 |
+
|
| 66 |
+
It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence.
|
| 67 |
+
|
| 68 |
+
## Additional Resources
|
| 69 |
+
|
| 70 |
+
- [Model2Vec Repo](https://github.com/MinishLab/model2vec)
|
| 71 |
+
- [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
|
| 72 |
+
- [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
|
| 73 |
+
- [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials)
|
| 74 |
+
- [Website](https://minishlab.github.io/)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
## Library Authors
|
| 78 |
+
|
| 79 |
+
Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
|
| 80 |
+
|
| 81 |
+
## Citation
|
| 82 |
+
|
| 83 |
+
Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
|
| 84 |
+
```
|
| 85 |
+
@article{minishlab2024model2vec,
|
| 86 |
+
author = {Tulkens, Stephan and {van Dongen}, Thomas},
|
| 87 |
+
title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
|
| 88 |
+
year = {2024},
|
| 89 |
+
url = {https://github.com/MinishLab/model2vec}
|
| 90 |
+
}
|
| 91 |
+
```
|
src/distiller/model2vec/py.typed
ADDED
|
File without changes
|
src/distiller/model2vec/quantization.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DType(str, Enum):
|
| 9 |
+
Float16 = "float16"
|
| 10 |
+
Float32 = "float32"
|
| 11 |
+
Float64 = "float64"
|
| 12 |
+
Int8 = "int8"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
|
| 16 |
+
"""
|
| 17 |
+
Quantize embeddings to a specified data type to reduce memory usage.
|
| 18 |
+
|
| 19 |
+
:param embeddings: The embeddings to quantize, as a numpy array.
|
| 20 |
+
:param quantize_to: The data type to quantize to.
|
| 21 |
+
:return: The quantized embeddings.
|
| 22 |
+
:raises ValueError: If the quantization type is not valid.
|
| 23 |
+
"""
|
| 24 |
+
if quantize_to == DType.Float16:
|
| 25 |
+
return embeddings.astype(np.float16)
|
| 26 |
+
if quantize_to == DType.Float32:
|
| 27 |
+
return embeddings.astype(np.float32)
|
| 28 |
+
if quantize_to == DType.Float64:
|
| 29 |
+
return embeddings.astype(np.float64)
|
| 30 |
+
if quantize_to == DType.Int8:
|
| 31 |
+
# Normalize to [-128, 127] range for int8
|
| 32 |
+
# We normalize to -127 to 127 to keep symmetry.
|
| 33 |
+
scale = np.max(np.abs(embeddings)) / 127.0
|
| 34 |
+
return np.round(embeddings / scale).astype(np.int8)
|
| 35 |
+
msg = "Not a valid enum member of DType."
|
| 36 |
+
raise ValueError(msg)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def quantize_and_reduce_dim(
|
| 40 |
+
embeddings: np.ndarray, quantize_to: str | DType | None, dimensionality: int | None
|
| 41 |
+
) -> np.ndarray:
|
| 42 |
+
"""
|
| 43 |
+
Quantize embeddings to a datatype and reduce dimensionality.
|
| 44 |
+
|
| 45 |
+
:param embeddings: The embeddings to quantize and reduce, as a numpy array.
|
| 46 |
+
:param quantize_to: The data type to quantize to. If None, no quantization is performed.
|
| 47 |
+
:param dimensionality: The number of dimensions to keep. If None, no dimensionality reduction is performed.
|
| 48 |
+
:return: The quantized and reduced embeddings.
|
| 49 |
+
:raises ValueError: If the passed dimensionality is not None and greater than the model dimensionality.
|
| 50 |
+
"""
|
| 51 |
+
if quantize_to is not None:
|
| 52 |
+
quantize_to = DType(quantize_to)
|
| 53 |
+
embeddings = quantize_embeddings(embeddings, quantize_to)
|
| 54 |
+
|
| 55 |
+
if dimensionality is not None:
|
| 56 |
+
if dimensionality > embeddings.shape[1]:
|
| 57 |
+
msg = f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
|
| 58 |
+
raise ValueError(
|
| 59 |
+
msg
|
| 60 |
+
)
|
| 61 |
+
embeddings = embeddings[:, :dimensionality]
|
| 62 |
+
|
| 63 |
+
return embeddings
|
src/distiller/model2vec/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from distiller.model2vec.utils import importable
|
| 2 |
+
|
| 3 |
+
importable("transformers", "tokenizer")
|
| 4 |
+
|
| 5 |
+
from distiller.model2vec.tokenizer.tokenizer import (
|
| 6 |
+
clean_and_create_vocabulary,
|
| 7 |
+
create_tokenizer,
|
| 8 |
+
replace_vocabulary,
|
| 9 |
+
turn_tokens_into_ids,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = ["clean_and_create_vocabulary", "create_tokenizer", "replace_vocabulary", "turn_tokens_into_ids"]
|
src/distiller/model2vec/tokenizer/datamodels.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class Token:
|
| 6 |
+
"""A class to represent a token."""
|
| 7 |
+
|
| 8 |
+
form: str
|
| 9 |
+
# The normalized and pretokenized form of the token
|
| 10 |
+
normalized_form: str
|
| 11 |
+
# Whether the word is a continuing subword.
|
| 12 |
+
is_subword: bool
|
| 13 |
+
# Whether the token is internal to the model.
|
| 14 |
+
is_internal: bool
|
src/distiller/model2vec/tokenizer/model.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def process_tokenizer(
|
| 9 |
+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
|
| 10 |
+
) -> dict[str, Any]:
|
| 11 |
+
"""Process the WordPiece tokenizer JSON."""
|
| 12 |
+
if tokenizer_json["model"]["type"] == "Unigram":
|
| 13 |
+
return _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token)
|
| 14 |
+
tokenizer_json["model"]["type"] = "Unigram"
|
| 15 |
+
tokenizer_json["model"]["unk_id"] = pre_tokenized_tokens.index(unk_token) if unk_token else None
|
| 16 |
+
|
| 17 |
+
token_weights = np.asarray([_calculate_token_weight_for_unigram(token) for token in pre_tokenized_tokens])
|
| 18 |
+
proba = (token_weights / np.sum(token_weights)).tolist()
|
| 19 |
+
tokenizer_json["model"]["vocab"] = [(token, np.log(p)) for token, p in zip(pre_tokenized_tokens, proba, strict=False)]
|
| 20 |
+
|
| 21 |
+
return tokenizer_json
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _process_unigram(
|
| 25 |
+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
|
| 26 |
+
) -> dict[str, Any]:
|
| 27 |
+
"""Process the Unigram tokenizer JSON."""
|
| 28 |
+
current_probas = dict(tokenizer_json["model"]["vocab"])
|
| 29 |
+
avg_proba = sum(current_probas.values()) / len(current_probas)
|
| 30 |
+
new_probas = [[word, current_probas.get(word, avg_proba)] for word in pre_tokenized_tokens]
|
| 31 |
+
tokenizer_json["model"]["vocab"] = new_probas
|
| 32 |
+
|
| 33 |
+
tokens, _ = zip(*tokenizer_json["model"]["vocab"], strict=False)
|
| 34 |
+
if unk_token is not None:
|
| 35 |
+
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token)
|
| 36 |
+
|
| 37 |
+
return tokenizer_json
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _calculate_token_weight_for_unigram(token: str) -> float:
|
| 41 |
+
"""Calculate the token weight for Unigram."""
|
| 42 |
+
# Always prefer longer tokens.
|
| 43 |
+
return len(token) + token.count("▁") + token.count("Ġ")
|
src/distiller/model2vec/tokenizer/normalizer.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from string import punctuation
|
| 2 |
+
|
| 3 |
+
from tokenizers import Regex, Tokenizer
|
| 4 |
+
from tokenizers.normalizers import Replace, Sequence, Strip
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def replace_normalizer(
|
| 8 |
+
tokenizer: Tokenizer,
|
| 9 |
+
) -> Tokenizer:
|
| 10 |
+
"""
|
| 11 |
+
Replace the normalizer for the tokenizer.
|
| 12 |
+
|
| 13 |
+
The new normalizer will replace punctuation with a space before and after the punctuation.
|
| 14 |
+
It will also replace multiple spaces with a single space and strip the right side of the string.
|
| 15 |
+
If the tokenizer already has a normalizer, it will be added to the new normalizer.
|
| 16 |
+
If the tokenizer does not have a normalizer, a new normalizer will be created.
|
| 17 |
+
|
| 18 |
+
:param tokenizer: The tokenizer to change.
|
| 19 |
+
:return: The tokenizer with a replaced normalizer.
|
| 20 |
+
"""
|
| 21 |
+
normalizer = tokenizer.normalizer
|
| 22 |
+
new_normalizers = []
|
| 23 |
+
for char in punctuation:
|
| 24 |
+
new_normalizers.append(Replace(char, f" {char} "))
|
| 25 |
+
|
| 26 |
+
new_normalizers.append(Replace(Regex(r"\s+"), " "))
|
| 27 |
+
new_normalizers.append(Strip(right=True))
|
| 28 |
+
if normalizer is None:
|
| 29 |
+
normalizer = Sequence(new_normalizers) # type: ignore
|
| 30 |
+
else:
|
| 31 |
+
normalizer = Sequence([normalizer, *new_normalizers]) # type: ignore
|
| 32 |
+
tokenizer.normalizer = normalizer # type: ignore
|
| 33 |
+
|
| 34 |
+
return tokenizer
|
src/distiller/model2vec/tokenizer/pretokenizer.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import TYPE_CHECKING, Any
|
| 5 |
+
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
from tokenizers import Tokenizer
|
| 8 |
+
|
| 9 |
+
_FORBIDDEN_PRETOKENIZERS = (
|
| 10 |
+
"WhiteSpace",
|
| 11 |
+
"WhitespaceSplit",
|
| 12 |
+
"BertPreTokenizer",
|
| 13 |
+
"CharDelimiterSplit",
|
| 14 |
+
"Punctuation",
|
| 15 |
+
"Split",
|
| 16 |
+
"UnicodeScripts",
|
| 17 |
+
)
|
| 18 |
+
_BASIC_METASPACE = {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] | None:
|
| 22 |
+
"""Fixes a single pretokenizer to allow multiword units."""
|
| 23 |
+
if pre_tokenizer["type"] in _FORBIDDEN_PRETOKENIZERS:
|
| 24 |
+
return None
|
| 25 |
+
if pre_tokenizer["type"] == "ByteLevel":
|
| 26 |
+
pre_tokenizer["add_prefix_space"] = True
|
| 27 |
+
pre_tokenizer["use_regex"] = False
|
| 28 |
+
if pre_tokenizer["type"] == "Metaspace":
|
| 29 |
+
pre_tokenizer["split"] = False
|
| 30 |
+
pre_tokenizer["prepend_scheme"] = "always"
|
| 31 |
+
|
| 32 |
+
return pre_tokenizer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def replace_pretokenizer(tokenizer: Tokenizer) -> Tokenizer:
|
| 36 |
+
"""Fixes a single pretokenizer to allow multiword units."""
|
| 37 |
+
tokenizer_json = json.loads(tokenizer.to_str())
|
| 38 |
+
pre_tokenizer_json = tokenizer_json.get("pre_tokenizer", None)
|
| 39 |
+
|
| 40 |
+
if pre_tokenizer_json is None:
|
| 41 |
+
pre_tokenizer_json = _BASIC_METASPACE
|
| 42 |
+
|
| 43 |
+
elif pre_tokenizer_json["type"] == "Sequence":
|
| 44 |
+
new_pretokenizers = []
|
| 45 |
+
for single_pretokenizer in pre_tokenizer_json["pretokenizers"]:
|
| 46 |
+
new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer)
|
| 47 |
+
if new_pretokenizer is not None:
|
| 48 |
+
new_pretokenizers.append(new_pretokenizer)
|
| 49 |
+
|
| 50 |
+
if new_pretokenizers:
|
| 51 |
+
pre_tokenizer_json["pretokenizers"] = new_pretokenizers
|
| 52 |
+
else:
|
| 53 |
+
pre_tokenizer_json = _BASIC_METASPACE
|
| 54 |
+
|
| 55 |
+
pre_tokenizer_json = _fix_single_pretokenizer(pre_tokenizer_json) or _BASIC_METASPACE
|
| 56 |
+
tokenizer_json["pre_tokenizer"] = pre_tokenizer_json
|
| 57 |
+
|
| 58 |
+
return tokenizer.from_str(json.dumps(tokenizer_json))
|
src/distiller/model2vec/tokenizer/tokenizer.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from typing import TYPE_CHECKING, Any, cast
|
| 6 |
+
|
| 7 |
+
from tokenizers import Tokenizer
|
| 8 |
+
from transformers import PreTrainedTokenizerFast
|
| 9 |
+
|
| 10 |
+
from distiller.model2vec.tokenizer.datamodels import Token
|
| 11 |
+
from distiller.model2vec.tokenizer.model import process_tokenizer
|
| 12 |
+
from distiller.model2vec.tokenizer.normalizer import replace_normalizer
|
| 13 |
+
from distiller.model2vec.tokenizer.pretokenizer import replace_pretokenizer
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
from tokenizers.normalizers import Normalizer
|
| 19 |
+
from tokenizers.pre_tokenizers import (
|
| 20 |
+
PreTokenizer,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_DEFAULT_POST_PROCESSOR_TEMPLATE = {
|
| 27 |
+
"type": "TemplateProcessing",
|
| 28 |
+
"single": [{"Sequence": {"id": "A", "type_id": 0}}],
|
| 29 |
+
"pair": [{"Sequence": {"id": "A", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 0}}],
|
| 30 |
+
"special_tokens": {},
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _remap_added_tokens(
|
| 35 |
+
special_tokens: list[dict[str, Any]],
|
| 36 |
+
vocabulary: list[str],
|
| 37 |
+
) -> list[dict[str, Any]]:
|
| 38 |
+
"""
|
| 39 |
+
Remap special tokens in the tokenizer.
|
| 40 |
+
|
| 41 |
+
This function updates the special tokens in the tokenizer based on a mapping provided.
|
| 42 |
+
It also ensures that the special tokens are present in the vocabulary.
|
| 43 |
+
|
| 44 |
+
:param special_tokens: The special tokens to remap.
|
| 45 |
+
:param vocabulary: The vocabulary as a list of tokens.
|
| 46 |
+
:return: The updated special tokens.
|
| 47 |
+
"""
|
| 48 |
+
# Deepcopy
|
| 49 |
+
special_tokens = [{**x} for x in special_tokens]
|
| 50 |
+
for token in special_tokens:
|
| 51 |
+
token["id"] = vocabulary.index(token["content"])
|
| 52 |
+
|
| 53 |
+
return special_tokens
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def replace_vocabulary(
|
| 57 |
+
tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
|
| 58 |
+
) -> Tokenizer:
|
| 59 |
+
"""Replace the vocabulary of a tokenizer with a new one."""
|
| 60 |
+
tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str())
|
| 61 |
+
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]
|
| 62 |
+
|
| 63 |
+
pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary]
|
| 64 |
+
|
| 65 |
+
# We need to remove the added tokens but keep [UNK] and [PAD] tokens.
|
| 66 |
+
added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens)
|
| 67 |
+
added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens)
|
| 68 |
+
|
| 69 |
+
# Remove old added tokens from added tokens
|
| 70 |
+
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
|
| 71 |
+
tokenizer_json = process_tokenizer(
|
| 72 |
+
tokenizer_json, pre_tokenized_tokens, "[UNK]" if "[UNK]" in pre_tokenized_tokens else None
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Remap special tokens
|
| 76 |
+
tokenizer_json["added_tokens"] = _remap_added_tokens(
|
| 77 |
+
special_tokens=tokenizer_json["added_tokens"],
|
| 78 |
+
vocabulary=pre_tokenized_tokens,
|
| 79 |
+
)
|
| 80 |
+
tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE
|
| 81 |
+
|
| 82 |
+
return Tokenizer.from_str(json.dumps(tokenizer_json))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _rename_added_token(
|
| 86 |
+
form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str]
|
| 87 |
+
) -> list[dict[str, Any]]:
|
| 88 |
+
"""Rename added tokens in the tokenizer."""
|
| 89 |
+
if form is None:
|
| 90 |
+
return added_tokens
|
| 91 |
+
|
| 92 |
+
idx = vocabulary.index(form)
|
| 93 |
+
added_token = [x for x in added_tokens if x["content"] == form]
|
| 94 |
+
if added_token:
|
| 95 |
+
added_token[0]["id"] = idx
|
| 96 |
+
added_token[0]["content"] = new_form
|
| 97 |
+
vocabulary[idx] = new_form
|
| 98 |
+
|
| 99 |
+
return added_tokens
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def clean_and_create_vocabulary(
|
| 103 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 104 |
+
vocabulary: list[str],
|
| 105 |
+
token_remove_regex: re.Pattern | None,
|
| 106 |
+
) -> tuple[list[Token], Tokenizer]:
|
| 107 |
+
"""Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
|
| 108 |
+
seen_tokens = set()
|
| 109 |
+
post_normalize_seen_tokens = set()
|
| 110 |
+
n_empty = 0
|
| 111 |
+
n_duplicates = 0
|
| 112 |
+
|
| 113 |
+
backend_tokenizer = tokenizer.backend_tokenizer
|
| 114 |
+
|
| 115 |
+
# Make a base list of tokens.
|
| 116 |
+
internal_vocab: dict[str, int] = tokenizer.get_vocab()
|
| 117 |
+
internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])]
|
| 118 |
+
|
| 119 |
+
cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
|
| 120 |
+
# Copy the backend tokenizer to avoid modifying the original.
|
| 121 |
+
backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str())
|
| 122 |
+
backend_tokenizer = replace_normalizer(backend_tokenizer)
|
| 123 |
+
|
| 124 |
+
internal_tokens_set = {token.form for token in cleaned_vocabulary}
|
| 125 |
+
|
| 126 |
+
normalizer: Normalizer | None = backend_tokenizer.normalizer
|
| 127 |
+
for token in vocabulary:
|
| 128 |
+
if normalizer is not None:
|
| 129 |
+
token = cast("str", normalizer.normalize_str(token))
|
| 130 |
+
|
| 131 |
+
if not token:
|
| 132 |
+
n_empty += 1
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer
|
| 136 |
+
normalized_token = token
|
| 137 |
+
if pre_tokenizer is not None:
|
| 138 |
+
normalized_token = _normalize_vocabulary_token(
|
| 139 |
+
token=token,
|
| 140 |
+
pre_tokenizer=pre_tokenizer,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# We need to check whether the pretokenized token is in the vocabulary.
|
| 144 |
+
# But we need to return the original token, because that will be tokenized
|
| 145 |
+
# again by the tokenizer during featurization.
|
| 146 |
+
if normalized_token in seen_tokens or normalized_token in internal_tokens_set:
|
| 147 |
+
n_duplicates += 1
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
# Add the possibly pretokenized token to seen
|
| 151 |
+
seen_tokens.add(normalized_token)
|
| 152 |
+
|
| 153 |
+
# After checking the token exists, we need to normalize it into the token
|
| 154 |
+
# it will become. For byte tokens, this means we don't do anything. For
|
| 155 |
+
# other types of tokens, we will insert a metaspace.
|
| 156 |
+
# In the case of multiword tokens, we replace any spaces with the metaspace
|
| 157 |
+
# or byte prefix token.
|
| 158 |
+
if not normalized_token.startswith(("▁", "Ġ")):
|
| 159 |
+
normalized_token = normalized_token.replace(" ", "▁")
|
| 160 |
+
normalized_token = f"▁{normalized_token}"
|
| 161 |
+
else:
|
| 162 |
+
normalized_token = normalized_token.replace(" ", normalized_token[0])
|
| 163 |
+
|
| 164 |
+
if normalized_token in post_normalize_seen_tokens:
|
| 165 |
+
n_duplicates += 1
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
post_normalize_seen_tokens.add(normalized_token)
|
| 169 |
+
# Add the original string to the vocabulary.
|
| 170 |
+
cleaned_vocabulary.append(
|
| 171 |
+
Token(form=token, normalized_form=normalized_token, is_subword=False, is_internal=False)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if n_duplicates:
|
| 175 |
+
logger.warning(f"Removed {n_duplicates} duplicate tokens.")
|
| 176 |
+
if n_empty:
|
| 177 |
+
logger.warning(f"Removed {n_empty} empty tokens.")
|
| 178 |
+
|
| 179 |
+
return cleaned_vocabulary, replace_pretokenizer(backend_tokenizer)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _process_internal_tokens(
|
| 183 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 184 |
+
backend_tokenizer: Tokenizer,
|
| 185 |
+
internal_tokens: list[str],
|
| 186 |
+
token_remove_regex: re.Pattern | None,
|
| 187 |
+
) -> list[Token]:
|
| 188 |
+
"""Clean internal tokens."""
|
| 189 |
+
# Get the pad and unk token from the tokenizer.
|
| 190 |
+
pad_token: str | None = tokenizer.special_tokens_map.get("pad_token") # type: ignore[assignment]
|
| 191 |
+
unk_token: str | None = tokenizer.special_tokens_map.get("unk_token") # type: ignore[assignment]
|
| 192 |
+
# Empty set if no pad or unk token is set.
|
| 193 |
+
added_tokens_to_keep: set[str] = {x for x in (pad_token, unk_token) if x is not None}
|
| 194 |
+
added_tokens_to_remove = set(tokenizer.added_tokens_encoder) - added_tokens_to_keep
|
| 195 |
+
cleaned_internal_tokens: list[Token] = []
|
| 196 |
+
|
| 197 |
+
# Figure out whether token is a subword or not.
|
| 198 |
+
encoded = backend_tokenizer.encode(f" {'a' * 25}", add_special_tokens=False)
|
| 199 |
+
first_token, second_token, *_ = encoded.tokens
|
| 200 |
+
# Isolate the prefix. We can't do first_token[0] because we don't know
|
| 201 |
+
# how long the prefix is.
|
| 202 |
+
# e.g., "Ġaaaa" -> "Ġ"
|
| 203 |
+
a_index = None if "a" not in first_token else first_token.index("a")
|
| 204 |
+
word_prefix = first_token[:a_index]
|
| 205 |
+
is_byte_prefix = word_prefix == "Ġ"
|
| 206 |
+
second_token = encoded.tokens[1]
|
| 207 |
+
# The second token is the first subword token.
|
| 208 |
+
# If a tokenizer uses subwords, this token will have been prefixed.
|
| 209 |
+
# We don't know how long the prefix is.
|
| 210 |
+
a_index = None if "a" not in second_token else second_token.index("a")
|
| 211 |
+
subword_prefix = second_token[:a_index]
|
| 212 |
+
|
| 213 |
+
pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer
|
| 214 |
+
|
| 215 |
+
for token in internal_tokens:
|
| 216 |
+
# Create the token objects. If this returns None, it was unsucessful for some reason.
|
| 217 |
+
if token_object := _create_single_internal_token(
|
| 218 |
+
token=token,
|
| 219 |
+
subword_prefix=subword_prefix,
|
| 220 |
+
word_prefix=word_prefix,
|
| 221 |
+
pre_tokenizer=pre_tokenizer,
|
| 222 |
+
is_byte_prefix=is_byte_prefix,
|
| 223 |
+
token_remove_regex=token_remove_regex,
|
| 224 |
+
added_tokens_to_keep=added_tokens_to_keep,
|
| 225 |
+
added_tokens_to_remove=added_tokens_to_remove,
|
| 226 |
+
):
|
| 227 |
+
cleaned_internal_tokens.append(token_object)
|
| 228 |
+
|
| 229 |
+
if len(cleaned_internal_tokens) != len(internal_tokens):
|
| 230 |
+
logger.info(
|
| 231 |
+
f"Removed {len(internal_tokens) - len(cleaned_internal_tokens)} internal tokens from the vocabulary."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return cleaned_internal_tokens
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _create_single_internal_token(
|
| 238 |
+
token: str,
|
| 239 |
+
subword_prefix: str,
|
| 240 |
+
word_prefix: str,
|
| 241 |
+
pre_tokenizer: PreTokenizer | None,
|
| 242 |
+
is_byte_prefix: bool,
|
| 243 |
+
token_remove_regex: re.Pattern | None,
|
| 244 |
+
added_tokens_to_keep: set[str],
|
| 245 |
+
added_tokens_to_remove: set[str],
|
| 246 |
+
) -> Token | None:
|
| 247 |
+
"""Create a token object from a string."""
|
| 248 |
+
if token in added_tokens_to_remove:
|
| 249 |
+
# We remove any tokens that are added tokens that aren't [UNK] or [PAD].
|
| 250 |
+
return None
|
| 251 |
+
if token in added_tokens_to_keep:
|
| 252 |
+
# Don't put added tokens through the regular motions.
|
| 253 |
+
return Token(form=token, normalized_form=token, is_subword=False, is_internal=True)
|
| 254 |
+
if token_remove_regex and token_remove_regex.match(token):
|
| 255 |
+
# If the regex matches, remove the token.
|
| 256 |
+
return None
|
| 257 |
+
|
| 258 |
+
# A token is a subword if there is a subword prefix and the word
|
| 259 |
+
# starts with a subword prefix, or if there is a WORD prefix, and the word
|
| 260 |
+
# does not start with this prefix. For metaspace tokenizers, for example:
|
| 261 |
+
# "doghouse" -> ["_dog", "house"]
|
| 262 |
+
# So we can only tell that "house" is a subword by knowing that it is not prefixed
|
| 263 |
+
# and word-initial tokens are.
|
| 264 |
+
is_subword = False
|
| 265 |
+
if subword_prefix:
|
| 266 |
+
is_subword = bool(token.startswith(subword_prefix))
|
| 267 |
+
if word_prefix:
|
| 268 |
+
is_subword = not bool(token.startswith(word_prefix))
|
| 269 |
+
|
| 270 |
+
# Byte prefixed tokenizers don't need to be checked.
|
| 271 |
+
if pre_tokenizer is not None and not is_byte_prefix:
|
| 272 |
+
# We need to check the thing without prefixes. If we have a word prefix,
|
| 273 |
+
# we need to check tokens that have are subwords. Other way around for subword
|
| 274 |
+
# prefixes.
|
| 275 |
+
if (subword_prefix and not is_subword) or (word_prefix and is_subword):
|
| 276 |
+
# If this is True, the token is unreachable, even though it is a subword token.
|
| 277 |
+
if len(pre_tokenizer.pre_tokenize_str(token)) > 1:
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
# Turn a token into a normalized form for later processing.
|
| 281 |
+
normalized_form = _create_normalized_form(token, subword_prefix, word_prefix, is_byte_prefix, is_subword)
|
| 282 |
+
|
| 283 |
+
return Token(form=token, normalized_form=normalized_form, is_subword=is_subword, is_internal=True)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _create_normalized_form(
|
| 287 |
+
token: str, subword_prefix: str, word_prefix: str, is_byte_prefix: bool, is_subword: bool
|
| 288 |
+
) -> str:
|
| 289 |
+
"""Turn an internal token string into a normalized form."""
|
| 290 |
+
# We don't need to check byte prefixed strings.
|
| 291 |
+
if is_byte_prefix:
|
| 292 |
+
return token
|
| 293 |
+
# We need to check if the token is a subword or not and remove the prefix.
|
| 294 |
+
if is_subword:
|
| 295 |
+
return token.removeprefix(subword_prefix)
|
| 296 |
+
# If the token is not a subword, we need to remove the word prefix, and add metaspace.
|
| 297 |
+
return f"▁{token.removeprefix(word_prefix)}"
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def turn_tokens_into_ids(
|
| 301 |
+
tokens: list[Token], tokenizer: PreTrainedTokenizerFast, unk_token: str | None
|
| 302 |
+
) -> list[list[int]]:
|
| 303 |
+
"""
|
| 304 |
+
Convert a list of Token objects to their corresponding token ID sequences.
|
| 305 |
+
|
| 306 |
+
:param tokens: List of Token objects to convert
|
| 307 |
+
:param tokenizer: The tokenizer to use for converting tokens to IDs
|
| 308 |
+
:param unk_token: The string form of the unk token.
|
| 309 |
+
:return: List of token IDs corresponding to the input tokens
|
| 310 |
+
"""
|
| 311 |
+
unk_id = None if unk_token is None else tokenizer.convert_tokens_to_ids(unk_token)
|
| 312 |
+
prefix, suffix = find_eos_bos(tokenizer)
|
| 313 |
+
|
| 314 |
+
token_ids: list[list[int]] = []
|
| 315 |
+
for token in tokens:
|
| 316 |
+
if token.is_internal:
|
| 317 |
+
# Careful. Any incorrect tokens will just get `[UNK]``, so this could go horribly wrong
|
| 318 |
+
# Cast because return type is wrong.
|
| 319 |
+
token_id: int = cast("int", tokenizer.convert_tokens_to_ids(token.form)) or 0
|
| 320 |
+
# Explicitly check and warn if `unk_id` appears, but don't crash.
|
| 321 |
+
if unk_id is not None and token_id == unk_id and token.form != unk_token:
|
| 322 |
+
logger.warning(f"Token {token.form} was set to unk. This is wrong.")
|
| 323 |
+
token_ids.append([*prefix, token_id, *suffix])
|
| 324 |
+
else:
|
| 325 |
+
token_ids.append(tokenizer.encode(token.form))
|
| 326 |
+
|
| 327 |
+
return token_ids
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def find_eos_bos(tokenizer: PreTrainedTokenizerFast) -> tuple[list[int], list[int]]:
|
| 331 |
+
"""Finds the eos and bos tokens for a tokenizer."""
|
| 332 |
+
# Little bit complicated, because not all tokenizers have eos and bos tokens.
|
| 333 |
+
encoding = tokenizer.encode("a", add_special_tokens=True)
|
| 334 |
+
if len(encoding) != 3:
|
| 335 |
+
a_encoded = tokenizer.encode("a", add_special_tokens=False)
|
| 336 |
+
if len(a_encoded) != 1:
|
| 337 |
+
msg = f"Error while encoding, couldn't determine eos and bos tokens. The model tokenizes 'a' to '{a_encoded}'"
|
| 338 |
+
raise ValueError(
|
| 339 |
+
msg
|
| 340 |
+
)
|
| 341 |
+
a_idx = encoding.index(a_encoded[0])
|
| 342 |
+
prefix, suffix = encoding[:a_idx], encoding[a_idx + 1 :]
|
| 343 |
+
else:
|
| 344 |
+
prefix, suffix = encoding[:1], encoding[2:]
|
| 345 |
+
return prefix, suffix
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str:
|
| 349 |
+
"""Normalize a token that is not in the initial token vocabulary."""
|
| 350 |
+
# Add prefix space for byte tokenizers.
|
| 351 |
+
prefixed_token = f" {token}"
|
| 352 |
+
pretokenized_tokens: tuple[str, ...]
|
| 353 |
+
pretokenized_tokens, offsets = zip(*pre_tokenizer.pre_tokenize_str(prefixed_token), strict=False)
|
| 354 |
+
# The first item is always the start of the token.
|
| 355 |
+
new_token = [pretokenized_tokens[0]]
|
| 356 |
+
# Loop over the subtokens and offsets.
|
| 357 |
+
for t, (s, _) in zip(pretokenized_tokens[1:], offsets[1:], strict=False):
|
| 358 |
+
# Do not prefix the token with a space if it starts with a metaspace.
|
| 359 |
+
if t.startswith("▁"):
|
| 360 |
+
new_token.append(t)
|
| 361 |
+
# If the character before the subtoken is a space, we have a
|
| 362 |
+
# multiword token. e.g., "room for the moon", which is split into
|
| 363 |
+
# ["room", "for", "the", "moon"].
|
| 364 |
+
# If it doesn't have a space, it is part of a complex multiword token,
|
| 365 |
+
# e.g., "chat-gpt", which is split into ["chat", "-", "gpt"].
|
| 366 |
+
elif prefixed_token[s - 1] == " ":
|
| 367 |
+
new_token.append(f" {t}")
|
| 368 |
+
else:
|
| 369 |
+
new_token.append(t)
|
| 370 |
+
return "".join(new_token)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def create_tokenizer(
|
| 375 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 376 |
+
vocabulary: list[str],
|
| 377 |
+
token_remove_regex: re.Pattern | None = None,
|
| 378 |
+
) -> PreTrainedTokenizerFast:
|
| 379 |
+
"""
|
| 380 |
+
Create a tokenizer by adding tokens to the vocabulary.
|
| 381 |
+
|
| 382 |
+
This function turns any tokenizer into a supertoken tokenizer. It does the following:
|
| 383 |
+
1. Turns the tokenizer model into a unigram model.
|
| 384 |
+
2. Adds a new pretokenizer, splitting on punctuation.
|
| 385 |
+
3. Adds all tokens in vocabulary to the model.
|
| 386 |
+
4. Removes any internal tokens that conform to the regex.
|
| 387 |
+
|
| 388 |
+
:param tokenizer: The tokenizer to use.
|
| 389 |
+
:param vocabulary: The vocabulary to use.
|
| 390 |
+
:param token_remove_regex: The regex to use to remove tokens from the vocabulary.
|
| 391 |
+
:return: The created tokenizer.
|
| 392 |
+
"""
|
| 393 |
+
unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token"))
|
| 394 |
+
pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token"))
|
| 395 |
+
cleaned_vocabulary, backend_tokenizer = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex)
|
| 396 |
+
new_tokenizer = replace_vocabulary(backend_tokenizer, cleaned_vocabulary, unk_token, pad_token)
|
| 397 |
+
|
| 398 |
+
return PreTrainedTokenizerFast(tokenizer_object=new_tokenizer)
|
src/distiller/model2vec/train/README.md
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training
|
| 2 |
+
|
| 3 |
+
Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html).
|
| 4 |
+
|
| 5 |
+
We support both single and multi-label classification, which work seamlessly based on the labels you provide.
|
| 6 |
+
|
| 7 |
+
# Installation
|
| 8 |
+
|
| 9 |
+
To train, make sure you install the training extra:
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
pip install model2vec[training]
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
# Quickstart
|
| 16 |
+
|
| 17 |
+
To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
from model2vec.distill import distill
|
| 21 |
+
from model2vec.train import StaticModelForClassification
|
| 22 |
+
|
| 23 |
+
# From a distilled model
|
| 24 |
+
distilled_model = distill("baai/bge-base-en-v1.5")
|
| 25 |
+
classifier = StaticModelForClassification.from_static_model(model=distilled_model)
|
| 26 |
+
|
| 27 |
+
# From a pre-trained model: potion is the default
|
| 28 |
+
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m")
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-32m](https://huggingface.co/minishlab/potion-base-32M), our best model to date. This is our recommended path if you're working with general English data.
|
| 32 |
+
|
| 33 |
+
Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed.
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
import numpy as np
|
| 37 |
+
from datasets import load_dataset
|
| 38 |
+
|
| 39 |
+
# Load the subj dataset
|
| 40 |
+
ds = load_dataset("setfit/subj")
|
| 41 |
+
train = ds["train"]
|
| 42 |
+
test = ds["test"]
|
| 43 |
+
|
| 44 |
+
s = perf_counter()
|
| 45 |
+
classifier = classifier.fit(train["text"], train["label"])
|
| 46 |
+
|
| 47 |
+
print(f"Training took {int(perf_counter() - s)} seconds.")
|
| 48 |
+
# Training took 81 seconds
|
| 49 |
+
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
|
| 50 |
+
print(classification_report)
|
| 51 |
+
# Achieved 91.0 test accuracy
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training.
|
| 55 |
+
|
| 56 |
+
The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5.
|
| 57 |
+
|
| 58 |
+
Note that this model is as fast as you're used to from us:
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
from time import perf_counter
|
| 62 |
+
|
| 63 |
+
s = perf_counter()
|
| 64 |
+
classifier.predict(test["text"])
|
| 65 |
+
print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.")
|
| 66 |
+
# Took 67 milliseconds for 2000 instances on CPU.
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Multi-label classification
|
| 70 |
+
|
| 71 |
+
Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset:
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
from datasets import load_dataset
|
| 75 |
+
from model2vec.train import StaticModelForClassification
|
| 76 |
+
|
| 77 |
+
# Initialize a classifier from a pre-trained model
|
| 78 |
+
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
|
| 79 |
+
|
| 80 |
+
# Load a multi-label dataset
|
| 81 |
+
ds = load_dataset("google-research-datasets/go_emotions")
|
| 82 |
+
|
| 83 |
+
# Inspect some of the labels
|
| 84 |
+
print(ds["train"]["labels"][40:50])
|
| 85 |
+
# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]
|
| 86 |
+
|
| 87 |
+
# Train the classifier on text (X) and labels (y)
|
| 88 |
+
classifier.fit(ds["train"]["text"], ds["train"]["labels"])
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Then, we can evaluate the classifier:
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
from sklearn import metrics
|
| 95 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 96 |
+
|
| 97 |
+
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
|
| 98 |
+
print(classification_report)
|
| 99 |
+
# Accuracy: 0.410
|
| 100 |
+
# Precision: 0.527
|
| 101 |
+
# Recall: 0.410
|
| 102 |
+
# F1: 0.439
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster.
|
| 106 |
+
|
| 107 |
+
# Persistence
|
| 108 |
+
|
| 109 |
+
You can turn a classifier into a scikit-learn compatible pipeline, as follows:
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
pipeline = classifier.to_pipeline()
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization.
|
| 116 |
+
|
| 117 |
+
If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions:
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
pipeline.save_pretrained(path)
|
| 121 |
+
pipeline.push_to_hub("my_cool/project")
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
Later, you can load these as follows:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
from model2vec.inference import StaticModelPipeline
|
| 128 |
+
|
| 129 |
+
pipeline = StaticModelPipeline.from_pretrained("my_cool/project")
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk.
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Bring your own architecture
|
| 136 |
+
|
| 137 |
+
Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc.
|
| 138 |
+
|
| 139 |
+
The core functionality of the `StaticModelForClassification` is contained in a couple of functions:
|
| 140 |
+
|
| 141 |
+
* `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior.
|
| 142 |
+
* `train_test_split`: governs the train test split before classification.
|
| 143 |
+
* `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training.
|
| 144 |
+
* `_encode`: The encoding function used in the model.
|
| 145 |
+
* `fit`: contains all the lightning-related fitting logic.
|
| 146 |
+
|
| 147 |
+
The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic.
|
| 148 |
+
|
| 149 |
+
# Results
|
| 150 |
+
|
| 151 |
+
We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation.
|
src/distiller/model2vec/train/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from distiller.model2vec.utils import get_package_extras, importable
|
| 2 |
+
|
| 3 |
+
_REQUIRED_EXTRA = "train"
|
| 4 |
+
|
| 5 |
+
for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
|
| 6 |
+
importable(extra_dependency, _REQUIRED_EXTRA)
|
| 7 |
+
|
| 8 |
+
from distiller.model2vec.train.classifier import StaticModelForClassification
|
| 9 |
+
|
| 10 |
+
__all__ = ["StaticModelForClassification"]
|
src/distiller/model2vec/train/base.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Self, TypeVar
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
|
| 12 |
+
from distiller.model2vec import StaticModel
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from tokenizers import Encoding, Tokenizer
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FinetunableStaticModel(nn.Module):
|
| 21 |
+
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Initialize a trainable StaticModel from a StaticModel.
|
| 24 |
+
|
| 25 |
+
:param vectors: The embeddings of the staticmodel.
|
| 26 |
+
:param tokenizer: The tokenizer.
|
| 27 |
+
:param out_dim: The output dimension of the head.
|
| 28 |
+
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
|
| 29 |
+
"""
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.pad_id = pad_id
|
| 32 |
+
self.out_dim = out_dim
|
| 33 |
+
self.embed_dim = vectors.shape[1]
|
| 34 |
+
|
| 35 |
+
self.vectors = vectors
|
| 36 |
+
if self.vectors.dtype != torch.float32:
|
| 37 |
+
dtype = str(self.vectors.dtype)
|
| 38 |
+
logger.warning(
|
| 39 |
+
f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
|
| 40 |
+
)
|
| 41 |
+
self.vectors = vectors.float()
|
| 42 |
+
|
| 43 |
+
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
|
| 44 |
+
self.head = self.construct_head()
|
| 45 |
+
self.w = self.construct_weights()
|
| 46 |
+
self.tokenizer = tokenizer
|
| 47 |
+
|
| 48 |
+
def construct_weights(self) -> nn.Parameter:
|
| 49 |
+
"""Construct the weights for the model."""
|
| 50 |
+
weights = torch.zeros(len(self.vectors))
|
| 51 |
+
weights[self.pad_id] = -10_000
|
| 52 |
+
return nn.Parameter(weights)
|
| 53 |
+
|
| 54 |
+
def construct_head(self) -> nn.Sequential:
|
| 55 |
+
"""Method should be overridden for various other classes."""
|
| 56 |
+
return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def from_pretrained(
|
| 60 |
+
cls, *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
|
| 61 |
+
) -> Self:
|
| 62 |
+
"""Load the model from a pretrained model2vec model."""
|
| 63 |
+
model = StaticModel.from_pretrained(model_name)
|
| 64 |
+
return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def from_static_model(cls, *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> Self:
|
| 68 |
+
"""Load the model from a static model."""
|
| 69 |
+
model.embedding = np.nan_to_num(model.embedding)
|
| 70 |
+
embeddings_converted = torch.from_numpy(model.embedding)
|
| 71 |
+
return cls(
|
| 72 |
+
vectors=embeddings_converted,
|
| 73 |
+
pad_id=model.tokenizer.token_to_id("[PAD]"),
|
| 74 |
+
out_dim=out_dim,
|
| 75 |
+
tokenizer=model.tokenizer,
|
| 76 |
+
**kwargs,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
"""
|
| 81 |
+
A forward pass and mean pooling.
|
| 82 |
+
|
| 83 |
+
This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients
|
| 84 |
+
to pass through.
|
| 85 |
+
|
| 86 |
+
:param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
|
| 87 |
+
:return: The mean over the input ids, weighted by token weights.
|
| 88 |
+
"""
|
| 89 |
+
w = self.w[input_ids]
|
| 90 |
+
w = torch.sigmoid(w)
|
| 91 |
+
zeros = (input_ids != self.pad_id).float()
|
| 92 |
+
w = w * zeros
|
| 93 |
+
# Add a small epsilon to avoid division by zero
|
| 94 |
+
length = zeros.sum(1) + 1e-16
|
| 95 |
+
embedded = self.embeddings(input_ids)
|
| 96 |
+
# Weigh each token
|
| 97 |
+
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
|
| 98 |
+
# Mean pooling by dividing by the length
|
| 99 |
+
embedded = embedded / length[:, None]
|
| 100 |
+
|
| 101 |
+
return nn.functional.normalize(embedded)
|
| 102 |
+
|
| 103 |
+
def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 104 |
+
"""Forward pass through the mean, and a classifier layer after."""
|
| 105 |
+
encoded = self._encode(input_ids)
|
| 106 |
+
return self.head(encoded), encoded
|
| 107 |
+
|
| 108 |
+
def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Tokenize a bunch of strings into a single padded 2D tensor.
|
| 111 |
+
|
| 112 |
+
Note that this is not used during training.
|
| 113 |
+
|
| 114 |
+
:param texts: The texts to tokenize.
|
| 115 |
+
:param max_length: If this is None, the sequence lengths are truncated to 512.
|
| 116 |
+
:return: A 2D padded tensor
|
| 117 |
+
"""
|
| 118 |
+
encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
|
| 119 |
+
encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
|
| 120 |
+
return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def device(self) -> str:
|
| 124 |
+
"""Get the device of the model."""
|
| 125 |
+
return self.embeddings.weight.device
|
| 126 |
+
|
| 127 |
+
def to_static_model(self) -> StaticModel:
|
| 128 |
+
"""Convert the model to a static model."""
|
| 129 |
+
emb = self.embeddings.weight.detach().cpu().numpy()
|
| 130 |
+
w = torch.sigmoid(self.w).detach().cpu().numpy()
|
| 131 |
+
|
| 132 |
+
return StaticModel(emb * w[:, None], self.tokenizer, normalize=True)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class TextDataset(Dataset):
|
| 136 |
+
def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:
|
| 137 |
+
"""
|
| 138 |
+
A dataset of texts.
|
| 139 |
+
|
| 140 |
+
:param tokenized_texts: The tokenized texts. Each text is a list of token ids.
|
| 141 |
+
:param targets: The targets.
|
| 142 |
+
:raises ValueError: If the number of labels does not match the number of texts.
|
| 143 |
+
"""
|
| 144 |
+
if len(targets) != len(tokenized_texts):
|
| 145 |
+
msg = "Number of labels does not match number of texts."
|
| 146 |
+
raise ValueError(msg)
|
| 147 |
+
self.tokenized_texts = tokenized_texts
|
| 148 |
+
self.targets = targets
|
| 149 |
+
|
| 150 |
+
def __len__(self) -> int:
|
| 151 |
+
"""Return the length of the dataset."""
|
| 152 |
+
return len(self.tokenized_texts)
|
| 153 |
+
|
| 154 |
+
def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
|
| 155 |
+
"""Gets an item."""
|
| 156 |
+
return self.tokenized_texts[index], self.targets[index]
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
|
| 160 |
+
"""Collate function."""
|
| 161 |
+
texts, targets = zip(*batch, strict=False)
|
| 162 |
+
|
| 163 |
+
tensors = [torch.LongTensor(x) for x in texts]
|
| 164 |
+
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
|
| 165 |
+
|
| 166 |
+
return padded, torch.stack(targets)
|
| 167 |
+
|
| 168 |
+
def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
|
| 169 |
+
"""Convert the dataset to a DataLoader."""
|
| 170 |
+
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
ModelType = TypeVar("ModelType", bound=FinetunableStaticModel)
|
src/distiller/model2vec/train/classifier.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from collections import Counter
|
| 5 |
+
from itertools import chain
|
| 6 |
+
from tempfile import TemporaryDirectory
|
| 7 |
+
from typing import TYPE_CHECKING, TypeVar, cast
|
| 8 |
+
|
| 9 |
+
import lightning as pl
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from lightning.pytorch.callbacks import Callback, EarlyStopping
|
| 13 |
+
from sklearn.metrics import jaccard_score
|
| 14 |
+
from sklearn.model_selection import train_test_split
|
| 15 |
+
from sklearn.neural_network import MLPClassifier
|
| 16 |
+
from sklearn.pipeline import make_pipeline
|
| 17 |
+
from torch import nn
|
| 18 |
+
from tqdm import trange
|
| 19 |
+
|
| 20 |
+
from distiller.model2vec.inference import StaticModelPipeline, evaluate_single_or_multi_label
|
| 21 |
+
from distiller.model2vec.train.base import FinetunableStaticModel, TextDataset
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
| 25 |
+
from tokenizers import Tokenizer
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
_RANDOM_SEED = 42
|
| 29 |
+
|
| 30 |
+
LabelType = TypeVar("LabelType", list[str], list[list[str]])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class StaticModelForClassification(FinetunableStaticModel):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
*,
|
| 37 |
+
vectors: torch.Tensor,
|
| 38 |
+
tokenizer: Tokenizer,
|
| 39 |
+
n_layers: int = 1,
|
| 40 |
+
hidden_dim: int = 512,
|
| 41 |
+
out_dim: int = 2,
|
| 42 |
+
pad_id: int = 0,
|
| 43 |
+
) -> None:
|
| 44 |
+
"""Initialize a standard classifier model."""
|
| 45 |
+
self.n_layers = n_layers
|
| 46 |
+
self.hidden_dim = hidden_dim
|
| 47 |
+
# Alias: Follows scikit-learn. Set to dummy classes
|
| 48 |
+
self.classes_: list[str] = [str(x) for x in range(out_dim)]
|
| 49 |
+
# multilabel flag will be set based on the type of `y` passed to fit.
|
| 50 |
+
self.multilabel: bool = False
|
| 51 |
+
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def classes(self) -> np.ndarray:
|
| 55 |
+
"""Return all clasess in the correct order."""
|
| 56 |
+
return np.array(self.classes_)
|
| 57 |
+
|
| 58 |
+
def construct_head(self) -> nn.Sequential:
|
| 59 |
+
"""Constructs a simple classifier head."""
|
| 60 |
+
if self.n_layers == 0:
|
| 61 |
+
return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
|
| 62 |
+
modules = [
|
| 63 |
+
nn.Linear(self.embed_dim, self.hidden_dim),
|
| 64 |
+
nn.ReLU(),
|
| 65 |
+
]
|
| 66 |
+
for _ in range(self.n_layers - 1):
|
| 67 |
+
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
|
| 68 |
+
modules.extend([nn.Linear(self.hidden_dim, self.out_dim)])
|
| 69 |
+
|
| 70 |
+
for module in modules:
|
| 71 |
+
if isinstance(module, nn.Linear):
|
| 72 |
+
nn.init.kaiming_uniform_(module.weight)
|
| 73 |
+
nn.init.zeros_(module.bias)
|
| 74 |
+
|
| 75 |
+
return nn.Sequential(*modules)
|
| 76 |
+
|
| 77 |
+
def predict(
|
| 78 |
+
self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024, threshold: float = 0.5
|
| 79 |
+
) -> np.ndarray:
|
| 80 |
+
"""
|
| 81 |
+
Predict labels for a set of texts.
|
| 82 |
+
|
| 83 |
+
In single-label mode, each prediction is a single class.
|
| 84 |
+
In multilabel mode, each prediction is a list of classes.
|
| 85 |
+
|
| 86 |
+
:param X: The texts to predict on.
|
| 87 |
+
:param show_progress_bar: Whether to show a progress bar.
|
| 88 |
+
:param batch_size: The batch size.
|
| 89 |
+
:param threshold: The threshold for multilabel classification.
|
| 90 |
+
:return: The predictions.
|
| 91 |
+
"""
|
| 92 |
+
pred = []
|
| 93 |
+
for batch in trange(0, len(X), batch_size, disable=not show_progress_bar):
|
| 94 |
+
logits = self._predict_single_batch(X[batch : batch + batch_size])
|
| 95 |
+
if self.multilabel:
|
| 96 |
+
probs = torch.sigmoid(logits)
|
| 97 |
+
mask = (probs > threshold).cpu().numpy()
|
| 98 |
+
pred.extend([self.classes[np.flatnonzero(row)] for row in mask])
|
| 99 |
+
else:
|
| 100 |
+
pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()])
|
| 101 |
+
if self.multilabel:
|
| 102 |
+
# Return as object array to allow for lists of varying lengths.
|
| 103 |
+
return np.array(pred, dtype=object)
|
| 104 |
+
return np.array(pred)
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def _predict_single_batch(self, X: list[str]) -> torch.Tensor:
|
| 108 |
+
input_ids = self.tokenize(X)
|
| 109 |
+
vectors, _ = self.forward(input_ids)
|
| 110 |
+
return vectors
|
| 111 |
+
|
| 112 |
+
def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray:
|
| 113 |
+
"""
|
| 114 |
+
Predict probabilities for each class.
|
| 115 |
+
|
| 116 |
+
In single-label mode, returns softmax probabilities.
|
| 117 |
+
In multilabel mode, returns sigmoid probabilities.
|
| 118 |
+
"""
|
| 119 |
+
pred = []
|
| 120 |
+
for batch in trange(0, len(X), batch_size, disable=not show_progress_bar):
|
| 121 |
+
logits = self._predict_single_batch(X[batch : batch + batch_size])
|
| 122 |
+
if self.multilabel:
|
| 123 |
+
pred.append(torch.sigmoid(logits).cpu().numpy())
|
| 124 |
+
else:
|
| 125 |
+
pred.append(torch.softmax(logits, dim=1).cpu().numpy())
|
| 126 |
+
return np.concatenate(pred, axis=0)
|
| 127 |
+
|
| 128 |
+
def fit(
|
| 129 |
+
self,
|
| 130 |
+
X: list[str],
|
| 131 |
+
y: LabelType,
|
| 132 |
+
learning_rate: float = 1e-3,
|
| 133 |
+
batch_size: int | None = None,
|
| 134 |
+
min_epochs: int | None = None,
|
| 135 |
+
max_epochs: int | None = -1,
|
| 136 |
+
early_stopping_patience: int | None = 5,
|
| 137 |
+
test_size: float = 0.1,
|
| 138 |
+
device: str = "auto",
|
| 139 |
+
X_val: list[str] | None = None,
|
| 140 |
+
y_val: LabelType | None = None,
|
| 141 |
+
) -> StaticModelForClassification:
|
| 142 |
+
"""
|
| 143 |
+
Fit a model.
|
| 144 |
+
|
| 145 |
+
This function creates a Lightning Trainer object and fits the model to the data.
|
| 146 |
+
It supports both single-label and multi-label classification.
|
| 147 |
+
We use early stopping. After training, the weights of the best model are loaded back into the model.
|
| 148 |
+
|
| 149 |
+
This function seeds everything with a seed of 42, so the results are reproducible.
|
| 150 |
+
It also splits the data into a train and validation set, again with a random seed.
|
| 151 |
+
|
| 152 |
+
If `X_val` and `y_val` are not provided, the function will automatically
|
| 153 |
+
split the training data into a train and validation set using `test_size`.
|
| 154 |
+
|
| 155 |
+
:param X: The texts to train on.
|
| 156 |
+
:param y: The labels to train on. If the first element is a list, multi-label classification is assumed.
|
| 157 |
+
:param learning_rate: The learning rate.
|
| 158 |
+
:param batch_size: The batch size. If None, a good batch size is chosen automatically.
|
| 159 |
+
:param min_epochs: The minimum number of epochs to train for.
|
| 160 |
+
:param max_epochs: The maximum number of epochs to train for.
|
| 161 |
+
If this is -1, the model trains until early stopping is triggered.
|
| 162 |
+
:param early_stopping_patience: The patience for early stopping.
|
| 163 |
+
If this is None, early stopping is disabled.
|
| 164 |
+
:param test_size: The test size for the train-test split.
|
| 165 |
+
:param device: The device to train on. If this is "auto", the device is chosen automatically.
|
| 166 |
+
:param X_val: The texts to be used for validation.
|
| 167 |
+
:param y_val: The labels to be used for validation.
|
| 168 |
+
:return: The fitted model.
|
| 169 |
+
:raises ValueError: If either X_val or y_val are provided, but not both.
|
| 170 |
+
"""
|
| 171 |
+
pl.seed_everything(_RANDOM_SEED)
|
| 172 |
+
logger.info("Re-initializing model.")
|
| 173 |
+
|
| 174 |
+
# Determine whether the task is multilabel based on the type of y.
|
| 175 |
+
|
| 176 |
+
self._initialize(y)
|
| 177 |
+
|
| 178 |
+
if (X_val is not None) != (y_val is not None):
|
| 179 |
+
msg = "Both X_val and y_val must be provided together, or neither."
|
| 180 |
+
raise ValueError(msg)
|
| 181 |
+
|
| 182 |
+
if X_val is not None and y_val is not None:
|
| 183 |
+
# Additional check to ensure y_val is of the same type as y
|
| 184 |
+
if type(y_val[0]) != type(y[0]):
|
| 185 |
+
msg = "X_val and y_val must be of the same type as X and y."
|
| 186 |
+
raise ValueError(msg)
|
| 187 |
+
|
| 188 |
+
train_texts = X
|
| 189 |
+
train_labels = y
|
| 190 |
+
validation_texts = X_val
|
| 191 |
+
validation_labels = y_val
|
| 192 |
+
else:
|
| 193 |
+
train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
|
| 194 |
+
X,
|
| 195 |
+
y,
|
| 196 |
+
test_size=test_size,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if batch_size is None:
|
| 200 |
+
# Set to a multiple of 32
|
| 201 |
+
base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
|
| 202 |
+
batch_size = int(base_number * 32)
|
| 203 |
+
logger.info("Batch size automatically set to %d.", batch_size)
|
| 204 |
+
|
| 205 |
+
logger.info("Preparing train dataset.")
|
| 206 |
+
train_dataset = self._prepare_dataset(train_texts, train_labels)
|
| 207 |
+
logger.info("Preparing validation dataset.")
|
| 208 |
+
val_dataset = self._prepare_dataset(validation_texts, validation_labels)
|
| 209 |
+
|
| 210 |
+
c = _ClassifierLightningModule(self, learning_rate=learning_rate)
|
| 211 |
+
|
| 212 |
+
n_train_batches = len(train_dataset) // batch_size
|
| 213 |
+
callbacks: list[Callback] = []
|
| 214 |
+
if early_stopping_patience is not None:
|
| 215 |
+
callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience)
|
| 216 |
+
callbacks.append(callback)
|
| 217 |
+
|
| 218 |
+
# If the dataset is small, we check the validation set every epoch.
|
| 219 |
+
# If the dataset is large, we check the validation set every 250 batches.
|
| 220 |
+
if n_train_batches < 250:
|
| 221 |
+
val_check_interval = None
|
| 222 |
+
check_val_every_epoch = 1
|
| 223 |
+
else:
|
| 224 |
+
val_check_interval = max(250, 2 * len(val_dataset) // batch_size)
|
| 225 |
+
check_val_every_epoch = None
|
| 226 |
+
|
| 227 |
+
with TemporaryDirectory() as tempdir:
|
| 228 |
+
trainer = pl.Trainer(
|
| 229 |
+
min_epochs=min_epochs,
|
| 230 |
+
max_epochs=max_epochs,
|
| 231 |
+
callbacks=callbacks,
|
| 232 |
+
val_check_interval=val_check_interval,
|
| 233 |
+
check_val_every_n_epoch=check_val_every_epoch,
|
| 234 |
+
accelerator=device,
|
| 235 |
+
default_root_dir=tempdir,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
trainer.fit(
|
| 239 |
+
c,
|
| 240 |
+
train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size),
|
| 241 |
+
val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size),
|
| 242 |
+
)
|
| 243 |
+
best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore
|
| 244 |
+
best_model_weights = torch.load(best_model_path, weights_only=True)
|
| 245 |
+
|
| 246 |
+
state_dict = {}
|
| 247 |
+
for weight_name, weight in best_model_weights["state_dict"].items():
|
| 248 |
+
state_dict[weight_name.removeprefix("model.")] = weight
|
| 249 |
+
|
| 250 |
+
self.load_state_dict(state_dict)
|
| 251 |
+
self.eval()
|
| 252 |
+
return self
|
| 253 |
+
|
| 254 |
+
def evaluate(
|
| 255 |
+
self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
|
| 256 |
+
) -> str | dict[str, dict[str, float]]:
|
| 257 |
+
"""
|
| 258 |
+
Evaluate the classifier on a given dataset using scikit-learn's classification report.
|
| 259 |
+
|
| 260 |
+
:param X: The texts to predict on.
|
| 261 |
+
:param y: The ground truth labels.
|
| 262 |
+
:param batch_size: The batch size.
|
| 263 |
+
:param threshold: The threshold for multilabel classification.
|
| 264 |
+
:param output_dict: Whether to output the classification report as a dictionary.
|
| 265 |
+
:return: A classification report.
|
| 266 |
+
"""
|
| 267 |
+
self.eval()
|
| 268 |
+
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
|
| 269 |
+
return evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _initialize(self, y: LabelType) -> None:
|
| 273 |
+
"""
|
| 274 |
+
Sets the output dimensionality, the classes, and initializes the head.
|
| 275 |
+
|
| 276 |
+
:param y: The labels.
|
| 277 |
+
:raises ValueError: If the labels are inconsistent.
|
| 278 |
+
"""
|
| 279 |
+
if isinstance(y[0], (str, int)):
|
| 280 |
+
# Check if all labels are strings or integers.
|
| 281 |
+
if not all(isinstance(label, (str, int)) for label in y):
|
| 282 |
+
msg = "Inconsistent label types in y. All labels must be strings or integers."
|
| 283 |
+
raise ValueError(msg)
|
| 284 |
+
self.multilabel = False
|
| 285 |
+
classes = sorted(set(y))
|
| 286 |
+
else:
|
| 287 |
+
# Check if all labels are lists or tuples.
|
| 288 |
+
if not all(isinstance(label, (list, tuple)) for label in y):
|
| 289 |
+
msg = "Inconsistent label types in y. All labels must be lists or tuples."
|
| 290 |
+
raise ValueError(msg)
|
| 291 |
+
self.multilabel = True
|
| 292 |
+
classes = sorted(set(chain.from_iterable(y)))
|
| 293 |
+
|
| 294 |
+
self.classes_ = classes
|
| 295 |
+
self.out_dim = len(self.classes_) # Update output dimension
|
| 296 |
+
self.head = self.construct_head()
|
| 297 |
+
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
|
| 298 |
+
self.w = self.construct_weights()
|
| 299 |
+
self.train()
|
| 300 |
+
|
| 301 |
+
def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> TextDataset:
|
| 302 |
+
"""
|
| 303 |
+
Prepare a dataset. For multilabel classification, each target is converted into a multi-hot vector.
|
| 304 |
+
|
| 305 |
+
:param X: The texts.
|
| 306 |
+
:param y: The labels.
|
| 307 |
+
:param max_length: The maximum length of the input.
|
| 308 |
+
:return: A TextDataset.
|
| 309 |
+
"""
|
| 310 |
+
# This is a speed optimization.
|
| 311 |
+
# assumes a mean token length of 10, which is really high, so safe.
|
| 312 |
+
truncate_length = max_length * 10
|
| 313 |
+
X = [x[:truncate_length] for x in X]
|
| 314 |
+
tokenized: list[list[int]] = [
|
| 315 |
+
encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False)
|
| 316 |
+
]
|
| 317 |
+
if self.multilabel:
|
| 318 |
+
# Convert labels to multi-hot vectors
|
| 319 |
+
num_classes = len(self.classes_)
|
| 320 |
+
labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float)
|
| 321 |
+
mapping = {label: idx for idx, label in enumerate(self.classes_)}
|
| 322 |
+
for i, sample_labels in enumerate(y):
|
| 323 |
+
indices = [mapping[label] for label in sample_labels]
|
| 324 |
+
labels_tensor[i, indices] = 1.0
|
| 325 |
+
else:
|
| 326 |
+
labels_tensor = torch.tensor([self.classes_.index(label) for label in cast("list[str]", y)], dtype=torch.long)
|
| 327 |
+
return TextDataset(tokenized, labels_tensor)
|
| 328 |
+
|
| 329 |
+
def _train_test_split(
|
| 330 |
+
self,
|
| 331 |
+
X: list[str],
|
| 332 |
+
y: list[str] | list[list[str]],
|
| 333 |
+
test_size: float,
|
| 334 |
+
) -> tuple[list[str], list[str], LabelType, LabelType]:
|
| 335 |
+
"""
|
| 336 |
+
Split the data.
|
| 337 |
+
|
| 338 |
+
For single-label classification, stratification is attempted (if possible).
|
| 339 |
+
For multilabel classification, a random split is performed.
|
| 340 |
+
"""
|
| 341 |
+
if not self.multilabel:
|
| 342 |
+
label_counts = Counter(y)
|
| 343 |
+
if min(label_counts.values()) < 2:
|
| 344 |
+
logger.info("Some classes have less than 2 samples. Stratification is disabled.")
|
| 345 |
+
return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True)
|
| 346 |
+
return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y)
|
| 347 |
+
# Multilabel classification does not support stratification.
|
| 348 |
+
return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True)
|
| 349 |
+
|
| 350 |
+
def to_pipeline(self) -> StaticModelPipeline:
|
| 351 |
+
"""Convert the model to an sklearn pipeline."""
|
| 352 |
+
static_model = self.to_static_model()
|
| 353 |
+
|
| 354 |
+
random_state = np.random.RandomState(_RANDOM_SEED)
|
| 355 |
+
n_items = len(self.classes)
|
| 356 |
+
X = random_state.randn(n_items, static_model.dim)
|
| 357 |
+
y = self.classes
|
| 358 |
+
|
| 359 |
+
converted = make_pipeline(MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers))
|
| 360 |
+
converted.fit(X, y)
|
| 361 |
+
mlp_head: MLPClassifier = converted[-1]
|
| 362 |
+
|
| 363 |
+
for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]):
|
| 364 |
+
mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T
|
| 365 |
+
mlp_head.intercepts_[index] = layer.bias.detach().cpu().numpy()
|
| 366 |
+
# Below is necessary to ensure that the converted model works correctly.
|
| 367 |
+
# In scikit-learn, a binary classifier only has a single vector of output coefficients
|
| 368 |
+
# and a single intercept. We use two output vectors.
|
| 369 |
+
# To convert correctly, we need to set the outputs correctly, and fix the activation function.
|
| 370 |
+
# Make sure n_outputs is set to > 1.
|
| 371 |
+
mlp_head.n_outputs_ = self.out_dim
|
| 372 |
+
# Set to softmax or sigmoid
|
| 373 |
+
mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax"
|
| 374 |
+
|
| 375 |
+
return StaticModelPipeline(static_model, converted)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class _ClassifierLightningModule(pl.LightningModule):
|
| 379 |
+
def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None:
|
| 380 |
+
"""Initialize the LightningModule."""
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.model = model
|
| 383 |
+
self.learning_rate = learning_rate
|
| 384 |
+
self.loss_function = nn.CrossEntropyLoss() if not model.multilabel else nn.BCEWithLogitsLoss()
|
| 385 |
+
|
| 386 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 387 |
+
"""Simple forward pass."""
|
| 388 |
+
return self.model(x)
|
| 389 |
+
|
| 390 |
+
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 391 |
+
"""Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training."""
|
| 392 |
+
x, y = batch
|
| 393 |
+
head_out, _ = self.model(x)
|
| 394 |
+
loss = self.loss_function(head_out, y)
|
| 395 |
+
self.log("train_loss", loss)
|
| 396 |
+
return loss
|
| 397 |
+
|
| 398 |
+
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 399 |
+
"""Validation step computing loss and accuracy."""
|
| 400 |
+
x, y = batch
|
| 401 |
+
head_out, _ = self.model(x)
|
| 402 |
+
loss = self.loss_function(head_out, y)
|
| 403 |
+
if self.model.multilabel:
|
| 404 |
+
preds = (torch.sigmoid(head_out) > 0.5).float()
|
| 405 |
+
# Multilabel accuracy is defined as the Jaccard score averaged over samples.
|
| 406 |
+
accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples")
|
| 407 |
+
else:
|
| 408 |
+
accuracy = (head_out.argmax(dim=1) == y).float().mean()
|
| 409 |
+
self.log("val_loss", loss)
|
| 410 |
+
self.log("val_accuracy", accuracy, prog_bar=True)
|
| 411 |
+
|
| 412 |
+
return loss
|
| 413 |
+
|
| 414 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
| 415 |
+
"""Configure optimizer and learning rate scheduler."""
|
| 416 |
+
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
| 417 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 418 |
+
optimizer,
|
| 419 |
+
mode="min",
|
| 420 |
+
factor=0.5,
|
| 421 |
+
patience=3,
|
| 422 |
+
min_lr=1e-6,
|
| 423 |
+
threshold=0.03,
|
| 424 |
+
threshold_mode="rel",
|
| 425 |
+
)
|
| 426 |
+
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
|
src/distiller/model2vec/utils.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import re
|
| 6 |
+
from importlib import import_module
|
| 7 |
+
from importlib.metadata import metadata
|
| 8 |
+
from typing import TYPE_CHECKING, Any, Protocol, cast
|
| 9 |
+
|
| 10 |
+
import safetensors
|
| 11 |
+
from joblib import Parallel
|
| 12 |
+
from tokenizers import Tokenizer
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from collections.abc import Iterator
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ProgressParallel(Parallel):
|
| 25 |
+
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
|
| 28 |
+
"""
|
| 29 |
+
Initialize the ProgressParallel object.
|
| 30 |
+
|
| 31 |
+
:param use_tqdm: Whether to show the progress bar.
|
| 32 |
+
:param total: Total number of tasks (batches) you expect to process. If None,
|
| 33 |
+
it updates the total dynamically to the number of dispatched tasks.
|
| 34 |
+
:param *args: Additional arguments to pass to `Parallel.__init__`.
|
| 35 |
+
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
|
| 36 |
+
"""
|
| 37 |
+
self._use_tqdm = use_tqdm
|
| 38 |
+
self._total = total
|
| 39 |
+
super().__init__(*args, **kwargs)
|
| 40 |
+
|
| 41 |
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
| 42 |
+
"""Create a tqdm context."""
|
| 43 |
+
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
|
| 44 |
+
self._pbar = self._pbar
|
| 45 |
+
return super().__call__(*args, **kwargs)
|
| 46 |
+
|
| 47 |
+
def print_progress(self) -> None:
|
| 48 |
+
"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
|
| 49 |
+
if self._total is None:
|
| 50 |
+
# If no fixed total was given, we dynamically set the total
|
| 51 |
+
self._pbar.total = self.n_dispatched_tasks
|
| 52 |
+
# Move the bar to the number of completed tasks
|
| 53 |
+
self._pbar.n = self.n_completed_tasks
|
| 54 |
+
self._pbar.refresh()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SafeOpenProtocol(Protocol):
|
| 58 |
+
"""Protocol to fix safetensors safe open."""
|
| 59 |
+
|
| 60 |
+
def get_tensor(self, key: str) -> np.ndarray:
|
| 61 |
+
"""Get a tensor."""
|
| 62 |
+
... # pragma: no cover
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
_MODULE_MAP = (("scikit-learn", "sklearn"),)
|
| 66 |
+
_DIVIDERS = re.compile(r"[=<>!]+")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_package_extras(package: str, extra: str) -> Iterator[str]:
|
| 70 |
+
"""Get the extras of the package."""
|
| 71 |
+
try:
|
| 72 |
+
message = metadata(package)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
# For local packages without metadata, return empty iterator
|
| 75 |
+
# This allows the package to work without installed metadata
|
| 76 |
+
logger.debug(f"Could not retrieve metadata for package '{package}': {e}")
|
| 77 |
+
return iter([])
|
| 78 |
+
|
| 79 |
+
all_packages = message.get_all("Requires-Dist") or []
|
| 80 |
+
for package in all_packages:
|
| 81 |
+
name, *rest = package.split(";", maxsplit=1)
|
| 82 |
+
if rest:
|
| 83 |
+
# Extract and clean the extra requirement
|
| 84 |
+
found_extra = rest[0].split("==")[-1].strip(" \"'")
|
| 85 |
+
if found_extra == extra:
|
| 86 |
+
prefix, *_ = _DIVIDERS.split(name)
|
| 87 |
+
yield prefix.strip()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def importable(module: str, extra: str) -> None:
|
| 91 |
+
"""Check if a module is importable."""
|
| 92 |
+
module = dict(_MODULE_MAP).get(module, module)
|
| 93 |
+
try:
|
| 94 |
+
import_module(module)
|
| 95 |
+
except ImportError:
|
| 96 |
+
msg = f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`"
|
| 97 |
+
raise ImportError(msg)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def setup_logging() -> None:
|
| 101 |
+
"""Simple logging setup."""
|
| 102 |
+
from rich.logging import RichHandler
|
| 103 |
+
|
| 104 |
+
logging.basicConfig(
|
| 105 |
+
level="INFO",
|
| 106 |
+
format="%(name)s - %(message)s",
|
| 107 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 108 |
+
handlers=[RichHandler(rich_tracebacks=True)],
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
|
| 113 |
+
"""Load a local model."""
|
| 114 |
+
embeddings_path = folder / "model.safetensors"
|
| 115 |
+
tokenizer_path = folder / "tokenizer.json"
|
| 116 |
+
config_path = folder / "config.json"
|
| 117 |
+
|
| 118 |
+
opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy"))
|
| 119 |
+
embeddings = opened_tensor_file.get_tensor("embeddings")
|
| 120 |
+
|
| 121 |
+
config = json.load(open(config_path)) if config_path.exists() else {}
|
| 122 |
+
|
| 123 |
+
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
| 124 |
+
|
| 125 |
+
if len(tokenizer.get_vocab()) != len(embeddings):
|
| 126 |
+
logger.warning(
|
| 127 |
+
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
return embeddings, tokenizer, config
|
src/distiller/model2vec/version.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version_triple__ = (0, 5, 0)
|
| 2 |
+
__version__ = ".".join(map(str, __version_triple__))
|