File size: 18,042 Bytes
f7fef32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
from dataclasses import dataclass
import shutil
from textwrap import dedent, indent
from typing import Any
import numpy as np
from zstandard import ZstdCompressor
from pathlib import Path
import io
from sentence_transformers import SentenceTransformer
from torch.nn import EmbeddingBag
import torch
from model2vec import StaticModel
from tokenizers import Encoding, Tokenizer

models_path = Path("models")


@dataclass
class ModelCard:
    owner: str
    repo: str
    # The dimensions that were applied with Matroyshka Loss.
    matroyshka_dims: list[int]
    description: str
    license: str

    def name(self):
        return f"{self.owner}/{self.repo}"

    def path(self):
        return models_path / self.owner / self.repo

    def get_description(self):
        return dedent(self.description).strip()


def zst_compress_file(input: Path):
    cctx = ZstdCompressor()
    output = input.parent / f"{input.name}.zst"
    print(f"Compressing {output}")
    with open(input, "rb") as fin, open(output, "wb") as fout:
        cctx.copy_stream(fin, fout)


def save_data(path: Path, tensor: torch.Tensor):
    """Writes out the static embeddings to a .npy and .npy.zst file"""
    buffer = io.BytesIO()

    if tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
        # Store as the raw bytes.
        np.save(buffer, tensor.detach().view(torch.uint8).numpy())
    else:
        np.save(buffer, tensor.detach().numpy())

    print(f"Saving {path}")
    with (open(path, "wb") as outfile,):
        outfile.write(buffer.getvalue())

    zst_compress_file(path)


def quantization_loss_mse(tensor: torch.Tensor, dtype: torch.dtype):
    """
    Compute reconstruction loss when converting embeddings to a datatype and back using
    the mean squared error, which punishes big errors more than small ones.
    """

    # Original → quantize → dequantize
    roundtrip = tensor.detach().to(dtype).to(tensor.dtype)

    # Mean squared error
    return torch.mean((tensor - roundtrip) ** 2).item()


def quantization_loss_mae(tensor: torch.Tensor, dtype: torch.dtype):
    """
    Compute reconstruction loss when converting embeddings to a datatype and back using
    the mean absolute error, which is less sensitive to outliers than MSE.
    """

    # Original → quantize → dequantize
    roundtrip = tensor.detach().to(dtype).to(tensor.dtype)

    # Mean absolute error
    return torch.mean(torch.abs(tensor - roundtrip)).item()


def quantization_loss_cosine(tensor: torch.Tensor, dtype: torch.dtype):
    """
    Compute reconstruction loss when converting embeddings to a datatype and back using
    cosine similarity. This measures whether the embedding directions are preserved
    after quantization, independent of their magnitudes.
    """

    # Original → quantize → dequantize
    roundtrip = tensor.detach().to(dtype).to(tensor.dtype)

    # Flatten both to 2D (num_vectors, dimensions) in case tensor is 1D or higher-D
    if tensor.ndim == 1:
        orig = tensor.unsqueeze(0)
        recon = roundtrip.unsqueeze(0)
    else:
        orig = tensor.view(tensor.shape[0], -1)
        recon = roundtrip.view(roundtrip.shape[0], -1)

    # Cosine similarity per vector, then average
    cos = torch.nn.functional.cosine_similarity(orig, recon, dim=1)
    return cos.mean().item()


def export_embeddings(model_card: ModelCard, embeddings: torch.Tensor) -> None:
    vocab_size, dimensions = embeddings.shape

    # This logic can always be adjusted for models with different shapes.
    assert (
        embeddings.dtype == torch.float32
    ), f"The embeddings {embeddings.dtype} are assumed to be float32."

    for dim in model_card.matroyshka_dims:
        assert (
            dim <= dimensions
        ), f"The Matroyshka dimensions {dim} were bigger than the models dimensions of {dimensions}"

        truncated = embeddings[:, :dim]
        assert truncated.shape == torch.Size([vocab_size, dim])

        save_data(model_card.path() / f"fp32.d{dim}.npy", truncated)
        save_data(
            model_card.path() / f"fp16.d{dim}.npy",
            truncated.to(dtype=torch.float16),
        )
        save_data(
            model_card.path() / f"fp8_e5m2.d{dim}.npy",
            truncated.to(dtype=torch.float8_e5m2),
        )
        save_data(
            model_card.path() / f"fp8_e4m3.d{dim}.npy",
            truncated.to(dtype=torch.float8_e4m3fn),
        )


