|
|
|
|
|
from typing import List |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
from fastapi import Request |
|
|
|
|
|
def _l2_normalize(vec: np.ndarray) -> List[float]: |
|
|
norm = np.linalg.norm(vec) or 1.0 |
|
|
return (vec / norm).tolist() |
|
|
|
|
|
def _generate_position_ids(input_ids: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
input_ids: [batch_size, seq_len] |
|
|
return: position_ids of shape [batch_size, seq_len] with int64 dtype |
|
|
""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
position_ids = np.arange(seq_len)[None, :].astype("int64") |
|
|
return np.tile(position_ids, (batch_size, 1)) |
|
|
|
|
|
def get_embedding(request: Request, text: str) -> List[float]: |
|
|
""" |
|
|
request.app.state.embedder_sess : ONNX Runtime InferenceSession |
|
|
request.app.state.embedder_tokenizer : 토크나이저 |
|
|
""" |
|
|
tokenizer = request.app.state.embedder_tokenizer |
|
|
sess: ort.InferenceSession = request.app.state.embedder_sess |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256) |
|
|
input_ids = inputs["input_ids"] |
|
|
inputs["position_ids"] = _generate_position_ids(input_ids) |
|
|
ort_inputs = {k: v for k, v in inputs.items()} |
|
|
ort_outs = sess.run(None, ort_inputs) |
|
|
print([arr.shape for arr in ort_outs]) |
|
|
|
|
|
token_embeddings = ort_outs[0] |
|
|
|
|
|
vec = token_embeddings.mean(axis=1)[0] |
|
|
return _l2_normalize(vec) |
|
|
|
|
|
|
|
|
|