File size: 1,826 Bytes
5038bcb
998ba81
5038bcb
998ba81
 
 
5038bcb
998ba81
0cca1ec
998ba81
 
 
5038bcb
998ba81
 
 
 
 
 
 
 
5038bcb
998ba81
5038bcb
998ba81
 
 
 
5038bcb
998ba81
 
5038bcb
998ba81
5038bcb
 
 
 
998ba81
5038bcb
998ba81
5038bcb
 
 
998ba81
 
 
 
 
 
 
 
5038bcb
998ba81
5038bcb
998ba81
 
5038bcb
998ba81
d562e10
998ba81
 
 
d562e10
998ba81
0cca1ec
 
998ba81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import numpy as np
from typing import List
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download

MODEL_ID = "onnx-community/embeddinggemma-300m-ONNX"

class EmbeddingModel:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

        self.model_path = hf_hub_download(
            repo_id=MODEL_ID,
            filename="onnx/model.onnx"
        )
        self.data_path = hf_hub_download(
            repo_id=MODEL_ID,
            filename="onnx/model.onnx_data"
        )

        model_dir = os.path.dirname(self.model_path)

        self.session = ort.InferenceSession(
            self.model_path,
            providers=["CPUExecutionProvider"],
        )

        self.input_names = [i.name for i in self.session.get_inputs()]
        self.output_names = [o.name for o in self.session.get_outputs()]

    async def embed_text(self, text: str, max_length=512) -> List[float]:

        encoded = self.tokenizer(
            text,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors="np",
        )

        input_ids = encoded["input_ids"].astype(np.int64)
        attention_mask = encoded["attention_mask"].astype(np.int64)

        outputs = self.session.run(
            self.output_names,
            {
                self.input_names[0]: input_ids,
                self.input_names[1]: attention_mask,
            },
        )
        last_hidden = outputs[0]

        mask = attention_mask[..., None]
        pooled = (last_hidden * mask).sum(axis=1) / mask.sum(axis=1)

        vec = pooled[0]

        norm = np.linalg.norm(vec)
        if norm > 0:
            vec = vec / norm

        return vec.tolist()


embedding_model = EmbeddingModel()