| import onnxruntime as ort |
| import numpy as np |
| from tokenizers import Tokenizer |
| from typing import List |
|
|
|
|
| class IndonesianEmbeddingEngine: |
| def __init__( |
| self, |
| model_path: str = "./onnx/indonesian_embedding.onnx", |
| tokenizer_path: str = "./onnx/tokenizer.json", |
| max_length: int = 384, |
| ): |
| self.max_length = max_length |
|
|
| |
| self.tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
| |
| self.session = ort.InferenceSession( |
| model_path, |
| providers=["CPUExecutionProvider"] |
| ) |
|
|
| self.input_names = {i.name for i in self.session.get_inputs()} |
|
|
| def _tokenize(self, texts: List[str]): |
| encodings = self.tokenizer.encode_batch(texts) |
|
|
| input_ids = [] |
| attention_mask = [] |
|
|
| for enc in encodings: |
| ids = enc.ids[: self.max_length] |
| mask = [1] * len(ids) |
|
|
| pad_len = self.max_length - len(ids) |
| if pad_len > 0: |
| ids += [0] * pad_len |
| mask += [0] * pad_len |
|
|
| input_ids.append(ids) |
| attention_mask.append(mask) |
|
|
| return { |
| "input_ids": np.array(input_ids, dtype=np.int64), |
| "attention_mask": np.array(attention_mask, dtype=np.int64), |
| } |
|
|
| def _mean_pooling(self, token_embeddings, attention_mask): |
| mask = attention_mask[..., None] |
| summed = np.sum(token_embeddings * mask, axis=1) |
| counts = np.clip(mask.sum(axis=1), a_min=1e-9, a_max=None) |
| return summed / counts |
|
|
| def _normalize(self, vectors): |
| norms = np.linalg.norm(vectors, axis=1, keepdims=True) |
| return vectors / norms |
|
|
| def embed(self, texts: List[str]) -> List[List[float]]: |
| inputs = self._tokenize(texts) |
|
|
| ort_inputs = {} |
| if "input_ids" in self.input_names: |
| ort_inputs["input_ids"] = inputs["input_ids"] |
| if "attention_mask" in self.input_names: |
| ort_inputs["attention_mask"] = inputs["attention_mask"] |
|
|
| outputs = self.session.run(None, ort_inputs) |
|
|
| token_embeddings = outputs[0] |
| pooled = self._mean_pooling(token_embeddings, inputs["attention_mask"]) |
| normalized = self._normalize(pooled) |
|
|
| return normalized.tolist() |
|
|