def normalized_mean_pooling(x: torch.Tensor) -> torch.Tensor:
    pooled = x.mean(dim=0)
    normalized = torch.nn.functional.normalize(pooled, dim=0)
    return normalized


def export_readme(
    model_card: ModelCard,
    embeddings: torch.Tensor,
    tokenizer: Tokenizer,
):
    vocab_size, dimensions = embeddings.shape
    norms = torch.norm(embeddings, dim=1)  # shape: [vocab_size]

    phrases = [
        "The committee approved the proposal after hours of heated discussion and several last-minute amendments."
        "When training large neural networks, careful tuning of hyperparameters can significantly affect performance and stability."
        "Despite the heavy rain, the concert continued as planned and the crowd stayed enthusiastic until the final encore."
        "In ancient mythology, heroes often embarked on perilous journeys to discover hidden truths about themselves and their world."
        "The new smartphone model features an improved camera system, faster processing, and extended battery life compared to its predecessor."
        "He tried to explain the concept using simple analogies, but the underlying mathematics remained difficult to grasp for most listeners."
        "After weeks of negotiations, the two countries signed a historic trade agreement aimed at reducing tariffs and boosting cooperation."
        "She paused for a moment before answering, choosing her words carefully to avoid misunderstanding in such a delicate situation."
        "The detective pieced together the timeline of events, realizing that the key witness had provided a contradictory statement."
        "Remote work has changed the way teams collaborate, with online tools replacing traditional office routines and in-person meetings."
    ]

    cosine_similarity = {
        torch.float16: [],
        torch.float8_e4m3fn: [],
        torch.float8_e5m2: [],
    }

    for phrase in phrases:
        encoding: Encoding = tokenizer.encode(phrase)
        embedded_phrase = embeddings[torch.tensor(encoding.ids, dtype=torch.long)]

        for dtype in cosine_similarity.keys():
            pooling_unquantized = normalized_mean_pooling(embedded_phrase)
            pooling_roundtrip = normalized_mean_pooling(
                embedded_phrase.to(dtype).to(torch.float32)
            )
            cosine = torch.dot(pooling_unquantized, pooling_roundtrip).item()
            cosine_similarity[dtype].append(cosine)

    avg_cosine_similarity = {
        dtype: sum(values) / len(values) for dtype, values in cosine_similarity.items()
    }

    tokenizer_examples = []
    for text in [
        "This is an example of encoding",
        "The quick brown fox jumps over the lazy dog.",
        "Curaçao, naïve fiancé, jalapeño, déjà vu.",
        "Привет, как дела?",
        "Бързата кафява лисица прескача мързеливото куче.",
        "Γρήγορη καφέ αλεπού πηδάει πάνω από τον τεμπέλη σκύλο.",
        "اللغة العربية جميلة وغنية بالتاريخ.",
        "مرحبا بالعالم!",
        "Simplified: 快速的棕色狐狸跳过懒狗。",
        "Traditional: 快速的棕色狐狸跳過懶狗。",
        "素早い茶色の狐が怠け者の犬を飛び越える。",
        "コンピュータープログラミング",
        "빠른 갈색 여우가 게으른 개를 뛰어넘습니다.",
        "तेज़ भूरी लोमड़ी आलसी कुत्ते के ऊपर कूदती है।",
        "দ্রুত বাদামী শিয়াল অলস কুকুরের উপর দিয়ে লাফ দেয়।",
        "வேகமான பழுப்பு நரி சோம்பேறி நாயின் மேல் குதிக்கிறது.",
        "สุนัขจิ้งจอกสีน้ำตาลกระโดดข้ามสุนัขขี้เกียจ.",
        "ብሩክ ቡናማ ቀበሮ ሰነፍ ውሻን ተዘልሏል።",
        "Hello 世界 مرحبا 🌍",
        "123, αβγ, абв, العربية, 中文, हिन्दी.",
    ]:
        encoding = tokenizer.encode(text)
        tokens = [f"`{token}`" for token in encoding.tokens]

        tokenizer_examples.append(f"**Input:** {text}<br/>")
        tokenizer_examples.append(f"**Tokens**: {' '.join(tokens)}")
        tokenizer_examples.append("")

    tokenizer_output = "\n".join(tokenizer_examples)

    with (model_card.path() / "README.md").open("wt") as file:
        prefix = "                "

        file.write(
            dedent(
                f"""
                # [{model_card.name()}](https://huggingface.co/{model_card.name()})
                
                License: [{model_card.license}](https://choosealicense.com/licenses/{model_card.license}/)
                
                {indent(model_card.get_description(), prefix).strip()}
                
                ## Model Stats
                
                Stats that describe the embeddings tensor shapes and value distribution.

                | item          | metric                  | value |
                | --------------| ----------------------- | ----- |
                | vocab         | size                    | {vocab_size:,.0f} |
                | embedding     | dimensions              | {dimensions:,.0f} |
                | vector length | mean                    | {norms.mean().item():.2f} |
                | vector length | median                  | {norms.median().item():.2f} |
                | vector length | stddev                  | {norms.std().item():.2f} |
                | values        | mean                    | {embeddings.mean().item():.2f} |
                | values        | median                  | {embeddings.median().item():.2f} |
                | values        | stddev                  | {embeddings.std().item():.2f} |
                
                ## Mean Pooled Quantization Loss
                
                This test roundtrips the vectors through quantization, but performs the
                mean pooling arithmetic in float32 space. The quantized and unquantized
                mean pooled vectors are compared to each other to determine their cosine
                similarity, to show how much the meaning of the vector has changed due
                to quantization.
                
                | Precision     | Cosine Similarity |
                | ------------- | ----------------- |
                | fp16          | {avg_cosine_similarity[torch.float16]:.5f} |
                | fp8 e4m3      | {avg_cosine_similarity[torch.float8_e4m3fn]:.5f} |
                | fp8 e5m2      | {avg_cosine_similarity[torch.float8_e5m2]:.5f} |
                
                ## Quantization Loss Per Vector
                
                While ultimately the embedding vectors will be mean pooled together, it's
                still useful to look at the loss per-vector in the embedding table to see
                which quantization strategies retain the most vector meaning.
                
                - **Cosine Similarity** — measures how well the *direction* of embedding vectors
                is preserved after quantization, independent of scale. This is especially
                relevant when embeddings are used for similarity search or retrieval.
                - **MSE (Mean Squared Error)** — emphasizes large errors by squaring the
                differences. Useful for detecting whether any values are badly distorted.
                - **MAE (Mean Absolute Error)** — the average absolute difference between
                original and quantized values. Easier to interpret, less sensitive to outliers.

                | Precision     | Metric | Value |
                | ------------- | ------ | ----- |
                | fp16          | cosine similarity | {quantization_loss_cosine(embeddings, torch.float16):.5f} |
                | fp8 e4m3      | cosine similarity | {quantization_loss_cosine(embeddings, torch.float8_e4m3fn):.5f} |
                | fp8 e5m2      | cosine similarity | {quantization_loss_cosine(embeddings, torch.float8_e5m2):.5f} |
                | fp16          | MSE    | {quantization_loss_mse(embeddings, torch.float16):.5f} |
                | fp8 e4m3      | MSE    | {quantization_loss_mse(embeddings, torch.float8_e4m3fn):.5f} |
                | fp8 e5m2      | MSE    | {quantization_loss_mse(embeddings, torch.float8_e5m2):.5f} |
                | fp16          | MAE    | {quantization_loss_mae(embeddings, torch.float16):.5f} |
                | fp8 e4m3      | MAE    | {quantization_loss_mae(embeddings, torch.float8_e4m3fn):.5f} |
                | fp8 e5m2      | MAE    | {quantization_loss_mae(embeddings, torch.float8_e5m2):.5f} |
                
                ## Tokenizer Examples
                
                {indent(tokenizer_output, prefix).strip()}
                """
            ).strip()
        )


