# rag/models/reranker.py import os from typing import List, Dict import onnxruntime as ort from fastapi import Request THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3")) def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]: """ request.app.state.reranker_sess : ONNX Runtime InferenceSession request.app.state.reranker_tokenizer : 토크나이저 contexts: [{"id": ..., "text": ...}, ...] """ if not contexts: return [] tokenizer = request.app.state.reranker_tokenizer sess: ort.InferenceSession = request.app.state.reranker_sess pairs = [(query, ctx["text"]) for ctx in contexts] inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256) ort_inputs = {k: v for k, v in inputs.items()} scores = sess.run(None, ort_inputs)[0] # [batch, 1] 형태 scores = scores.squeeze(-1) for ctx, sc in zip(contexts, scores): ctx["score"] = float(sc) reranked = sorted(contexts, key=lambda x: x["score"], reverse=True) reranked = [c for c in reranked if c["score"] >= THRESHOLD] return reranked