File size: 1,503 Bytes
2aa7bf4
 
 
 
 
 
 
 
 
 
5f612cd
 
 
 
 
 
 
 
 
2aa7bf4
 
 
 
 
 
 
 
 
6e89dad
5f612cd
2aa7bf4
 
5ca4913
3c9754e
 
 
 
2aa7bf4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# rag/models/embedder.py
from typing import List
import numpy as np
import onnxruntime as ort
from fastapi import Request

def _l2_normalize(vec: np.ndarray) -> List[float]:
    norm = np.linalg.norm(vec) or 1.0
    return (vec / norm).tolist()

def _generate_position_ids(input_ids: np.ndarray) -> np.ndarray:
    """
    input_ids: [batch_size, seq_len]
    return: position_ids of shape [batch_size, seq_len] with int64 dtype
    """
    batch_size, seq_len = input_ids.shape
    position_ids = np.arange(seq_len)[None, :].astype("int64")
    return np.tile(position_ids, (batch_size, 1))

def get_embedding(request: Request, text: str) -> List[float]:
    """
    request.app.state.embedder_sess : ONNX Runtime InferenceSession
    request.app.state.embedder_tokenizer : 토크나이저
    """
    tokenizer = request.app.state.embedder_tokenizer
    sess: ort.InferenceSession = request.app.state.embedder_sess

    inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
    input_ids = inputs["input_ids"]
    inputs["position_ids"] = _generate_position_ids(input_ids)
    ort_inputs = {k: v for k, v in inputs.items()}
    ort_outs = sess.run(None, ort_inputs)
    print([arr.shape for arr in ort_outs])
    # 첫 번째 출력이 (batch, seq_len, dim)
    token_embeddings = ort_outs[0]  # shape (1, seq_len, dim)
    # 평균 pooling으로 문장 임베딩 생성
    vec = token_embeddings.mean(axis=1)[0]  # shape (dim,)
    return _l2_normalize(vec)