scvcoder commited on
Commit
f64a4c2
·
verified ·
1 Parent(s): 0e6bb26

Hybrid RAG: BM25+Dense (sqlite-vec/BGE-M3) + cross-encoder reranker (bge-reranker-v2-m3)

Browse files
Files changed (1) hide show
  1. src/kpaa/retrieval/retriever.py +132 -25
src/kpaa/retrieval/retriever.py CHANGED
@@ -27,12 +27,66 @@ from functools import lru_cache
27
  from pathlib import Path
28
  from typing import Any
29
 
 
 
30
  from kpaa.cases import CasesIndex
31
  from kpaa.guides import GuidesIndex
32
  from kpaa.law_api import KoreanLawClient
33
  from kpaa.retrieval.excerpts import Excerpt
34
  from kpaa.retrieval.router import RouterPlan
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Progress callback signature: async fn(stage: str, payload: dict).
37
  # 단계별 SSE prelude 표시용. None 이면 silent.
38
  ProgressCB = Callable[[str, dict[str, Any]], Awaitable[None]] | None
@@ -211,22 +265,50 @@ async def _fetch_cases(
211
  if on_progress:
212
  await on_progress("fetch_done", {"source": "case", "count": 0, "keyword": ""})
213
  return []
214
- # 키워드 + 원본 질문 쿼리로 검색 RRF(Reciprocal Rank Fusion)로 결합 —
215
- # LLM 추출 키워드가 핵심 주제어 누락 시 원본 질문이 안전망. 단순 concat은
216
- # BM25 토큰 가중치 차이로 한쪽이 독점 가능하므로 rank 기반 결합 필요.
217
- _RRF_K = 60
218
- queries: list[str] = []
 
 
 
 
219
  if plan.search_keywords:
220
- queries.append(" ".join(plan.search_keywords[:3]))
221
- if plan.query and plan.query not in queries:
222
- queries.append(plan.query)
223
- rrf_scores: dict = {}
 
224
  hit_map: dict = {}
225
- for q in queries:
226
- for rank, h in enumerate(idx.search(q, k=k)):
 
 
 
227
  rrf_scores[h.ntt_id] = rrf_scores.get(h.ntt_id, 0.0) + 1.0 / (_RRF_K + rank)
228
  hit_map.setdefault(h.ntt_id, h)
229
- top_ids = sorted(rrf_scores, key=lambda i: -rrf_scores[i])[:k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  hits = [hit_map[i] for i in top_ids]
231
  out: list[Excerpt] = []
232
  for h in hits:
@@ -253,6 +335,7 @@ async def _fetch_cases(
253
  recency_score=_recency_score(_yyyy_mmdd_to_year(h.case_year or h.reg_dt)),
254
  )
255
  )
 
256
  if on_progress:
257
  await on_progress(
258
  "fetch_done", {"source": "case", "count": len(out), "keyword": plan.top_keyword}
@@ -274,23 +357,46 @@ async def _fetch_guides(
274
  if on_progress:
275
  await on_progress("fetch_done", {"source": "guide", "count": 0, "keyword": ""})
276
  return []
277
- # 키워드 + 원본 질문 쿼리로 검색 RRF로 결합 — LLM 추출 키워드가 핵심
278
- # 주제어를 누락해도 원본 질문이 안전망 (e.g. "처방전 보관기간" → ["보관기간"]만
279
- # 추출돼도 원본 query에서 "처방전" 토큰 hit 가능). 단순 concat은 BM25 가중치
280
- # 차이로 한쪽독점하므로 rank union 필요.
281
- _RRF_K = 60
282
- queries: list[str] = []
 
283
  if plan.search_keywords:
284
- queries.append(" ".join(plan.search_keywords[:3]))
285
- if plan.query and plan.query not in queries:
286
- queries.append(plan.query)
287
- rrf_scores: dict = {}
 
288
  hit_map: dict = {}
289
- for q in queries:
290
- for rank, h in enumerate(idx.search(q, k=k)):
 
 
 
291
  rrf_scores[h.chunk_id] = rrf_scores.get(h.chunk_id, 0.0) + 1.0 / (_RRF_K + rank)
292
  hit_map.setdefault(h.chunk_id, h)
293
- top_ids = sorted(rrf_scores, key=lambda i: -rrf_scores[i])[:k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  hits = [hit_map[i] for i in top_ids]
295
  out: list[Excerpt] = []
296
  for h in hits:
@@ -318,6 +424,7 @@ async def _fetch_guides(
318
  recency_score=_recency_score(year),
319
  )
320
  )
 
321
  if on_progress:
322
  await on_progress(
323
  "fetch_done",
 
27
  from pathlib import Path
28
  from typing import Any
29
 
30
+ import os
31
+
32
  from kpaa.cases import CasesIndex
33
  from kpaa.guides import GuidesIndex
34
  from kpaa.law_api import KoreanLawClient
35
  from kpaa.retrieval.excerpts import Excerpt
36
  from kpaa.retrieval.router import RouterPlan
37
 
38
+
39
+ # ─── Hybrid retrieval (BM25 + Dense via sqlite-vec) ────────────────────────
40
+ # `kpaa build-embeddings` 로 data/embeddings.sqlite 가 빌드되어 있으면 자동 사용.
41
+ # 빌드 안 된 환경 / 실패 시 BM25 단독으로 fallback.
42
+ #
43
+ # - KPAA_DENSE_RETRIEVAL=off 면 비활성. (default on)
44
+ # - RRF k 상수: 60 (Cormack et al. 2009 권장값).
45
+ _RRF_K = 60
46
+
47
+
48
+ def _dense_enabled() -> bool:
49
+ return os.environ.get("KPAA_DENSE_RETRIEVAL", "on").lower() not in ("off", "0", "false", "no")
50
+
51
+
52
+ def _safe_dense_search(query: str, *, source_type: str, k: int) -> list:
53
+ """Dense 검색 — 인덱스 없거나 모델 로드 실패 시 [] 반환 (BM25 fallback)."""
54
+ if not _dense_enabled() or not query:
55
+ return []
56
+ try:
57
+ from kpaa.embeddings.index import search_embed
58
+
59
+ return search_embed(query, source_type=source_type, k=k)
60
+ except Exception as e: # noqa: BLE001
61
+ logger.warning("Dense retrieval skipped (%s)", e)
62
+ return []
63
+
64
+
65
+ def _disabled_rerank() -> bool:
66
+ """Reranker 비활성 여부 — 후보 풀 크기 결정에 사용 (활성 시 크게)."""
67
+ return os.environ.get("KPAA_RERANKER", "").lower() in ("off", "0", "false", "no", "disabled")
68
+
69
+
70
+ def _maybe_rerank(query: str, excerpts: list[Excerpt], *, k: int) -> list[Excerpt]:
71
+ """Cross-encoder reranker 활성 시 top-k 정밀 정렬, 미설치/disabled 면 원순서 유지."""
72
+ if not query or not excerpts:
73
+ return excerpts[:k]
74
+ try:
75
+ from kpaa.retrieval.reranker import Reranker
76
+
77
+ rr = Reranker.default()
78
+ except Exception as e: # noqa: BLE001
79
+ logger.warning("Reranker import failed (%s) — original order", e)
80
+ return excerpts[:k]
81
+ if rr is None or len(excerpts) <= k:
82
+ return excerpts[:k]
83
+ return rr.rerank(
84
+ query, excerpts,
85
+ text_fn=lambda e: f"{e.title}\n{(e.content or '')[:1500]}" if e.title else (e.content or "")[:1500],
86
+ top_k=k,
87
+ )
88
+ # ──────────────────────────────────────────────────────────────────────────
89
+
90
  # Progress callback signature: async fn(stage: str, payload: dict).
91
  # 단계별 SSE prelude 표시용. None 이면 silent.
92
  ProgressCB = Callable[[str, dict[str, Any]], Awaitable[None]] | None
 
265
  if on_progress:
266
  await on_progress("fetch_done", {"source": "case", "count": 0, "keyword": ""})
267
  return []
268
+ # Hybrid retrieval BM25 + Dense RRF.
269
+ #
270
+ # BM25 입력: search_keywords 결합 query + 원본 질문 (LLM 키워드 추출이 핵심 주제어
271
+ # 누락 원본이 안전망)
272
+ # Dense 입력: 원본 질문 1회 (semantic 검색은 자연어 길수록 좋음. 임베딩 비용도
273
+ # 1회만)
274
+ #
275
+ # RRF 로 두 신호 통합. dense 인덱스 없으면 BM25 단독으로 fallback.
276
+ bm25_queries: list[str] = []
277
  if plan.search_keywords:
278
+ bm25_queries.append(" ".join(plan.search_keywords[:3]))
279
+ if plan.query and plan.query not in bm25_queries:
280
+ bm25_queries.append(plan.query)
281
+
282
+ rrf_scores: dict[str, float] = {}
283
  hit_map: dict = {}
284
+ pool = max(k * 3, 30)
285
+
286
+ # BM25 — 두 query 시도해 RRF 누적
287
+ for q in bm25_queries:
288
+ for rank, h in enumerate(idx.search(q, k=pool)):
289
  rrf_scores[h.ntt_id] = rrf_scores.get(h.ntt_id, 0.0) + 1.0 / (_RRF_K + rank)
290
  hit_map.setdefault(h.ntt_id, h)
291
+
292
+ # Dense — 원본 질문으로 1회. 결과는 EmbedHit(chunk_id='case_<ntt_id>'..)
293
+ dense_ids: list[str] = []
294
+ for rank, eh in enumerate(_safe_dense_search(plan.query, source_type="case", k=pool)):
295
+ ntt_id = eh.chunk_id.removeprefix("case_")
296
+ rrf_scores[ntt_id] = rrf_scores.get(ntt_id, 0.0) + 1.0 / (_RRF_K + rank)
297
+ if ntt_id not in hit_map:
298
+ dense_ids.append(ntt_id)
299
+
300
+ # Dense-only id 들의 Case 본문 lookup (BM25 결과에 없는 것만)
301
+ if dense_ids:
302
+ from kpaa.cases.index import get_cases
303
+
304
+ extra = get_cases(dense_ids)
305
+ hit_map.update(extra)
306
+
307
+ # Reranker 가용 시 더 큰 후보 풀(k*3 ~ 20)을 reranker 에 넘겨 정밀 정렬
308
+ rerank_pool = max(k * 3, 15) if not _disabled_rerank() else k
309
+ top_ids = [i for i in sorted(rrf_scores, key=lambda i: -rrf_scores[i]) if i in hit_map][
310
+ :rerank_pool
311
+ ]
312
  hits = [hit_map[i] for i in top_ids]
313
  out: list[Excerpt] = []
314
  for h in hits:
 
335
  recency_score=_recency_score(_yyyy_mmdd_to_year(h.case_year or h.reg_dt)),
336
  )
337
  )
338
+ out = _maybe_rerank(plan.query, out, k=k)
339
  if on_progress:
340
  await on_progress(
341
  "fetch_done", {"source": "case", "count": len(out), "keyword": plan.top_keyword}
 
357
  if on_progress:
358
  await on_progress("fetch_done", {"source": "guide", "count": 0, "keyword": ""})
359
  return []
360
+ # Hybrid retrieval BM25 + Dense RRF.
361
+ #
362
+ # BM25 입력: search_keywords 결합 query + 원본 질문 (LLM 키워드 추출이 핵심 주제어
363
+ # 누락 원본안전망. e.g. "처방전 보관간" ["보관기간"]만 추출돼도 원본
364
+ # query 에서 "처방전" 토큰 hit 가능)
365
+ # Dense 입력: 원본 질문 1회 (semantic 검색은 자연어 길수록 좋음)
366
+ bm25_queries: list[str] = []
367
  if plan.search_keywords:
368
+ bm25_queries.append(" ".join(plan.search_keywords[:3]))
369
+ if plan.query and plan.query not in bm25_queries:
370
+ bm25_queries.append(plan.query)
371
+
372
+ rrf_scores: dict[str, float] = {}
373
  hit_map: dict = {}
374
+ pool = max(k * 3, 30)
375
+
376
+ # BM25
377
+ for q in bm25_queries:
378
+ for rank, h in enumerate(idx.search(q, k=pool)):
379
  rrf_scores[h.chunk_id] = rrf_scores.get(h.chunk_id, 0.0) + 1.0 / (_RRF_K + rank)
380
  hit_map.setdefault(h.chunk_id, h)
381
+
382
+ # Dense — 원본 질문 1회 (chunk_id 그대로)
383
+ dense_ids: list[str] = []
384
+ for rank, eh in enumerate(_safe_dense_search(plan.query, source_type="guide", k=pool)):
385
+ rrf_scores[eh.chunk_id] = rrf_scores.get(eh.chunk_id, 0.0) + 1.0 / (_RRF_K + rank)
386
+ if eh.chunk_id not in hit_map:
387
+ dense_ids.append(eh.chunk_id)
388
+
389
+ # Dense-only chunk_id 들의 GuideChunk 본문 lookup
390
+ if dense_ids:
391
+ from kpaa.guides.index import get_chunks
392
+
393
+ extra = get_chunks(dense_ids)
394
+ hit_map.update(extra)
395
+
396
+ rerank_pool = max(k * 3, 15) if not _disabled_rerank() else k
397
+ top_ids = [i for i in sorted(rrf_scores, key=lambda i: -rrf_scores[i]) if i in hit_map][
398
+ :rerank_pool
399
+ ]
400
  hits = [hit_map[i] for i in top_ids]
401
  out: list[Excerpt] = []
402
  for h in hits:
 
424
  recency_score=_recency_score(year),
425
  )
426
  )
427
+ out = _maybe_rerank(plan.query, out, k=k)
428
  if on_progress:
429
  await on_progress(
430
  "fetch_done",