Rifqi Hafizuddin commited on
Commit
de32ab0
·
1 Parent(s): c9d3b33

[NOTICKET] rrf merge now at router level

Browse files
Files changed (1) hide show
  1. src/rag/router.py +119 -19
src/rag/router.py CHANGED
@@ -1,8 +1,14 @@
1
- """Routes retrieval requests to the appropriate retriever based on source_hint."""
 
 
 
 
 
2
 
3
  import asyncio
4
  import hashlib
5
  import json
 
6
  from typing import Literal
7
 
8
  from src.db.redis.connection import get_redis
@@ -12,9 +18,73 @@ from src.rag.base import BaseRetriever, RetrievalResult
12
  logger = get_logger("retrieval_router")
13
 
14
  _CACHE_TTL = 3600 # 1 hour
 
 
15
  SourceHint = Literal["document", "schema", "both"]
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class RetrievalRouter:
19
  def __init__(
20
  self,
@@ -26,12 +96,12 @@ class RetrievalRouter:
26
  "document": document_retriever,
27
  }
28
 
29
- def _route(self, source_hint: SourceHint) -> list[BaseRetriever]:
30
  if source_hint == "schema":
31
- return [self._retrievers["schema"]]
32
  if source_hint == "document":
33
- return [self._retrievers["document"]]
34
- return list(self._retrievers.values())
35
 
36
  async def retrieve(
37
  self,
@@ -42,7 +112,7 @@ class RetrievalRouter:
42
  ) -> list[RetrievalResult]:
43
  redis = await get_redis()
44
  query_hash = hashlib.md5(query.encode()).hexdigest()
45
- cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}"
46
 
47
  cached = await redis.get(cache_key)
48
  if cached:
@@ -50,26 +120,56 @@ class RetrievalRouter:
50
  raw = json.loads(cached)
51
  return [RetrievalResult(**r) for r in raw]
52
 
53
- retrievers = self._route(source_hint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  batches = await asyncio.gather(
55
- *[r.retrieve(query, user_id, k) for r in retrievers],
56
  return_exceptions=True,
57
  )
58
 
59
- results: list[RetrievalResult] = []
60
- for batch in batches:
 
61
  if isinstance(batch, Exception):
62
- logger.error("retriever failed", error=str(batch))
 
63
  continue
64
- results.extend(batch)
 
65
 
66
- results.sort(key=lambda r: r.score, reverse=True)
67
- results = results[:k]
68
 
69
- logger.info("retrieved chunks", count=len(results), source_hint=source_hint)
70
- await redis.setex(
71
- cache_key,
72
- _CACHE_TTL,
73
- json.dumps([vars(r) for r in results]),
 
 
74
  )
75
  return results
 
1
+ """Routes retrieval requests to the appropriate retriever based on source_hint.
2
+
3
+ Cross-retriever merging uses Reciprocal Rank Fusion (RRF) on per-retriever
4
+ ranked lists — score scales differ across retrievers (RRF, cosine, distance)
5
+ and aren't directly comparable, so we rank-merge instead of score-merge.
6
+ """
7
 
8
  import asyncio
9
  import hashlib
10
  import json
11
+ from dataclasses import asdict
12
  from typing import Literal
13
 
14
  from src.db.redis.connection import get_redis
 
18
  logger = get_logger("retrieval_router")
19
 
20
  _CACHE_TTL = 3600 # 1 hour
21
+ _CACHE_KEY_PREFIX = "retrieval"
22
+ _RRF_K = 60 # standard RRF constant
23
  SourceHint = Literal["document", "schema", "both"]
24
 
25
 
26
+ def _result_dedup_key(r: RetrievalResult) -> tuple:
27
+ """Cross-retriever dedup key — distinguishes DB columns vs DB tables vs
28
+ tabular columns vs prose chunks vs sheet-level (future)."""
29
+ data = r.metadata.get("data", {})
30
+ return (
31
+ r.source_type,
32
+ data.get("table_name"),
33
+ data.get("column_name"),
34
+ data.get("filename"),
35
+ data.get("sheet_name"),
36
+ data.get("chunk_index"), # disambiguates multiple prose chunks per doc
37
+ )
38
+
39
+
40
+ def _rrf_merge(
41
+ ranked_lists: list[list[RetrievalResult]],
42
+ top_k: int,
43
+ k_rrf: int = _RRF_K,
44
+ ) -> list[RetrievalResult]:
45
+ """Reciprocal Rank Fusion across retriever batches.
46
+
47
+ Each input list is treated as already best-first ordered. Items are
48
+ deduped via _result_dedup_key and re-ranked by aggregated reciprocal
49
+ rank across all lists. Score on the returned RetrievalResult is the
50
+ aggregated RRF score (uniform scale across legs).
51
+ """
52
+ scores: dict[tuple, float] = {}
53
+ index: dict[tuple, RetrievalResult] = {}
54
+
55
+ for ranked in ranked_lists:
56
+ for rank, result in enumerate(ranked):
57
+ key = _result_dedup_key(result)
58
+ scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
59
+ # Keep the first occurrence; metadata is identical for the same
60
+ # key across lists, so any copy is fine.
61
+ if key not in index:
62
+ index[key] = result
63
+
64
+ merged = sorted(index.values(), key=lambda r: scores[_result_dedup_key(r)], reverse=True)
65
+ # Overwrite score with RRF score so downstream consumers see a uniform scale.
66
+ for r in merged:
67
+ r.score = scores[_result_dedup_key(r)]
68
+ return merged[:top_k]
69
+
70
+
71
+ async def invalidate_retrieval_cache(user_id: str) -> int:
72
+ """Delete every cached retrieval entry for `user_id`.
73
+
74
+ Called by ingest/upload/delete API handlers after a successful write so
75
+ the next retrieval picks up the new data instead of stale cached top-k.
76
+ Returns the number of keys removed.
77
+ """
78
+ redis = await get_redis()
79
+ pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*"
80
+ keys = [key async for key in redis.scan_iter(match=pattern)]
81
+ if not keys:
82
+ return 0
83
+ deleted = await redis.delete(*keys)
84
+ logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted)
85
+ return int(deleted)
86
+
87
+
88
  class RetrievalRouter:
