pls-rag / modules /reranker.py
m97j's picture
Initial codes commit
4fdc679
raw
history blame
1.25 kB
# rag/modules/reranker.py
import os
from typing import List, Dict
from huggingface_hub import InferenceClient
# ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๋ชจ๋ธ๋ช…๊ณผ ํ† ํฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
HF_TOKEN = os.getenv("HF_TOKEN")
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large")
_client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN)
# threshold ๊ฐ’์€ ํ™˜๊ฒฝ๋ณ€์ˆ˜๋‚˜ config์—์„œ ๊ด€๋ฆฌ ๊ฐ€๋Šฅ
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
def rerank(query: str, contexts: List[Dict]) -> List[Dict]:
"""
contexts: [{"id": ..., "text": ...}, ...]
๋ฐ˜ํ™˜: threshold ์ด์ƒ ์ ์ˆ˜๋งŒ ํฌํ•จ๋œ reranked contexts
"""
if not contexts:
return []
# reranker ์ž…๋ ฅ: (query, passage) ์Œ ๋ฆฌ์ŠคํŠธ
pairs = [(query, ctx["text"]) for ctx in contexts]
# Inference API ํ˜ธ์ถœ โ†’ ๊ฐ ์Œ์— ๋Œ€ํ•œ ์ ์ˆ˜ ๋ฐ˜ํ™˜
scores = _client.rerank(inputs=pairs)
# scores๋Š” [{"score": float}, ...] ํ˜•ํƒœ
for ctx, sc in zip(contexts, scores):
ctx["score"] = sc["score"]
# ์ ์ˆ˜ ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
# threshold ์ด์ƒ๋งŒ ํ•„ํ„ฐ๋ง
reranked = [c for c in reranked if c["score"] >= THRESHOLD]
return reranked