Rifqi Hafizuddin commited on
Commit
c9d3b33
·
1 Parent(s): 4150ba7

fix: fix dedup logic

Browse files
Files changed (1) hide show
  1. src/rag/retrievers/schema.py +20 -10
src/rag/retrievers/schema.py CHANGED
@@ -178,6 +178,22 @@ class SchemaRetriever(BaseRetriever):
178
  for row in rows
179
  ]
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def _rrf_merge(
182
  self,
183
  *ranked_lists: list[RetrievalResult],
@@ -190,25 +206,19 @@ class SchemaRetriever(BaseRetriever):
190
 
191
  for ranked in ranked_lists:
192
  for rank, result in enumerate(ranked):
193
- data = result.metadata.get("data", {})
194
- key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
195
  scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
196
  if key not in index or result.score > index[key].score:
197
  index[key] = result
198
 
199
- def _key(r: RetrievalResult) -> tuple:
200
- d = r.metadata.get("data", {})
201
- return (d.get("table_name"), d.get("column_name") or d.get("filename"))
202
-
203
- merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True)
204
  return merged[:top_k]
205
 
206
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
207
- """Deduplicate by (table_name, column_name), keeping highest score per unique column."""
208
  seen: dict[tuple, RetrievalResult] = {}
209
  for r in results:
210
- data = r.metadata.get("data", {})
211
- key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
212
  if key not in seen or r.score > seen[key].score:
213
  seen[key] = r
214
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)
 
178
  for row in rows
179
  ]
180
 
181
+ @staticmethod
182
+ def _chunk_key(r: RetrievalResult) -> tuple:
183
+ """Stable identity for dedup/RRF.
184
+
185
+ Includes filename and sheet_name so that tabular column chunks with
186
+ the same column name across different files (e.g. `id` in two CSVs)
187
+ and future sheet-level chunks across XLSX sheets don't collide.
188
+ """
189
+ d = r.metadata.get("data", {})
190
+ return (
191
+ d.get("table_name"),
192
+ d.get("column_name"),
193
+ d.get("filename"),
194
+ d.get("sheet_name"),
195
+ )
196
+
197
  def _rrf_merge(
198
  self,
199
  *ranked_lists: list[RetrievalResult],
 
206
 
207
  for ranked in ranked_lists:
208
  for rank, result in enumerate(ranked):
209
+ key = self._chunk_key(result)
 
210
  scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
211
  if key not in index or result.score > index[key].score:
212
  index[key] = result
213
 
214
+ merged = sorted(index.values(), key=lambda r: scores[self._chunk_key(r)], reverse=True)
 
 
 
 
215
  return merged[:top_k]
216
 
217
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
218
+ """Deduplicate by chunk identity, keeping highest score per unique key."""
219
  seen: dict[tuple, RetrievalResult] = {}
220
  for r in results:
221
+ key = self._chunk_key(r)
 
222
  if key not in seen or r.score > seen[key].score:
223
  seen[key] = r
224
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)