89
  def __init__(
90
  self,
 
96
  "document": document_retriever,
97
  }
98
 
99
+ def _route(self, source_hint: SourceHint) -> list[tuple[str, BaseRetriever]]:
100
  if source_hint == "schema":
101
+ return [("schema", self._retrievers["schema"])]
102
  if source_hint == "document":
103
+ return [("document", self._retrievers["document"])]
104
+ return list(self._retrievers.items())
105
 
106
  async def retrieve(
107
  self,
 
112
  ) -> list[RetrievalResult]:
113
  redis = await get_redis()
114
  query_hash = hashlib.md5(query.encode()).hexdigest()
115
+ cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{source_hint}:{query_hash}:{k}"
116
 
117
  cached = await redis.get(cache_key)
118
  if cached:
 
120
  raw = json.loads(cached)
121
  return [RetrievalResult(**r) for r in raw]
122
 
123
+ results = await self._retrieve_uncached(query, user_id, source_hint, k)
124
+
125
+ # Empty-result fallback: orchestrator may have misclassified intent.
126
+ # Retry once with "both" before giving up. No-op when source_hint is
127
+ # already "both".
128
+ if not results and source_hint != "both":
129
+ logger.warning(
130
+ "empty retrieval, falling back to source_hint='both'",
131
+ original_source_hint=source_hint,
132
+ )
133
+ results = await self._retrieve_uncached(query, user_id, "both", k)
134
+
135
+ await redis.setex(
136
+ cache_key,
137
+ _CACHE_TTL,
138
+ json.dumps([asdict(r) for r in results]),
139
+ )
140
+ return results
141
+
142
+ async def _retrieve_uncached(
143
+ self,
144
+ query: str,
145
+ user_id: str,
146
+ source_hint: SourceHint,
147
+ k: int,
148
+ ) -> list[RetrievalResult]:
149
+ routed = self._route(source_hint)
150
  batches = await asyncio.gather(
151
+ *[r.retrieve(query, user_id, k) for _, r in routed],
152
  return_exceptions=True,
153
  )
154
 
155
+ valid_lists: list[list[RetrievalResult]] = []
156
+ per_retriever: dict[str, int | str] = {}
157
+ for (name, _), batch in zip(routed, batches):
158
  if isinstance(batch, Exception):
159
+ logger.error("retriever failed", retriever=name, error=str(batch))
160
+ per_retriever[name] = "error"
161
  continue
162
+ valid_lists.append(batch)
163
+ per_retriever[name] = len(batch)
164
 
165
+ results = _rrf_merge(valid_lists, top_k=k)
 
166
 
167
+ logger.info(
168
+ "router result",
169
+ source_hint=source_hint,
170
+ per_retriever=per_retriever,
171
+ final_count=len(results),
172
+ top_score=results[0].score if results else None,
173
+ bottom_score=results[-1].score if results else None,
174
  )
175
  return results