|
|
|
|
|
import os |
|
|
from typing import List, Dict |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large") |
|
|
|
|
|
_client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN) |
|
|
|
|
|
|
|
|
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3")) |
|
|
|
|
|
def rerank(query: str, contexts: List[Dict]) -> List[Dict]: |
|
|
""" |
|
|
contexts: [{"id": ..., "text": ...}, ...] |
|
|
๋ฐํ: threshold ์ด์ ์ ์๋ง ํฌํจ๋ reranked contexts |
|
|
""" |
|
|
if not contexts: |
|
|
return [] |
|
|
|
|
|
|
|
|
pairs = [(query, ctx["text"]) for ctx in contexts] |
|
|
|
|
|
|
|
|
scores = _client.rerank(inputs=pairs) |
|
|
|
|
|
|
|
|
for ctx, sc in zip(contexts, scores): |
|
|
ctx["score"] = sc["score"] |
|
|
|
|
|
|
|
|
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True) |
|
|
|
|
|
|
|
|
reranked = [c for c in reranked if c["score"] >= THRESHOLD] |
|
|
|
|
|
return reranked |
|
|
|