Spaces:
Running
Running
| import gc | |
| import json | |
| import sqlite3 | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Any, Dict, List, Set | |
| from collections import Counter | |
| import numpy as np | |
| import faiss | |
| from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
| from langchain_core.documents import Document | |
| from langchain_community.vectorstores import FAISS | |
| from sentence_transformers import SentenceTransformer | |
| # 런타임에 Embeddings 클래스를 찾기 위한 로직 | |
| try: | |
| from langchain_core.embeddings import Embeddings | |
| except ImportError: | |
| try: | |
| from langchain.embeddings.base import Embeddings | |
| except ImportError: | |
| Embeddings = object | |
| import logging | |
| # 로거 설정: 레벨을 INFO로 설정하고, 포맷을 지정합니다. | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --- SQLite 헬퍼 함수 --- | |
| SQLITE_DB_NAME = "metadata_mapping.db" | |
| # === IDSelector 클래스 정의 === | |
| class MetadataIDSelector(faiss.IDSelectorBatch): | |
| def __init__(self, allowed_ids: Set[int]): | |
| super().__init__(list(allowed_ids)) | |
| def get_db_connection(persist_directory: str) -> sqlite3.Connection: | |
| """FAISS 저장 경로를 기반으로 SQLite 연결을 설정하고 반환합니다.""" | |
| db_path = Path(persist_directory) / SQLITE_DB_NAME | |
| conn = sqlite3.connect(db_path) | |
| return conn | |
| def _create_and_populate_sqlite_db(chunks: List[Document], persist_directory: str): | |
| """문서 청크를 기반으로 SQLite DB를 생성하고 채웁니다.""" | |
| conn = get_db_connection(persist_directory) | |
| cursor = conn.cursor() | |
| # 1. 테이블 생성 | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS documents ( | |
| faiss_id INTEGER PRIMARY KEY, | |
| regulation_part TEXT, | |
| regulation_section TEXT, | |
| chapter_section TEXT, | |
| jo TEXT, | |
| json_metadata TEXT | |
| ) | |
| """) | |
| conn.commit() | |
| # 2. 데이터 채우기 | |
| for i, doc in enumerate(chunks): | |
| faiss_id = i | |
| metadata_json = json.dumps(doc.metadata, ensure_ascii=False) | |
| reg_part = doc.metadata.get('regulation_part') | |
| reg_section = doc.metadata.get('regulation_section') | |
| reg_chapter = doc.metadata.get('chapter_section') | |
| reg_jo = doc.metadata.get('jo') | |
| # 변수가 리스트인 경우, 쉼표로 구분된 문자열로 변환 | |
| if isinstance(reg_section, list): | |
| reg_section = ', '.join(map(str, reg_section)) | |
| if isinstance(reg_part, list): | |
| reg_part = ', '.join(map(str, reg_part)) | |
| if isinstance(reg_chapter, list): | |
| reg_chapter = ', '.join(map(str, reg_chapter)) | |
| if isinstance(reg_jo, list): | |
| reg_jo = ', '.join(map(str, reg_jo)) | |
| # 문서 메타데이터에 FAISS ID 추가 | |
| doc.metadata['faiss_id'] = faiss_id | |
| cursor.execute( | |
| "INSERT OR REPLACE INTO documents (faiss_id, regulation_part, regulation_section, chapter_section, jo, json_metadata) VALUES (?, ?, ?, ?, ?, ?)", | |
| (faiss_id, reg_part, reg_section, reg_chapter, reg_jo, metadata_json) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| # --- LocalSentenceTransformerEmbeddings --- | |
| class LocalSentenceTransformerEmbeddings(Embeddings): | |
| """SentenceTransformer를 LangChain Embeddings 인터페이스로 래핑""" | |
| def __init__(self, st_model, normalize_embeddings: bool = True, encode_batch_size: int = 32): | |
| self.model = st_model | |
| self.normalize = normalize_embeddings | |
| self.encode_batch_size = encode_batch_size | |
| def embed_documents(self, texts): | |
| vecs = self.model.encode( | |
| texts, | |
| batch_size=self.encode_batch_size, | |
| show_progress_bar=False, | |
| normalize_embeddings=self.normalize, | |
| convert_to_numpy=True, | |
| ) | |
| return vecs.tolist() | |
| def embed_query(self, text: str): | |
| vec = self.model.encode( | |
| [text], | |
| batch_size=self.encode_batch_size, | |
| show_progress_bar=False, | |
| normalize_embeddings=self.normalize, | |
| convert_to_numpy=True, | |
| )[0] | |
| return vec.tolist() | |
| # --- save_embedding_system --- | |
| def save_embedding_system( | |
| chunks, | |
| persist_directory: str = r"D:/Project AI/RAG", | |
| batch_size: int = 32, | |
| device: str = 'cuda' | |
| ): | |
| """ | |
| 청크를 임베딩하여 FAISS 벡터스토어와 앙상블 리트리버를 생성하고, | |
| SQLite DB에 메타데이터를 저장합니다. | |
| """ | |
| Path(persist_directory).mkdir(parents=True, exist_ok=True) | |
| # 1) SQLite DB에 메타데이터 저장 및 청크에 faiss_id 추가 | |
| _create_and_populate_sqlite_db(chunks, persist_directory) | |
| # 2) SentenceTransformer 로드 | |
| model = SentenceTransformer( | |
| 'nomic-ai/nomic-embed-text-v2-moe', | |
| trust_remote_code=True, | |
| device=device | |
| ) | |
| embeddings = LocalSentenceTransformerEmbeddings( | |
| st_model=model, | |
| normalize_embeddings=True, | |
| encode_batch_size=batch_size | |
| ) | |
| # 3) FAISS 벡터스토어 생성 | |
| vectorstore = None | |
| for i in range(0, len(chunks), batch_size): | |
| batch = chunks[i:i + batch_size] | |
| if vectorstore is None: | |
| vectorstore = FAISS.from_documents(documents=batch, embedding=embeddings) | |
| else: | |
| vectorstore.add_documents(documents=batch) | |
| gc.collect() | |
| # 4) BM25 + 벡터 앙상블 리트리버 생성 | |
| bm25_retriever = BM25Retriever.from_documents(chunks) | |
| bm25_retriever.k = 5 | |
| vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[vector_retriever, bm25_retriever], | |
| weights=[0.6, 0.4] | |
| ) | |
| # 5) FAISS 인덱스 저장 | |
| vectorstore.save_local(persist_directory) | |
| # 6) SQLite 연결 | |
| sqlite_conn = get_db_connection(persist_directory) | |
| gc.collect() | |
| return ensemble_retriever, vectorstore, sqlite_conn | |
| # --- load_embedding_from_faiss --- | |
| def load_embedding_from_faiss( | |
| persist_directory: str = r"D:/Project AI/RAG", | |
| top_k: int = 10, | |
| bm25_k: int = 10, | |
| weights: Tuple[float, float] = (0.6, 0.4), | |
| embeddings: Optional[Any] = None, | |
| device: str = 'cpu' | |
| ) -> Tuple[Any, FAISS, sqlite3.Connection]: | |
| """ | |
| 저장된 FAISS 인덱스와 SQLite 연결을 로드하여 앙상블 리트리버를 생성합니다. | |
| """ | |
| # 1) Embeddings 준비 | |
| if embeddings is None: | |
| st_model = SentenceTransformer( | |
| 'nomic-ai/nomic-embed-text-v2-moe', | |
| trust_remote_code=True, | |
| device=device | |
| ) | |
| embeddings = LocalSentenceTransformerEmbeddings( | |
| st_model=st_model, | |
| normalize_embeddings=True, | |
| encode_batch_size=32 | |
| ) | |
| # 2) FAISS 벡터스토어 로드 (Pydantic v1 호환 옵션 추가) | |
| persist_dir = Path(persist_directory) | |
| if not persist_dir.exists(): | |
| raise FileNotFoundError(f"FAISS 경로가 없습니다: {persist_dir}") | |
| try: | |
| vectorstore = FAISS.load_local( | |
| folder_path=str(persist_dir), | |
| embeddings=embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| logger.info(f"[로드 성공] FAISS 인덱스 로드 완료: {persist_dir}") | |
| except Exception as e: | |
| logger.info(f"[로드 오류] FAISS 로드 실패: {e}") | |
| raise | |
| # 3) BM25를 위한 문서 추출 | |
| docs = [] | |
| try: | |
| if hasattr(vectorstore, "docstore") and hasattr(vectorstore.docstore, "_dict"): | |
| docs = list(vectorstore.docstore._dict.values()) | |
| except Exception as e: | |
| logger.info(f"[경고] 저장된 문서를 읽는 중 문제가 발생했습니다: {e}") | |
| # 4) 앙상블 리트리버 구성 | |
| vector_retriever = vectorstore.as_retriever(search_kwargs={"k": top_k}) | |
| if docs: | |
| bm25_retriever = BM25Retriever.from_documents(docs) | |
| bm25_retriever.k = bm25_k | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[vector_retriever, bm25_retriever], | |
| weights=list(weights) | |
| ) | |
| else: | |
| logger.info("[안내] 문서를 찾지 못해 BM25 없이 벡터 리트리버만 반환합니다.") | |
| ensemble_retriever = vector_retriever | |
| # 5) SQLite 연결 | |
| sqlite_conn = get_db_connection(persist_directory) | |
| return ensemble_retriever, vectorstore, sqlite_conn | |
| # --- search_vectorstore --- | |
| def search_vectorstore(retriever, query, k=5): | |
| """리트리버를 사용해 쿼리와 관련된 문서를 검색합니다.""" | |
| results = retriever.invoke(query) | |
| return results[:k] | |
| # === search_with_metadata_filter === | |
| def search_with_metadata_filter( | |
| ensemble_retriever: EnsembleRetriever, | |
| vectorstore: FAISS, | |
| query: str, | |
| k: int = 5, | |
| metadata_filter: Optional[Dict[str, Any]] = None, | |
| sqlite_conn: Optional[sqlite3.Connection] = None, | |
| exact_match: bool = True | |
| ) -> List[Document]: | |
| """SQLite로 사전 필터링 후 FAISS 검색""" | |
| vector_ret, bm25_ret = ensemble_retriever.retrievers | |
| # === 1. SQLite에서 필터링된 FAISS ID 추출 === | |
| filtered_ids = None | |
| if metadata_filter and sqlite_conn: | |
| cursor = sqlite_conn.cursor() | |
| where_clauses = [] | |
| params = [] | |
| for key, value in metadata_filter.items(): | |
| #logger.info(f"[key] {key}") | |
| #logger.info(f"[value] {value}") | |
| if isinstance(value, list): | |
| if not value: | |
| continue | |
| placeholders = ', '.join(['?'] * len(value)) | |
| where_clauses.append(f"{key} IN ({placeholders})") | |
| params.extend(value) | |
| else: | |
| where_clauses.append(f"{key} = ?") | |
| params.append(value) | |
| if where_clauses: | |
| where_sql = " OR ".join(where_clauses) | |
| sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}" | |
| try: | |
| cursor.execute(sql_query, params) | |
| filtered_ids = {row[0] for row in cursor.fetchall()} | |
| #logger.info(f"[사전 필터링] {len(filtered_ids)}개 ID 획득 → FAISS 검색 제한") | |
| except Exception as e: | |
| logger.info(f"[경고] SQLite 필터링 실패: {e}") | |
| filtered_ids = None | |
| #else: | |
| #logger.info("[안내] 필터 조건 없음 → 전체 검색") | |
| #else: | |
| #logger.info("[안내] 필터 또는 DB 없음 → 전체 검색") | |
| # === 2. FAISS 벡터 검색 === | |
| if filtered_ids and len(filtered_ids) > 0: | |
| selector = MetadataIDSelector(filtered_ids) | |
| index: faiss.Index = vectorstore.index | |
| if not hasattr(index, "search"): | |
| raise ValueError("FAISS 인덱스가 검색을 지원하지 않습니다.") | |
| query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32') | |
| query_embedding = query_embedding.reshape(1, -1) | |
| search_params = faiss.SearchParametersIVF( | |
| sel=selector, | |
| nprobe=50 | |
| ) | |
| _k = max(k * 10, 100) | |
| D, I = index.search(query_embedding, _k, params=search_params) | |
| valid_indices = [i for i in I[0] if i != -1] | |
| vector_docs = [] | |
| for idx in valid_indices[:k]: | |
| doc_id = vectorstore.index_to_docstore_id[idx] | |
| doc = vectorstore.docstore.search(doc_id) | |
| if isinstance(doc, Document): | |
| vector_docs.append(doc) | |
| #logger.info(f"[벡터 검색] {len(valid_indices)}개 후보 → {len(vector_docs)}개 유효") | |
| else: | |
| search_k = k * 5 | |
| vector_docs = vector_ret.invoke(query, config={"search_kwargs": {"k": search_k}}) | |
| #logger.info(f"[벡터 검색] 전체 검색 → {len(vector_docs)}개 후보") | |
| # === 3. BM25 검색 === | |
| bm25_docs = [] | |
| if hasattr(bm25_ret, "invoke"): | |
| search_k = k * 5 | |
| candidates = bm25_ret.invoke(query, config={"search_kwargs": {"k": search_k}}) | |
| if filtered_ids: | |
| bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids] | |
| else: | |
| bm25_docs = candidates[:k] | |
| #logger.info(f"[BM25 검색] {len(candidates)}개 후보 → {len(bm25_docs)}개 필터링 후") | |
| # === 4. 병합 및 최종 k개 반환 === | |
| combined = {id(d): d for d in (vector_docs + bm25_docs)}.values() | |
| final_results = list(combined)[:k] | |
| #logger.info(f"[최종 결과] {len(final_results)}개 문서 반환") | |
| return final_results | |
| def get_unique_metadata_values( | |
| sqlite_conn: sqlite3.Connection, | |
| key_name: str, | |
| partial_match: Optional[str] = None | |
| ) -> List[str]: | |
| """SQLite에서 특정 컬럼의 고유한 값 리스트를 반환합니다.""" | |
| if not sqlite_conn: | |
| logger.info("[경고] SQLite 연결이 없어 고유 값 검색을 수행할 수 없습니다.") | |
| return [] | |
| cursor = sqlite_conn.cursor() | |
| sql_query = f"SELECT DISTINCT `{key_name}` FROM documents" | |
| params = [] | |
| if partial_match: | |
| sql_query += f" WHERE `{key_name}` LIKE ?" | |
| params.append(f"%{partial_match}%") | |
| try: | |
| cursor.execute(sql_query, params) | |
| unique_values = [row[0] for row in cursor.fetchall() if row[0] is not None] | |
| return unique_values | |
| except sqlite3.OperationalError as e: | |
| logger.info(f"[에러] SQLite 쿼리 실행 실패 (컬럼 '{key_name}' 이름 오류 가능): {e}") | |
| return [] | |
| except Exception as e: | |
| logger.info(f"[에러] 고유 값 검색 중 알 수 없는 오류 발생: {e}") | |
| return [] | |
| def smart_search_vectorstore( | |
| retriever, | |
| query, | |
| k=5, | |
| vectorstore=None, | |
| sqlite_conn=None, | |
| enable_detailed_search=True | |
| ): | |
| """기본 검색 + 상세 검색 수행""" | |
| # 1. 기본 검색 | |
| basic_results = retriever.invoke(query) | |
| basic_results = basic_results[:k] | |
| #logger.info(f"[기본 검색] {len(basic_results)}개 문서 검색 완료") | |
| if not enable_detailed_search or not vectorstore or not sqlite_conn: | |
| logger.info("[안내] 상세 검색 비활성화 또는 컴포넌트 부족 → 기본 검색 결과만 반환") | |
| return basic_results | |
| # 2. regulation_part 빈도 분석 | |
| regulation_parts = [] | |
| for doc in basic_results: | |
| reg_part = doc.metadata.get('regulation_part') | |
| if reg_part: | |
| if isinstance(reg_part, list): | |
| regulation_parts.extend(reg_part) | |
| elif isinstance(reg_part, str): | |
| if ',' in reg_part: | |
| regulation_parts.extend([part.strip() for part in reg_part.split(',')]) | |
| else: | |
| regulation_parts.append(reg_part) | |
| if not regulation_parts: | |
| logger.info("[안내] regulation_part 메타데이터 없음 → 기본 검색 결과만 반환") | |
| return basic_results | |
| counter = Counter(regulation_parts) | |
| most_extracted_category = counter.most_common(2) | |
| #logger.info(f"[빈도 분석] regulation_part 빈도: {dict(counter)}") | |
| #logger.info(f"[상위 카테고리] {most_extracted_category}") | |
| # 3. 상세 검색 | |
| detailed_results = [] | |
| for rank, (category, count) in enumerate(most_extracted_category, 1): | |
| #logger.info(f"[상세 검색 {rank}순위] '{category}' 카테고리 검색 시작 (빈도: {count})") | |
| metadata_filter = {'regulation_part': category} | |
| try: | |
| category_results = search_with_metadata_filter( | |
| ensemble_retriever=retriever, | |
| vectorstore=vectorstore, | |
| query=query, | |
| k=k, | |
| metadata_filter=metadata_filter, | |
| sqlite_conn=sqlite_conn | |
| ) | |
| detailed_results.extend(category_results) | |
| #logger.info(f"[상세 검색 {rank}순위] {len(category_results)}개 추가 문서 검색 완료") | |
| except Exception as e: | |
| #logger.info(f"[경고] 상세 검색 {rank}순위 실패 ({category}): {e}") | |
| continue | |
| # 4. 결과 병합 | |
| seen = set() | |
| final_results = [] | |
| #Detailed 검색 결과를 먼저 추가 | |
| for doc in detailed_results: | |
| doc_signature = (doc.page_content, str(sorted(doc.metadata.items()))) | |
| if doc_signature not in seen: | |
| seen.add(doc_signature) | |
| final_results.append(doc) | |
| for doc in basic_results: | |
| doc_signature = (doc.page_content, str(sorted(doc.metadata.items()))) | |
| if doc_signature not in seen: | |
| seen.add(doc_signature) | |
| final_results.append(doc) | |
| final_results = final_results[:k] | |
| #logger.info(f"[최종 결과] 기본 {len(basic_results)}개 + 상세 {len(detailed_results)}개 → 중복 제거 후 {len(final_results)}개 반환") | |
| return final_results | |
| # natural_sort_key 함수 추가 (app.py에서 사용됨) | |
| import re | |
| def natural_sort_key(s): | |
| """자연스러운 정렬을 위한 키 함수""" | |
| return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', str(s))] |