Spaces:
Sleeping
Sleeping
File size: 1,961 Bytes
5038bcb 998ba81 5038bcb 998ba81 5038bcb 998ba81 0cca1ec 998ba81 fd50b36 998ba81 5038bcb fd50b36 998ba81 5038bcb 998ba81 5038bcb fd50b36 998ba81 5038bcb 998ba81 5038bcb 998ba81 5038bcb 998ba81 5038bcb 998ba81 5038bcb 998ba81 5038bcb 998ba81 5038bcb 998ba81 5038bcb 998ba81 d562e10 998ba81 d562e10 998ba81 0cca1ec 998ba81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import numpy as np
from typing import List
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
MODEL_ID = "onnx-community/embeddinggemma-300m-ONNX"
class EmbeddingModel:
def __init__(self):
print("Loading tokenizer…")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Downloading ONNX model files…")
self.model_path = hf_hub_download(
repo_id=MODEL_ID,
filename="onnx/model.onnx"
)
self.data_path = hf_hub_download(
repo_id=MODEL_ID,
filename="onnx/model.onnx_data"
)
model_dir = os.path.dirname(self.model_path)
print("Creating inference session…")
self.session = ort.InferenceSession(
self.model_path,
providers=["CPUExecutionProvider"],
)
self.input_names = [i.name for i in self.session.get_inputs()]
self.output_names = [o.name for o in self.session.get_outputs()]
async def embed_text(self, text: str, max_length=512) -> List[float]:
encoded = self.tokenizer(
text,
truncation=True,
padding=True,
max_length=max_length,
return_tensors="np",
)
input_ids = encoded["input_ids"].astype(np.int64)
attention_mask = encoded["attention_mask"].astype(np.int64)
outputs = self.session.run(
self.output_names,
{
self.input_names[0]: input_ids,
self.input_names[1]: attention_mask,
},
)
last_hidden = outputs[0]
mask = attention_mask[..., None]
pooled = (last_hidden * mask).sum(axis=1) / mask.sum(axis=1)
vec = pooled[0]
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
return vec.tolist()
embedding_model = EmbeddingModel()
|