indonesian-embedding-small-custom / embedding_engine.py
chxxlk
setting embedding for chatbot
8201032
import onnxruntime as ort
import numpy as np
from tokenizers import Tokenizer
from typing import List
class IndonesianEmbeddingEngine:
def __init__(
self,
model_path: str = "./onnx/indonesian_embedding.onnx",
tokenizer_path: str = "./onnx/tokenizer.json",
max_length: int = 384,
):
self.max_length = max_length
# Load tokenizer
self.tokenizer = Tokenizer.from_file(tokenizer_path)
# Load ONNX model
self.session = ort.InferenceSession(
model_path,
providers=["CPUExecutionProvider"]
)
self.input_names = {i.name for i in self.session.get_inputs()}
def _tokenize(self, texts: List[str]):
encodings = self.tokenizer.encode_batch(texts)
input_ids = []
attention_mask = []
for enc in encodings:
ids = enc.ids[: self.max_length]
mask = [1] * len(ids)
pad_len = self.max_length - len(ids)
if pad_len > 0:
ids += [0] * pad_len
mask += [0] * pad_len
input_ids.append(ids)
attention_mask.append(mask)
return {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
}
def _mean_pooling(self, token_embeddings, attention_mask):
mask = attention_mask[..., None]
summed = np.sum(token_embeddings * mask, axis=1)
counts = np.clip(mask.sum(axis=1), a_min=1e-9, a_max=None)
return summed / counts
def _normalize(self, vectors):
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
return vectors / norms
def embed(self, texts: List[str]) -> List[List[float]]:
inputs = self._tokenize(texts)
ort_inputs = {}
if "input_ids" in self.input_names:
ort_inputs["input_ids"] = inputs["input_ids"]
if "attention_mask" in self.input_names:
ort_inputs["attention_mask"] = inputs["attention_mask"]
outputs = self.session.run(None, ort_inputs)
token_embeddings = outputs[0] # [batch, seq, hidden]
pooled = self._mean_pooling(token_embeddings, inputs["attention_mask"])
normalized = self._normalize(pooled)
return normalized.tolist()