MarlonKegel commited on
Commit
8cd5cc6
·
1 Parent(s): 95f11ac

added diversity cap when searching across all sources

Browse files
Files changed (1) hide show
  1. rag_ui.py +33 -2
rag_ui.py CHANGED
@@ -22,6 +22,7 @@ from rank_bm25 import BM25Okapi
22
  import io
23
  from docx import Document
24
  import hashlib
 
25
 
26
  # Caching for search results function
27
  @st.cache_data(show_spinner=False, max_entries=256)
@@ -182,11 +183,34 @@ def hybrid_search(query, chunk_idx_pool=None, k_sparse=TOPK_SPARSE, k_dense=TOPK
182
  if neighbor is not None:
183
  extra_idxs.add(neighbor)
184
  all_final_idxs = list(dict.fromkeys(list(best_idxs) + list(extra_idxs)))
 
185
  selected_chunks = []
 
 
 
 
 
 
 
 
 
 
186
  for i in all_final_idxs:
187
  if i < len(chunks) and (chunk_idx_pool is None or i in chunk_idx_pool):
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  rationale = []
189
- # Trace rationale category:
190
  sparse_rank = sparse_ranks.get(i)
191
  dense_rank = dense_ranks.get(i)
192
  combined_rank = list(sorted(hybrid_scores, key=hybrid_scores.get, reverse=True)).index(i) if i in hybrid_scores else None
@@ -196,9 +220,16 @@ def hybrid_search(query, chunk_idx_pool=None, k_sparse=TOPK_SPARSE, k_dense=TOPK
196
  rationale.append("high dense similarity (embedding rank top-3)")
197
  if combined_rank is not None and combined_rank < 3:
198
  rationale.append("high combined score (RRF top-3)")
199
- selected_chunk = dict(chunks[i]) # shallow copy, to avoid mutating source
200
  selected_chunk["retrieval_rationale"] = rationale if rationale else ["selected via hybrid search"]
201
  selected_chunks.append(selected_chunk)
 
 
 
 
 
 
 
202
  # --- Sort so that, within each paper_id, chunk_id is ascending ---
203
  selected_chunks.sort(key=lambda c: (c['paper_id'], c['chunk_id']))
204
  return selected_chunks
 
22
  import io
23
  from docx import Document
24
  import hashlib
25
+ import math
26
 
27
  # Caching for search results function
28
  @st.cache_data(show_spinner=False, max_entries=256)
 
183
  if neighbor is not None:
184
  extra_idxs.add(neighbor)
185
  all_final_idxs = list(dict.fromkeys(list(best_idxs) + list(extra_idxs)))
186
+
187
  selected_chunks = []
188
+ source_counts = {}
189
+ author_counts = {}
190
+
191
+ if chunk_idx_pool is None: # Only apply capping when searching all sources
192
+ max_per_source = math.ceil(n_final * 0.5)
193
+ max_per_author = math.ceil(n_final * 0.7)
194
+ else:
195
+ # If subset, no caps
196
+ max_per_source = max_per_author = n_final
197
+
198
  for i in all_final_idxs:
199
  if i < len(chunks) and (chunk_idx_pool is None or i in chunk_idx_pool):
200
+ chunk = chunks[i]
201
+ meta = chunk["metadata"]
202
+ source_id = (meta.get("title", ""), meta.get("authors", "")) # By title & authors (source)
203
+ author_id = meta.get("authors", "")
204
+
205
+ # Count how many from this source and author so far
206
+ s_count = source_counts.get(source_id, 0)
207
+ a_count = author_counts.get(author_id, 0)
208
+
209
+ # Enforce cap only if no source filter
210
+ if s_count >= max_per_source or a_count >= max_per_author:
211
+ continue
212
+
213
  rationale = []
 
214
  sparse_rank = sparse_ranks.get(i)
215
  dense_rank = dense_ranks.get(i)
216
  combined_rank = list(sorted(hybrid_scores, key=hybrid_scores.get, reverse=True)).index(i) if i in hybrid_scores else None
 
220
  rationale.append("high dense similarity (embedding rank top-3)")
221
  if combined_rank is not None and combined_rank < 3:
222
  rationale.append("high combined score (RRF top-3)")
223
+ selected_chunk = dict(chunk) # shallow copy, to avoid mutating source
224
  selected_chunk["retrieval_rationale"] = rationale if rationale else ["selected via hybrid search"]
225
  selected_chunks.append(selected_chunk)
226
+ # Update counts
227
+ source_counts[source_id] = s_count + 1
228
+ author_counts[author_id] = a_count + 1
229
+ # Stop early if we have enough
230
+ if len(selected_chunks) >= n_final:
231
+ break
232
+
233
  # --- Sort so that, within each paper_id, chunk_id is ascending ---
234
  selected_chunks.sort(key=lambda c: (c['paper_id'], c['chunk_id']))
235
  return selected_chunks