def export_tokenizer(model_card: ModelCard, tokenizer: Tokenizer) -> None:
    tokenizer_path = model_card.path() / "tokenizer.json"
    print(f"Exporting tokenizer: {tokenizer_path}")
    tokenizer.save(str(tokenizer_path))
    zst_compress_file(tokenizer_path)


def export_sentence_transformers(model_card: ModelCard) -> None:
    """Extract the embeddings and tokenizer from SentenceTransformers"""

    print("Processing", model_card.name())

    model = SentenceTransformer(model_card.name(), device="cpu")
    embedding_bag: EmbeddingBag = model[0].embedding  # type: ignore
    model_card.path().mkdir(exist_ok=True, parents=True)
    embeddings = torch.Tensor(embedding_bag.weight)

    export_embeddings(model_card, embeddings)
    export_tokenizer(model_card, model.tokenizer)
    export_readme(model_card, embeddings, model.tokenizer)


def export_model2vec(model_card: ModelCard) -> None:
    """Extract the embeddings and tokenizer from model2vec"""

    print("Processing", model_card.name())

    model = StaticModel.from_pretrained(model_card.name())
    model_card.path().mkdir(exist_ok=True, parents=True)
    embeddings = torch.from_numpy(model.embedding)
    export_embeddings(model_card, embeddings)
    export_tokenizer(model_card, model.tokenizer)
    export_readme(model_card, embeddings, model.tokenizer)


