Hybrid RAG: BM25+Dense (sqlite-vec/BGE-M3) + cross-encoder reranker (bge-reranker-v2-m3)
Browse files- src/kpaa/retrieval/retriever.py +132 -25
src/kpaa/retrieval/retriever.py
CHANGED
|
@@ -27,12 +27,66 @@ from functools import lru_cache
|
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Any
|
| 29 |
|
|
|
|
|
|
|
| 30 |
from kpaa.cases import CasesIndex
|
| 31 |
from kpaa.guides import GuidesIndex
|
| 32 |
from kpaa.law_api import KoreanLawClient
|
| 33 |
from kpaa.retrieval.excerpts import Excerpt
|
| 34 |
from kpaa.retrieval.router import RouterPlan
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Progress callback signature: async fn(stage: str, payload: dict).
|
| 37 |
# 단계별 SSE prelude 표시용. None 이면 silent.
|
| 38 |
ProgressCB = Callable[[str, dict[str, Any]], Awaitable[None]] | None
|
|
@@ -211,22 +265,50 @@ async def _fetch_cases(
|
|
| 211 |
if on_progress:
|
| 212 |
await on_progress("fetch_done", {"source": "case", "count": 0, "keyword": ""})
|
| 213 |
return []
|
| 214 |
-
#
|
| 215 |
-
#
|
| 216 |
-
# BM25
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
if plan.search_keywords:
|
| 220 |
-
|
| 221 |
-
if plan.query and plan.query not in
|
| 222 |
-
|
| 223 |
-
|
|
|
|
| 224 |
hit_map: dict = {}
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
| 227 |
rrf_scores[h.ntt_id] = rrf_scores.get(h.ntt_id, 0.0) + 1.0 / (_RRF_K + rank)
|
| 228 |
hit_map.setdefault(h.ntt_id, h)
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
hits = [hit_map[i] for i in top_ids]
|
| 231 |
out: list[Excerpt] = []
|
| 232 |
for h in hits:
|
|
@@ -253,6 +335,7 @@ async def _fetch_cases(
|
|
| 253 |
recency_score=_recency_score(_yyyy_mmdd_to_year(h.case_year or h.reg_dt)),
|
| 254 |
)
|
| 255 |
)
|
|
|
|
| 256 |
if on_progress:
|
| 257 |
await on_progress(
|
| 258 |
"fetch_done", {"source": "case", "count": len(out), "keyword": plan.top_keyword}
|
|
@@ -274,23 +357,46 @@ async def _fetch_guides(
|
|
| 274 |
if on_progress:
|
| 275 |
await on_progress("fetch_done", {"source": "guide", "count": 0, "keyword": ""})
|
| 276 |
return []
|
| 277 |
-
#
|
| 278 |
-
#
|
| 279 |
-
#
|
| 280 |
-
#
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 283 |
if plan.search_keywords:
|
| 284 |
-
|
| 285 |
-
if plan.query and plan.query not in
|
| 286 |
-
|
| 287 |
-
|
|
|
|
| 288 |
hit_map: dict = {}
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
| 291 |
rrf_scores[h.chunk_id] = rrf_scores.get(h.chunk_id, 0.0) + 1.0 / (_RRF_K + rank)
|
| 292 |
hit_map.setdefault(h.chunk_id, h)
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
hits = [hit_map[i] for i in top_ids]
|
| 295 |
out: list[Excerpt] = []
|
| 296 |
for h in hits:
|
|
@@ -318,6 +424,7 @@ async def _fetch_guides(
|
|
| 318 |
recency_score=_recency_score(year),
|
| 319 |
)
|
| 320 |
)
|
|
|
|
| 321 |
if on_progress:
|
| 322 |
await on_progress(
|
| 323 |
"fetch_done",
|
|
|
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Any
|
| 29 |
|
| 30 |
+
import os
|
| 31 |
+
|
| 32 |
from kpaa.cases import CasesIndex
|
| 33 |
from kpaa.guides import GuidesIndex
|
| 34 |
from kpaa.law_api import KoreanLawClient
|
| 35 |
from kpaa.retrieval.excerpts import Excerpt
|
| 36 |
from kpaa.retrieval.router import RouterPlan
|
| 37 |
|
| 38 |
+
|
| 39 |
+
# ─── Hybrid retrieval (BM25 + Dense via sqlite-vec) ────────────────────────
|
| 40 |
+
# `kpaa build-embeddings` 로 data/embeddings.sqlite 가 빌드되어 있으면 자동 사용.
|
| 41 |
+
# 빌드 안 된 환경 / 실패 시 BM25 단독으로 fallback.
|
| 42 |
+
#
|
| 43 |
+
# - KPAA_DENSE_RETRIEVAL=off 면 비활성. (default on)
|
| 44 |
+
# - RRF k 상수: 60 (Cormack et al. 2009 권장값).
|
| 45 |
+
_RRF_K = 60
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _dense_enabled() -> bool:
|
| 49 |
+
return os.environ.get("KPAA_DENSE_RETRIEVAL", "on").lower() not in ("off", "0", "false", "no")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _safe_dense_search(query: str, *, source_type: str, k: int) -> list:
|
| 53 |
+
"""Dense 검색 — 인덱스 없거나 모델 로드 실패 시 [] 반환 (BM25 fallback)."""
|
| 54 |
+
if not _dense_enabled() or not query:
|
| 55 |
+
return []
|
| 56 |
+
try:
|
| 57 |
+
from kpaa.embeddings.index import search_embed
|
| 58 |
+
|
| 59 |
+
return search_embed(query, source_type=source_type, k=k)
|
| 60 |
+
except Exception as e: # noqa: BLE001
|
| 61 |
+
logger.warning("Dense retrieval skipped (%s)", e)
|
| 62 |
+
return []
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _disabled_rerank() -> bool:
|
| 66 |
+
"""Reranker 비활성 여부 — 후보 풀 크기 결정에 사용 (활성 시 크게)."""
|
| 67 |
+
return os.environ.get("KPAA_RERANKER", "").lower() in ("off", "0", "false", "no", "disabled")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _maybe_rerank(query: str, excerpts: list[Excerpt], *, k: int) -> list[Excerpt]:
|
| 71 |
+
"""Cross-encoder reranker 활성 시 top-k 정밀 정렬, 미설치/disabled 면 원순서 유지."""
|
| 72 |
+
if not query or not excerpts:
|
| 73 |
+
return excerpts[:k]
|
| 74 |
+
try:
|
| 75 |
+
from kpaa.retrieval.reranker import Reranker
|
| 76 |
+
|
| 77 |
+
rr = Reranker.default()
|
| 78 |
+
except Exception as e: # noqa: BLE001
|
| 79 |
+
logger.warning("Reranker import failed (%s) — original order", e)
|
| 80 |
+
return excerpts[:k]
|
| 81 |
+
if rr is None or len(excerpts) <= k:
|
| 82 |
+
return excerpts[:k]
|
| 83 |
+
return rr.rerank(
|
| 84 |
+
query, excerpts,
|
| 85 |
+
text_fn=lambda e: f"{e.title}\n{(e.content or '')[:1500]}" if e.title else (e.content or "")[:1500],
|
| 86 |
+
top_k=k,
|
| 87 |
+
)
|
| 88 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 89 |
+
|
| 90 |
# Progress callback signature: async fn(stage: str, payload: dict).
|
| 91 |
# 단계별 SSE prelude 표시용. None 이면 silent.
|
| 92 |
ProgressCB = Callable[[str, dict[str, Any]], Awaitable[None]] | None
|
|
|
|
| 265 |
if on_progress:
|
| 266 |
await on_progress("fetch_done", {"source": "case", "count": 0, "keyword": ""})
|
| 267 |
return []
|
| 268 |
+
# Hybrid retrieval — BM25 + Dense → RRF.
|
| 269 |
+
#
|
| 270 |
+
# BM25 입력: search_keywords 결합 query + 원본 질문 (LLM 키워드 추출이 핵심 주제어
|
| 271 |
+
# 누락 시 원본이 안전망)
|
| 272 |
+
# Dense 입력: 원본 질문 1회 (semantic 검색은 자연어 길수록 좋음. 임베딩 비용도
|
| 273 |
+
# 1회만)
|
| 274 |
+
#
|
| 275 |
+
# RRF 로 두 신호 통합. dense 인덱스 없으면 BM25 단독으로 fallback.
|
| 276 |
+
bm25_queries: list[str] = []
|
| 277 |
if plan.search_keywords:
|
| 278 |
+
bm25_queries.append(" ".join(plan.search_keywords[:3]))
|
| 279 |
+
if plan.query and plan.query not in bm25_queries:
|
| 280 |
+
bm25_queries.append(plan.query)
|
| 281 |
+
|
| 282 |
+
rrf_scores: dict[str, float] = {}
|
| 283 |
hit_map: dict = {}
|
| 284 |
+
pool = max(k * 3, 30)
|
| 285 |
+
|
| 286 |
+
# BM25 — 두 query 시도해 RRF 누적
|
| 287 |
+
for q in bm25_queries:
|
| 288 |
+
for rank, h in enumerate(idx.search(q, k=pool)):
|
| 289 |
rrf_scores[h.ntt_id] = rrf_scores.get(h.ntt_id, 0.0) + 1.0 / (_RRF_K + rank)
|
| 290 |
hit_map.setdefault(h.ntt_id, h)
|
| 291 |
+
|
| 292 |
+
# Dense — 원본 질문으로 1회. 결과는 EmbedHit(chunk_id='case_<ntt_id>'..)
|
| 293 |
+
dense_ids: list[str] = []
|
| 294 |
+
for rank, eh in enumerate(_safe_dense_search(plan.query, source_type="case", k=pool)):
|
| 295 |
+
ntt_id = eh.chunk_id.removeprefix("case_")
|
| 296 |
+
rrf_scores[ntt_id] = rrf_scores.get(ntt_id, 0.0) + 1.0 / (_RRF_K + rank)
|
| 297 |
+
if ntt_id not in hit_map:
|
| 298 |
+
dense_ids.append(ntt_id)
|
| 299 |
+
|
| 300 |
+
# Dense-only id 들의 Case 본문 lookup (BM25 결과에 없는 것만)
|
| 301 |
+
if dense_ids:
|
| 302 |
+
from kpaa.cases.index import get_cases
|
| 303 |
+
|
| 304 |
+
extra = get_cases(dense_ids)
|
| 305 |
+
hit_map.update(extra)
|
| 306 |
+
|
| 307 |
+
# Reranker 가용 시 더 큰 후보 풀(k*3 ~ 20)을 reranker 에 넘겨 정밀 정렬
|
| 308 |
+
rerank_pool = max(k * 3, 15) if not _disabled_rerank() else k
|
| 309 |
+
top_ids = [i for i in sorted(rrf_scores, key=lambda i: -rrf_scores[i]) if i in hit_map][
|
| 310 |
+
:rerank_pool
|
| 311 |
+
]
|
| 312 |
hits = [hit_map[i] for i in top_ids]
|
| 313 |
out: list[Excerpt] = []
|
| 314 |
for h in hits:
|
|
|
|
| 335 |
recency_score=_recency_score(_yyyy_mmdd_to_year(h.case_year or h.reg_dt)),
|
| 336 |
)
|
| 337 |
)
|
| 338 |
+
out = _maybe_rerank(plan.query, out, k=k)
|
| 339 |
if on_progress:
|
| 340 |
await on_progress(
|
| 341 |
"fetch_done", {"source": "case", "count": len(out), "keyword": plan.top_keyword}
|
|
|
|
| 357 |
if on_progress:
|
| 358 |
await on_progress("fetch_done", {"source": "guide", "count": 0, "keyword": ""})
|
| 359 |
return []
|
| 360 |
+
# Hybrid retrieval — BM25 + Dense → RRF.
|
| 361 |
+
#
|
| 362 |
+
# BM25 입력: search_keywords 결합 query + 원본 질문 (LLM 키워드 추출이 핵심 주제어
|
| 363 |
+
# 누락 시 원본이 안전망. e.g. "처방전 보관기간" → ["보관기간"]만 추출돼도 원본
|
| 364 |
+
# query 에서 "처방전" 토큰 hit 가능)
|
| 365 |
+
# Dense 입력: 원본 질문 1회 (semantic 검색은 자연어 길수록 좋음)
|
| 366 |
+
bm25_queries: list[str] = []
|
| 367 |
if plan.search_keywords:
|
| 368 |
+
bm25_queries.append(" ".join(plan.search_keywords[:3]))
|
| 369 |
+
if plan.query and plan.query not in bm25_queries:
|
| 370 |
+
bm25_queries.append(plan.query)
|
| 371 |
+
|
| 372 |
+
rrf_scores: dict[str, float] = {}
|
| 373 |
hit_map: dict = {}
|
| 374 |
+
pool = max(k * 3, 30)
|
| 375 |
+
|
| 376 |
+
# BM25
|
| 377 |
+
for q in bm25_queries:
|
| 378 |
+
for rank, h in enumerate(idx.search(q, k=pool)):
|
| 379 |
rrf_scores[h.chunk_id] = rrf_scores.get(h.chunk_id, 0.0) + 1.0 / (_RRF_K + rank)
|
| 380 |
hit_map.setdefault(h.chunk_id, h)
|
| 381 |
+
|
| 382 |
+
# Dense — 원본 질문 1회 (chunk_id 그대로)
|
| 383 |
+
dense_ids: list[str] = []
|
| 384 |
+
for rank, eh in enumerate(_safe_dense_search(plan.query, source_type="guide", k=pool)):
|
| 385 |
+
rrf_scores[eh.chunk_id] = rrf_scores.get(eh.chunk_id, 0.0) + 1.0 / (_RRF_K + rank)
|
| 386 |
+
if eh.chunk_id not in hit_map:
|
| 387 |
+
dense_ids.append(eh.chunk_id)
|
| 388 |
+
|
| 389 |
+
# Dense-only chunk_id 들의 GuideChunk 본문 lookup
|
| 390 |
+
if dense_ids:
|
| 391 |
+
from kpaa.guides.index import get_chunks
|
| 392 |
+
|
| 393 |
+
extra = get_chunks(dense_ids)
|
| 394 |
+
hit_map.update(extra)
|
| 395 |
+
|
| 396 |
+
rerank_pool = max(k * 3, 15) if not _disabled_rerank() else k
|
| 397 |
+
top_ids = [i for i in sorted(rrf_scores, key=lambda i: -rrf_scores[i]) if i in hit_map][
|
| 398 |
+
:rerank_pool
|
| 399 |
+
]
|
| 400 |
hits = [hit_map[i] for i in top_ids]
|
| 401 |
out: list[Excerpt] = []
|
| 402 |
for h in hits:
|
|
|
|
| 424 |
recency_score=_recency_score(year),
|
| 425 |
)
|
| 426 |
)
|
| 427 |
+
out = _maybe_rerank(plan.query, out, k=k)
|
| 428 |
if on_progress:
|
| 429 |
await on_progress(
|
| 430 |
"fetch_done",
|