# rag/modules/reranker.py import os from typing import List, Dict from huggingface_hub import InferenceClient # 환경변수에서 모델명과 토큰 불러오기 HF_TOKEN = os.getenv("HF_TOKEN") RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large") _client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN) # threshold 값은 환경변수나 config에서 관리 가능 THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3")) def rerank(query: str, contexts: List[Dict]) -> List[Dict]: """ contexts: [{"id": ..., "text": ...}, ...] 반환: threshold 이상 점수만 포함된 reranked contexts """ if not contexts: return [] # reranker 입력: (query, passage) 쌍 리스트 pairs = [(query, ctx["text"]) for ctx in contexts] # Inference API 호출 → 각 쌍에 대한 점수 반환 scores = _client.rerank(inputs=pairs) # scores는 [{"score": float}, ...] 형태 for ctx, sc in zip(contexts, scores): ctx["score"] = sc["score"] # 점수 내림차순 정렬 reranked = sorted(contexts, key=lambda x: x["score"], reverse=True) # threshold 이상만 필터링 reranked = [c for c in reranked if c["score"] >= THRESHOLD] return reranked