File size: 1,136 Bytes
2aa7bf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cff5d3d
2aa7bf4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# rag/models/reranker.py
import os
from typing import List, Dict
import onnxruntime as ort
from fastapi import Request

THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))

def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]:
    """
    request.app.state.reranker_sess : ONNX Runtime InferenceSession
    request.app.state.reranker_tokenizer : ν† ν¬λ‚˜μ΄μ €
    contexts: [{"id": ..., "text": ...}, ...]
    """
    if not contexts:
        return []

    tokenizer = request.app.state.reranker_tokenizer
    sess: ort.InferenceSession = request.app.state.reranker_sess

    pairs = [(query, ctx["text"]) for ctx in contexts]
    inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
    ort_inputs = {k: v for k, v in inputs.items()}
    scores = sess.run(None, ort_inputs)[0]  # [batch, 1] ν˜•νƒœ
    scores = scores.squeeze(-1)

    for ctx, sc in zip(contexts, scores):
        ctx["score"] = float(sc)

    reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
    reranked = [c for c in reranked if c["score"] >= THRESHOLD]
    return reranked