File size: 1,961 Bytes
5038bcb
998ba81
5038bcb
998ba81
 
 
5038bcb
998ba81
0cca1ec
998ba81
 
fd50b36
998ba81
5038bcb
fd50b36
998ba81
 
 
 
 
 
 
 
 
5038bcb
998ba81
5038bcb
fd50b36
998ba81
 
 
 
5038bcb
998ba81
 
5038bcb
998ba81
5038bcb
 
 
 
998ba81
5038bcb
998ba81
5038bcb
 
 
998ba81
 
 
 
 
 
 
 
5038bcb
998ba81
5038bcb
998ba81
 
5038bcb
998ba81
d562e10
998ba81
 
 
d562e10
998ba81
0cca1ec
 
998ba81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import numpy as np
from typing import List
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download

MODEL_ID = "onnx-community/embeddinggemma-300m-ONNX"

class EmbeddingModel:
    def __init__(self):
        print("Loading tokenizer…")
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

        print("Downloading ONNX model files…")

        self.model_path = hf_hub_download(
            repo_id=MODEL_ID,
            filename="onnx/model.onnx"
        )
        self.data_path = hf_hub_download(
            repo_id=MODEL_ID,
            filename="onnx/model.onnx_data"
        )

        model_dir = os.path.dirname(self.model_path)

        print("Creating inference session…")
        self.session = ort.InferenceSession(
            self.model_path,
            providers=["CPUExecutionProvider"],
        )

        self.input_names = [i.name for i in self.session.get_inputs()]
        self.output_names = [o.name for o in self.session.get_outputs()]

    async def embed_text(self, text: str, max_length=512) -> List[float]:

        encoded = self.tokenizer(
            text,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors="np",
        )

        input_ids = encoded["input_ids"].astype(np.int64)
        attention_mask = encoded["attention_mask"].astype(np.int64)

        outputs = self.session.run(
            self.output_names,
            {
                self.input_names[0]: input_ids,
                self.input_names[1]: attention_mask,
            },
        )
        last_hidden = outputs[0]

        mask = attention_mask[..., None]
        pooled = (last_hidden * mask).sum(axis=1) / mask.sum(axis=1)

        vec = pooled[0]

        norm = np.linalg.norm(vec)
        if norm > 0:
            vec = vec / norm

        return vec.tolist()


embedding_model = EmbeddingModel()