|
|
import numpy as np |
|
|
from zstandard import ZstdCompressor |
|
|
from pathlib import Path |
|
|
import io |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from torch.nn import EmbeddingBag |
|
|
import torch |
|
|
|
|
|
|
|
|
def save_data(path: Path, tensor: torch.Tensor): |
|
|
"""Writes out the static embeddings to a .npy.zst file""" |
|
|
assert str(path).endswith(".npy.zst") |
|
|
buffer = io.BytesIO() |
|
|
np.save(buffer, tensor.detach().numpy()) |
|
|
|
|
|
with ( |
|
|
open(path, "wb") as outfile, |
|
|
ZstdCompressor().stream_writer(outfile) as writer, |
|
|
): |
|
|
writer.write(buffer.getvalue()) |
|
|
|
|
|
|
|
|
model_path = Path("model") |
|
|
model_name = "sentence-transformers/static-similarity-mrl-multilingual-v1" |
|
|
vocab_size = 105_879 |
|
|
dimensions = 1024 |
|
|
|
|
|
|
|
|
def load_embeddings(): |
|
|
model = SentenceTransformer(model_name, device="cpu") |
|
|
embedding_bag: EmbeddingBag = model[0].embedding |
|
|
embeddings = torch.Tensor(embedding_bag.weight) |
|
|
|
|
|
print(embeddings.shape) |
|
|
assert embeddings.shape == torch.Size([vocab_size, dimensions]) |
|
|
|
|
|
print("float32") |
|
|
print(f" 1024 dim - {embeddings.shape[0] * 1024 * 4 / 1024 / 1024:,.1f} MiB") |
|
|
print(f" 512 dim - {embeddings.shape[0] * 512 * 4 / 1024 / 1024:,.1f} MiB") |
|
|
print(f" 256 dim - {embeddings.shape[0] * 256 * 4 / 1024 / 1024:,.1f} MiB") |
|
|
|
|
|
print("float16") |
|
|
print(f" 1024 dim - {embeddings.shape[0] * 1024 * 2 / 1024 / 1024:,.1f} MiB") |
|
|
print(f" 512 dim - {embeddings.shape[0] * 512 * 2 / 1024 / 1024:,.1f} MiB") |
|
|
print(f" 256 dim - {embeddings.shape[0] * 256 * 2 / 1024 / 1024:,.1f} MiB") |
|
|
|
|
|
for dim in (1024, 512, 384, 256, 128): |
|
|
truncated = embeddings[:, :dim] |
|
|
assert truncated.shape == torch.Size([vocab_size, dim]) |
|
|
|
|
|
save_data(model_path / f"static-embeddings.{dim}.fp32.npy.zst", embeddings) |
|
|
save_data( |
|
|
model_path / f"static-embeddings.{dim}.fp16.npy.zst", |
|
|
embeddings.to(dtype=torch.float16), |
|
|
) |
|
|
save_data( |
|
|
model_path / f"static-embeddings.{dim}.int8.npy.zst", |
|
|
embeddings.to(dtype=torch.int8), |
|
|
) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
load_embeddings() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|