def main() -> None:
    # Static embedders that use sentence_transformers models.
    sentence_transformers_models = [
        ModelCard(
            owner="sentence-transformers",
            repo="static-similarity-mrl-multilingual-v1",
            description="""
            Multi-lingual similarity embeddings that were trained with Matroyshka loss
            that allows for more effective truncation of the embedding vectors. It
            was trained on a variety of domains of multilingual datasets.
            
            It's a general purpose model that can be used for semantic textual similarity,
            paraphrase mining, text classification, clustering, and more
            """,
            matroyshka_dims=[32, 64, 128, 256, 512, 1024],
            license="apache-2.0",
        ),
        ModelCard(
            owner="sentence-transformers",
            repo="static-retrieval-mrl-en-v1",
            description="""
            English-only uncased similarity embeddings that were trained with Matroyshka
            loss that allows for more effective truncation of the embedding vectors. It
            was trained on a variety of domains of monolingual datasets. I was designed
            specifically for similarity retrieval.
            """,
            matroyshka_dims=[32, 64, 128, 256, 512, 1024],
            license="apache-2.0",
        ),
    ]
    # Static embedders that use model2vec.
    model2vec_models = [
        ModelCard(
            owner="minishlab",
            repo="potion-multilingual-128M",
            # These are assumed as their is no python reference implementation:
            matroyshka_dims=[32, 64, 128, 256],
            description="""
            A multilingual embedder. The details are a bit scant on how it's trained as
            there is no source code for it. However, it's likely a close architecture
            to the potion-retrieval-32M model, but trained on Common Crawl data.
            
            The 128M references the number of parameters in the embeddings:
            
            256 dimensions * 500,353 vocab.
            """,
            license="mit",
        ),
        ModelCard(
            owner="minishlab",
            repo="potion-retrieval-32M",
            matroyshka_dims=[32, 64, 128, 256, 512],
            description="""
            The token embeddings from a monolingual English 32M parameter model that was
            distilled from embeddings that were initialized from the the multi-domain
            [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5)
            
            The 32M references the number of parameters in the embeddings:
            
            512 dimension * 63,091 vocab.
            """,
            license="mit",
        ),
    ]

    if models_path.exists():
        print(f"Removing the old models folder: {models_path}")
        shutil.rmtree(models_path)
        models_path.mkdir()

    for model_card in sentence_transformers_models:
        export_sentence_transformers(model_card)

    for model_card in model2vec_models:
        export_model2vec(model_card)


if __name__ == "__main__":
    main()