test_space / src /utils /fastembed_manager.py
Minh
init
6912ad8
from fastembed import TextEmbedding
from fastembed.common.model_description import PoolingType, ModelSource
from huggingface_hub import snapshot_download
import time
# def download_model_from_hf(model_name: str, save_path: str):
# try:
# snapshot_download(
# repo_id=model_name,
# local_dir=save_path,
# allow_patterns=["onnx/*"],
# local_dir_use_symlinks=False,
# )
# except Exception as e:
# print(f"Error downloading model from Hugging Face: {e}")
# raise e
def add_custom_embedding_model(
model_name: str, source_model: str, source_file: str, dim: int, from_hf: bool = True
):
"""Add a custom embedding model to FastEmbed and return an instance of TextEmbedding."""
if from_hf:
try:
TextEmbedding.add_custom_model(
model=model_name,
pooling=PoolingType.MEAN,
normalization=True,
sources=ModelSource(hf=source_model), # can be used with an `url` to load files from a private storage
dim=dim,
model_file=source_file, # can be used to load an already supported model with another optimization or quantization, e.g. onnx/model_O4.onnx
)
print(f"Successfully added model '{model_name}' from Hugging Face.")
return TextEmbedding(model_name=model_name)
except Exception as e:
print(f"Error adding model from Hugging Face: {e}")
raise e
else:
try:
TextEmbedding.add_custom_model(
model=model_name,
pooling=PoolingType.MEAN,
normalization=True,
sources=ModelSource(url=source_model),
dim=dim,
model_file=source_file,
)
print(f"Successfully added model '{model_name}' from local file.")
return TextEmbedding(model_name=model_name)
except Exception as e:
print(f"Error adding model from local file: {e}")
raise e
if __name__ == "__main__":
# Example usage: adding a custom model from Hugging Face
# add_custom_embedding_model(
# model_name="models/Vietnamese_Embedding",
# source_model="AITeamVN/Vietnamese_Embedding",
# source_file="onnx/model.onnx_data",
# dim=1024,
# from_hf=True
# )
# model = TextEmbedding(model_name="AITeamVN/Vietnamese_Embedding")
# embeddings = list(model.embed("text to embed"))
# # Ex
# download_model_from_hf("AITeamVN/Vietnamese_Embedding", "./models/Vietnamese_Embedding")
# from fastembed import TextEmbedding
# from fastembed.common.model_description import PoolingType, ModelSource
TextEmbedding.add_custom_model(
model="Mint1456/Vietnamese_Embedding_OnnX_Quantized",
pooling=PoolingType.MEAN,
normalization=True,
sources=ModelSource(hf="Mint1456/Vietnamese_Embedding_OnnX_Quantized"), # can be used with an `url` to load files from a private storage
dim=1024,
model_file="model.onnx", # can be used to load an already supported model with another optimization or quantization, e.g. onnx/model_O4.onnx
)
model = TextEmbedding(model_name="Mint1456/Vietnamese_Embedding_OnnX_Quantized")
start = time.perf_counter()
embeddings = list(model.embed("define artificial intelligence"))
print(f"len embeding {len(embeddings[0])}, time taken: {time.perf_counter() - start} seconds")