|
|
|
|
|
from typing import List |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
from fastapi import Request |
|
|
|
|
|
def _l2_normalize(vec: np.ndarray) -> List[float]: |
|
|
norm = np.linalg.norm(vec) or 1.0 |
|
|
return (vec / norm).tolist() |
|
|
|
|
|
def get_embedding(request: Request, text: str) -> List[float]: |
|
|
""" |
|
|
request.app.state.embedder_sess : ONNX Runtime InferenceSession |
|
|
request.app.state.embedder_tokenizer : ํ ํฌ๋์ด์ |
|
|
""" |
|
|
tokenizer = request.app.state.embedder_tokenizer |
|
|
sess: ort.InferenceSession = request.app.state.embedder_sess |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256) |
|
|
ort_inputs = {k: v for k, v in inputs.items()} |
|
|
ort_outs = sess.run(None, ort_inputs) |
|
|
|
|
|
vec = ort_outs[0][0] |
|
|
return _l2_normalize(vec) |
|
|
|
|
|
|
|
|
|