File size: 2,336 Bytes
8201032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import onnxruntime as ort
import numpy as np
from tokenizers import Tokenizer
from typing import List


class IndonesianEmbeddingEngine:
    def __init__(
        self,
        model_path: str = "./onnx/indonesian_embedding.onnx",
        tokenizer_path: str = "./onnx/tokenizer.json",
        max_length: int = 384,
    ):
        self.max_length = max_length

        # Load tokenizer
        self.tokenizer = Tokenizer.from_file(tokenizer_path)

        # Load ONNX model
        self.session = ort.InferenceSession(
            model_path,
            providers=["CPUExecutionProvider"]
        )

        self.input_names = {i.name for i in self.session.get_inputs()}

    def _tokenize(self, texts: List[str]):
        encodings = self.tokenizer.encode_batch(texts)

        input_ids = []
        attention_mask = []

        for enc in encodings:
            ids = enc.ids[: self.max_length]
            mask = [1] * len(ids)

            pad_len = self.max_length - len(ids)
            if pad_len > 0:
                ids += [0] * pad_len
                mask += [0] * pad_len

            input_ids.append(ids)
            attention_mask.append(mask)

        return {
            "input_ids": np.array(input_ids, dtype=np.int64),
            "attention_mask": np.array(attention_mask, dtype=np.int64),
        }

    def _mean_pooling(self, token_embeddings, attention_mask):
        mask = attention_mask[..., None]
        summed = np.sum(token_embeddings * mask, axis=1)
        counts = np.clip(mask.sum(axis=1), a_min=1e-9, a_max=None)
        return summed / counts

    def _normalize(self, vectors):
        norms = np.linalg.norm(vectors, axis=1, keepdims=True)
        return vectors / norms

    def embed(self, texts: List[str]) -> List[List[float]]:
        inputs = self._tokenize(texts)

        ort_inputs = {}
        if "input_ids" in self.input_names:
            ort_inputs["input_ids"] = inputs["input_ids"]
        if "attention_mask" in self.input_names:
            ort_inputs["attention_mask"] = inputs["attention_mask"]

        outputs = self.session.run(None, ort_inputs)

        token_embeddings = outputs[0]  # [batch, seq, hidden]
        pooled = self._mean_pooling(token_embeddings, inputs["attention_mask"])
        normalized = self._normalize(pooled)

        return normalized.tolist()