scipious commited on
Commit
1cbc57f
·
verified ·
1 Parent(s): e28ac0d

Update reg_embedding_system.py

Browse files
Files changed (1) hide show
  1. reg_embedding_system.py +40 -29
reg_embedding_system.py CHANGED
@@ -255,7 +255,7 @@ def search_vectorstore(retriever, query, k=5):
255
  results = retriever.invoke(query)
256
  return results[:k]
257
 
258
- # === search_with_metadata_filter ===
259
  def search_with_metadata_filter(
260
  ensemble_retriever: EnsembleRetriever,
261
  vectorstore: FAISS,
@@ -263,11 +263,16 @@ def search_with_metadata_filter(
263
  k: int = 5,
264
  metadata_filter: Optional[Dict[str, Any]] = None,
265
  sqlite_conn: Optional[sqlite3.Connection] = None,
266
- exact_match: bool = True
267
  ) -> List[Document]:
268
- """SQLite로 사전 필터링 후 FAISS 검색"""
 
 
 
269
  vector_ret, bm25_ret = ensemble_retriever.retrievers
270
 
 
 
271
  # === 1. SQLite에서 필터링된 FAISS ID 추출 ===
272
  filtered_ids = None
273
  if metadata_filter and sqlite_conn:
@@ -276,18 +281,21 @@ def search_with_metadata_filter(
276
  params = []
277
 
278
  for key, value in metadata_filter.items():
279
- #logger.info(f"[key] {key}")
280
- #logger.info(f"[value] {value}")
281
  if isinstance(value, list):
 
282
  if not value:
283
- continue
284
  placeholders = ', '.join(['?'] * len(value))
285
  where_clauses.append(f"{key} IN ({placeholders})")
286
  params.extend(value)
287
  else:
 
288
  where_clauses.append(f"{key} = ?")
289
  params.append(value)
290
 
 
291
  if where_clauses:
292
  where_sql = " OR ".join(where_clauses)
293
  sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}"
@@ -295,34 +303,39 @@ def search_with_metadata_filter(
295
  try:
296
  cursor.execute(sql_query, params)
297
  filtered_ids = {row[0] for row in cursor.fetchall()}
298
- #logger.info(f"[사전 필터링] {len(filtered_ids)}개 ID 획득 → FAISS 검색 제한")
299
  except Exception as e:
300
  logger.info(f"[경고] SQLite 필터링 실패: {e}")
301
  filtered_ids = None
302
- #else:
303
- #logger.info("[안내] 필터 조건 없음 → 전체 검색")
304
- #else:
305
- #logger.info("[안내] 필터 또는 DB 없음 → 전체 검색")
306
 
307
- # === 2. FAISS 벡터 검색 ===
308
  if filtered_ids and len(filtered_ids) > 0:
 
309
  selector = MetadataIDSelector(filtered_ids)
 
 
310
  index: faiss.Index = vectorstore.index
311
-
312
  if not hasattr(index, "search"):
313
  raise ValueError("FAISS 인덱스가 검색을 지원하지 않습니다.")
314
 
 
315
  query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32')
316
  query_embedding = query_embedding.reshape(1, -1)
317
 
 
318
  search_params = faiss.SearchParametersIVF(
319
  sel=selector,
320
- nprobe=50
321
  )
322
 
 
323
  _k = max(k * 10, 100)
324
  D, I = index.search(query_embedding, _k, params=search_params)
325
 
 
326
  valid_indices = [i for i in I[0] if i != -1]
327
  vector_docs = []
328
  for idx in valid_indices[:k]:
@@ -330,29 +343,27 @@ def search_with_metadata_filter(
330
  doc = vectorstore.docstore.search(doc_id)
331
  if isinstance(doc, Document):
332
  vector_docs.append(doc)
333
-
334
- #logger.info(f"[벡터 검색] {len(valid_indices)}개 후보 → {len(vector_docs)}개 유효")
335
  else:
336
- search_k = k * 5
337
- vector_docs = vector_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
338
- #logger.info(f"[벡터 검색] 전체 검색 → {len(vector_docs)}개 후보")
 
339
 
340
- # === 3. BM25 검색 ===
341
  bm25_docs = []
342
- if hasattr(bm25_ret, "invoke"):
343
- search_k = k * 5
344
- candidates = bm25_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
345
- if filtered_ids:
346
- bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids]
347
- else:
348
- bm25_docs = candidates[:k]
349
- #logger.info(f"[BM25 검색] {len(candidates)}개 후보 → {len(bm25_docs)}개 필터링 후")
350
 
351
  # === 4. 병합 및 최종 k개 반환 ===
352
  combined = {id(d): d for d in (vector_docs + bm25_docs)}.values()
353
  final_results = list(combined)[:k]
354
 
355
- #logger.info(f"[최종 결과] {len(final_results)}개 문서 반환")
356
  return final_results
357
 
358
  def get_unique_metadata_values(
 
255
  results = retriever.invoke(query)
256
  return results[:k]
257
 
258
+ # === search_with_metadata_filter (사전 필터링 버전) ===
259
  def search_with_metadata_filter(
260
  ensemble_retriever: EnsembleRetriever,
261
  vectorstore: FAISS,
 
263
  k: int = 5,
264
  metadata_filter: Optional[Dict[str, Any]] = None,
265
  sqlite_conn: Optional[sqlite3.Connection] = None,
266
+ failsafe_search: bool = True
267
  ) -> List[Document]:
268
+ """
269
+ SQLite로 사전 필터링 → FAISS ID 추출 → IDSelector로 FAISS 검색 제한
270
+ → BM25는 post-filtering (BM25는 IDSelector 미지원)
271
+ """
272
  vector_ret, bm25_ret = ensemble_retriever.retrievers
273
 
274
+ vector_docs = []
275
+
276
  # === 1. SQLite에서 필터링된 FAISS ID 추출 ===
277
  filtered_ids = None
278
  if metadata_filter and sqlite_conn:
 
281
  params = []
282
 
283
  for key, value in metadata_filter.items():
284
+ print(f"[key] {key}")
285
+ print(f"[value] {value}")
286
  if isinstance(value, list):
287
+ # IN 쿼리: 리스트 값 지원
288
  if not value:
289
+ continue # 빈 리스트면 무시
290
  placeholders = ', '.join(['?'] * len(value))
291
  where_clauses.append(f"{key} IN ({placeholders})")
292
  params.extend(value)
293
  else:
294
+ # 단일 값
295
  where_clauses.append(f"{key} = ?")
296
  params.append(value)
297
 
298
+
299
  if where_clauses:
300
  where_sql = " OR ".join(where_clauses)
301
  sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}"
 
303
  try:
304
  cursor.execute(sql_query, params)
305
  filtered_ids = {row[0] for row in cursor.fetchall()}
 
306
  except Exception as e:
307
  logger.info(f"[경고] SQLite 필터링 실패: {e}")
308
  filtered_ids = None
309
+ else:
310
+ logger.info("[안내] 필터 조건 없음 → 전체 검색")
311
+ else:
312
+ logger.info("[안내] 필터 또는 DB 없음 → 전체 검색")
313
 
314
+ # === 2. FAISS 벡터 검색 (IDSelector 기반 사전 필터링) ===
315
  if filtered_ids and len(filtered_ids) > 0:
316
+ # IDSelector 생성
317
  selector = MetadataIDSelector(filtered_ids)
318
+
319
+ # FAISS 인덱스 추출
320
  index: faiss.Index = vectorstore.index
 
321
  if not hasattr(index, "search"):
322
  raise ValueError("FAISS 인덱스가 검색을 지원하지 않습니다.")
323
 
324
+ # 쿼리 임베딩
325
  query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32')
326
  query_embedding = query_embedding.reshape(1, -1)
327
 
328
+ # 검색 파라미터 설정
329
  search_params = faiss.SearchParametersIVF(
330
  sel=selector,
331
+ nprobe=50 # 필요시 조정 (성능 vs 재현율)
332
  )
333
 
334
+ # 여유 있게 k * 10개 후보 요청 (필터 후 부족 방지)
335
  _k = max(k * 10, 100)
336
  D, I = index.search(query_embedding, _k, params=search_params)
337
 
338
+ # 유효한 결과만 추출
339
  valid_indices = [i for i in I[0] if i != -1]
340
  vector_docs = []
341
  for idx in valid_indices[:k]:
 
343
  doc = vectorstore.docstore.search(doc_id)
344
  if isinstance(doc, Document):
345
  vector_docs.append(doc)
 
 
346
  else:
347
+ if failsafe_search:
348
+ # 필터 없거나 실패 일반 검색 (기존 방식)
349
+ search_k = k * 5
350
+ vector_docs = vector_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
351
 
352
+ # === 3. BM25 검색 (post-filtering, BM25는 IDSelector 미지원) ===
353
  bm25_docs = []
354
+ if failsafe_search:
355
+ if hasattr(bm25_ret, "invoke"):
356
+ search_k = k * 5
357
+ candidates = bm25_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
358
+ if filtered_ids:
359
+ bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids]
360
+ else:
361
+ bm25_docs = candidates[:k]
362
 
363
  # === 4. 병합 및 최종 k개 반환 ===
364
  combined = {id(d): d for d in (vector_docs + bm25_docs)}.values()
365
  final_results = list(combined)[:k]
366
 
 
367
  return final_results
368
 
369
  def get_unique_metadata_values(