File size: 1,246 Bytes
4fdc679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# rag/modules/reranker.py
import os
from typing import List, Dict
from huggingface_hub import InferenceClient

# ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๋ชจ๋ธ๋ช…๊ณผ ํ† ํฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
HF_TOKEN = os.getenv("HF_TOKEN")
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large")

_client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN)

# threshold ๊ฐ’์€ ํ™˜๊ฒฝ๋ณ€์ˆ˜๋‚˜ config์—์„œ ๊ด€๋ฆฌ ๊ฐ€๋Šฅ
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))

def rerank(query: str, contexts: List[Dict]) -> List[Dict]:
    """
    contexts: [{"id": ..., "text": ...}, ...]
    ๋ฐ˜ํ™˜: threshold ์ด์ƒ ์ ์ˆ˜๋งŒ ํฌํ•จ๋œ reranked contexts
    """
    if not contexts:
        return []

    # reranker ์ž…๋ ฅ: (query, passage) ์Œ ๋ฆฌ์ŠคํŠธ
    pairs = [(query, ctx["text"]) for ctx in contexts]

    # Inference API ํ˜ธ์ถœ โ†’ ๊ฐ ์Œ์— ๋Œ€ํ•œ ์ ์ˆ˜ ๋ฐ˜ํ™˜
    scores = _client.rerank(inputs=pairs)

    # scores๋Š” [{"score": float}, ...] ํ˜•ํƒœ
    for ctx, sc in zip(contexts, scores):
        ctx["score"] = sc["score"]

    # ์ ์ˆ˜ ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
    reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)

    # threshold ์ด์ƒ๋งŒ ํ•„ํ„ฐ๋ง
    reranked = [c for c in reranked if c["score"] >= THRESHOLD]

    return reranked