File size: 1,246 Bytes
4fdc679 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
# rag/modules/reranker.py
import os
from typing import List, Dict
from huggingface_hub import InferenceClient
# ํ๊ฒฝ๋ณ์์์ ๋ชจ๋ธ๋ช
๊ณผ ํ ํฐ ๋ถ๋ฌ์ค๊ธฐ
HF_TOKEN = os.getenv("HF_TOKEN")
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large")
_client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN)
# threshold ๊ฐ์ ํ๊ฒฝ๋ณ์๋ config์์ ๊ด๋ฆฌ ๊ฐ๋ฅ
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
def rerank(query: str, contexts: List[Dict]) -> List[Dict]:
"""
contexts: [{"id": ..., "text": ...}, ...]
๋ฐํ: threshold ์ด์ ์ ์๋ง ํฌํจ๋ reranked contexts
"""
if not contexts:
return []
# reranker ์
๋ ฅ: (query, passage) ์ ๋ฆฌ์คํธ
pairs = [(query, ctx["text"]) for ctx in contexts]
# Inference API ํธ์ถ โ ๊ฐ ์์ ๋ํ ์ ์ ๋ฐํ
scores = _client.rerank(inputs=pairs)
# scores๋ [{"score": float}, ...] ํํ
for ctx, sc in zip(contexts, scores):
ctx["score"] = sc["score"]
# ์ ์ ๋ด๋ฆผ์ฐจ์ ์ ๋ ฌ
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
# threshold ์ด์๋ง ํํฐ๋ง
reranked = [c for c in reranked if c["score"] >= THRESHOLD]
return reranked
|