lexistudio / reg_embedding_system.py
scipious's picture
Upload 16 files
c514a40 verified
raw
history blame
20.5 kB
import gc
import json
import sqlite3 # SQLite 모듈 추가
from pathlib import Path
from typing import Optional, Tuple, Any, Dict, List, Set
from collections import Counter # 빈도 계산을 위한 추가
import numpy as np
import faiss
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
# 런타임에 Embeddings 클래스를 찾기 위한 로직 (원본 코드 유지)
try:
from langchain_core.embeddings import Embeddings
except ImportError:
try:
from langchain.embeddings.base import Embeddings
except ImportError:
Embeddings = object
# --- SQLite 헬퍼 함수 ---
SQLITE_DB_NAME = "metadata_mapping.db"
# === IDSelector 클래스 정의 (파일 상단 또는 함수 외부에 위치) ===
# IDSelector 대신 IDSelectorBatch 사용
class MetadataIDSelector(faiss.IDSelectorBatch):
def __init__(self, allowed_ids: Set[int]):
# IDSelectorBatch는 allowed_ids 리스트를 직접 받음
super().__init__(list(allowed_ids)) # 리스트로 변환해서 전달
def get_db_connection(persist_directory: str) -> sqlite3.Connection:
"""FAISS 저장 경로를 기반으로 SQLite 연결을 설정하고 반환합니다."""
db_path = Path(persist_directory) / SQLITE_DB_NAME
conn = sqlite3.connect(db_path)
return conn
def _create_and_populate_sqlite_db(chunks: List[Document], persist_directory: str):
"""문서 청크를 기반으로 SQLite DB를 생성하고 채웁니다. (save_embedding_system 내부용)"""
conn = get_db_connection(persist_directory)
cursor = conn.cursor()
# 1. 테이블 생성
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
faiss_id INTEGER PRIMARY KEY,
regulation_part TEXT,
regulation_section TEXT,
chapter_section TEXT,
jo TEXT,
json_metadata TEXT
)
""")
conn.commit()
# 2. 데이터 채우기: FAISS ID는 청크 인덱스 i와 동일하게 매핑
for i, doc in enumerate(chunks):
faiss_id = i
metadata_json = json.dumps(doc.metadata, ensure_ascii=False)
reg_part = doc.metadata.get('regulation_part')
reg_section = doc.metadata.get('regulation_section')
reg_chapter = doc.metadata.get('chapter_section')
reg_jo = doc.metadata.get('jo')
# 변수가 리스트인 경우, 쉼표로 구분된 문자열로 변환
if isinstance(reg_section, list):
reg_section = ', '.join(map(str, reg_section))
if isinstance(reg_part, list):
reg_part = ', '.join(map(str, reg_part))
if isinstance(reg_chapter, list):
reg_chapter = ', '.join(map(str, reg_chapter))
if isinstance(reg_jo, list):
reg_jo = ', '.join(map(str, reg_jo))
# 문서 메타데이터에 FAISS ID 추가 (검색 후 필터링 함수에서 활용)
doc.metadata['faiss_id'] = faiss_id
cursor.execute(
"INSERT OR REPLACE INTO documents (faiss_id, regulation_part, regulation_section, chapter_section, jo, json_metadata) VALUES (?, ?, ?, ?, ?, ?)",
(faiss_id, reg_part, reg_section, reg_chapter, reg_jo, metadata_json)
)
conn.commit()
conn.close()
# --- LocalSentenceTransformerEmbeddings (변경 없음) ---
class LocalSentenceTransformerEmbeddings(Embeddings):
"""SentenceTransformer를 LangChain Embeddings 인터페이스로 래핑"""
def __init__(self, st_model, normalize_embeddings: bool = True, encode_batch_size: int = 32):
self.model = st_model
self.normalize = normalize_embeddings
self.encode_batch_size = encode_batch_size
def embed_documents(self, texts):
vecs = self.model.encode(
texts,
batch_size=self.encode_batch_size,
show_progress_bar=False,
normalize_embeddings=self.normalize,
convert_to_numpy=True,
)
return vecs.tolist()
def embed_query(self, text: str):
vec = self.model.encode(
[text],
batch_size=self.encode_batch_size,
show_progress_bar=False,
normalize_embeddings=self.normalize,
convert_to_numpy=True,
)[0]
return vec.tolist()
# --- save_embedding_system (SQLite 저장 로직 추가) ---
def save_embedding_system(
chunks,
persist_directory: str = r"D:/Project AI/RAG",
batch_size: int = 32,
device: str = 'cuda'
):
"""
청크를 임베딩하여 FAISS 벡터스토어와 앙상블 리트리버를 생성하고,
SQLite DB에 메타데이터를 저장합니다.
"""
Path(persist_directory).mkdir(parents=True, exist_ok=True)
# 1) SQLite DB에 메타데이터 저장 및 청크에 faiss_id 추가
# 이 함수 내에서 chunks 리스트의 metadata에 'faiss_id'가 추가됩니다.
_create_and_populate_sqlite_db(chunks, persist_directory)
# 2) SentenceTransformer 로드 (원래 코드와 동일)
model = SentenceTransformer(
'nomic-ai/nomic-embed-text-v2-moe',
trust_remote_code=True,
device=device
)
embeddings = LocalSentenceTransformerEmbeddings(
st_model=model,
normalize_embeddings=True,
encode_batch_size=batch_size
)
# 3) FAISS 벡터스토어 생성 (원래 코드와 동일)
vectorstore = None
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
if vectorstore is None:
vectorstore = FAISS.from_documents(documents=batch, embedding=embeddings)
else:
vectorstore.add_documents(documents=batch)
gc.collect()
# 4) BM25 + 벡터 앙상블 리트리버 생성 (원래 코드와 동일)
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 5
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.6, 0.4]
)
# 5) FAISS 인덱스 저장 (원래 코드와 동일)
vectorstore.save_local(persist_directory)
# 6) SQLite 연결
sqlite_conn = get_db_connection(persist_directory)
gc.collect()
return ensemble_retriever, vectorstore, sqlite_conn
# --- load_embedding_from_faiss (SQLite 연결 반환 추가) ---
def load_embedding_from_faiss(
persist_directory: str = r"D:/Project AI/RAG",
top_k: int = 10,
bm25_k: int = 10,
weights: Tuple[float, float] = (0.6, 0.4),
embeddings: Optional[Any] = None,
device: str = 'cpu'
) -> Tuple[Any, FAISS, sqlite3.Connection]: # 반환 값에 SQLite 연결 추가
"""
저장된 FAISS 인덱스와 SQLite 연결을 로드하여 앙상블 리트리버를 생성합니다.
"""
# 1) Embeddings 준비 (원래 코드와 동일)
if embeddings is None:
st_model = SentenceTransformer(
'nomic-ai/nomic-embed-text-v2-moe',
trust_remote_code=True,
device=device
)
embeddings = LocalSentenceTransformerEmbeddings(
st_model=st_model,
normalize_embeddings=True,
encode_batch_size=32
)
# 2) FAISS 벡터스토어 로드 (원래 코드와 동일)
persist_dir = Path(persist_directory)
if not persist_dir.exists():
raise FileNotFoundError(f"FAISS 경로가 없습니다: {persist_dir}")
vectorstore = FAISS.load_local(
folder_path=str(persist_dir),
embeddings=embeddings,
allow_dangerous_deserialization=True
)
# 3) BM25를 위한 문서 추출 (원래 코드와 동일)
docs = []
try:
# FAISS docstore에서 문서 추출
if hasattr(vectorstore, "docstore") and hasattr(vectorstore.docstore, "_dict"):
docs = list(vectorstore.docstore._dict.values())
except Exception as e:
print(f"[경고] 저장된 문서를 읽는 중 문제가 발생했습니다: {e}")
# 4) 앙상블 리트리버 구성 (원래 코드와 동일)
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": top_k})
if docs:
bm25_retriever = BM25Retriever.from_documents(docs)
bm25_retriever.k = bm25_k
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=list(weights)
)
else:
print("[안내] 문서를 찾지 못해 BM25 없이 벡터 리트리버만 반환합니다.")
ensemble_retriever = vector_retriever
# 5) SQLite 연결
sqlite_conn = get_db_connection(persist_directory)
return ensemble_retriever, vectorstore, sqlite_conn # SQLite 연결 반환
# --- search_vectorstore (변경 없음) ---
def search_vectorstore(retriever, query, k=5):
"""리트리버를 사용해 쿼리와 관련된 문서를 검색합니다."""
results = retriever.invoke(query)
return results[:k]
# === search_with_metadata_filter (사전 필터링 버전) ===
def search_with_metadata_filter(
ensemble_retriever: EnsembleRetriever,
vectorstore: FAISS,
query: str,
k: int = 5,
metadata_filter: Optional[Dict[str, Any]] = None,
sqlite_conn: Optional[sqlite3.Connection] = None,
exact_match: bool = True
) -> List[Document]:
"""
SQLite로 사전 필터링 → FAISS ID 추출 → IDSelector로 FAISS 검색 제한
→ BM25는 post-filtering (BM25는 IDSelector 미지원)
"""
vector_ret, bm25_ret = ensemble_retriever.retrievers
# === 1. SQLite에서 필터링된 FAISS ID 추출 ===
filtered_ids = None
if metadata_filter and sqlite_conn:
cursor = sqlite_conn.cursor()
where_clauses = []
params = []
for key, value in metadata_filter.items():
print(f"[key] {key}")
print(f"[value] {value}")
if isinstance(value, list):
# IN 쿼리: 리스트 값 지원
if not value:
continue # 빈 리스트면 무시
placeholders = ', '.join(['?'] * len(value))
where_clauses.append(f"{key} IN ({placeholders})")
params.extend(value)
else:
# 단일 값
where_clauses.append(f"{key} = ?")
params.append(value)
if where_clauses:
where_sql = " OR ".join(where_clauses)
sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}"
try:
cursor.execute(sql_query, params)
filtered_ids = {row[0] for row in cursor.fetchall()}
print(f"[사전 필터링] {len(filtered_ids)}개 ID 획득 → FAISS 검색 제한")
except Exception as e:
print(f"[경고] SQLite 필터링 실패: {e}")
filtered_ids = None
else:
print("[안내] 필터 조건 없음 → 전체 검색")
else:
print("[안내] 필터 또는 DB 없음 → 전체 검색")
# === 2. FAISS 벡터 검색 (IDSelector 기반 사전 필터링) ===
if filtered_ids and len(filtered_ids) > 0:
# IDSelector 생성
selector = MetadataIDSelector(filtered_ids)
# FAISS 인덱스 추출
index: faiss.Index = vectorstore.index
if not hasattr(index, "search"):
raise ValueError("FAISS 인덱스가 검색을 지원하지 않습니다.")
# 쿼리 임베딩
query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32')
query_embedding = query_embedding.reshape(1, -1)
# 검색 파라미터 설정
search_params = faiss.SearchParametersIVF(
sel=selector,
nprobe=50 # 필요시 조정 (성능 vs 재현율)
)
# 여유 있게 k * 10개 후보 요청 (필터 후 부족 방지)
_k = max(k * 10, 100)
D, I = index.search(query_embedding, _k, params=search_params)
# 유효한 결과만 추출
valid_indices = [i for i in I[0] if i != -1]
vector_docs = []
for idx in valid_indices[:k]:
doc_id = vectorstore.index_to_docstore_id[idx]
doc = vectorstore.docstore.search(doc_id)
if isinstance(doc, Document):
vector_docs.append(doc)
print(f"[벡터 검색] {len(valid_indices)}개 후보 → {len(vector_docs)}개 유효")
else:
# 필터 없거나 실패 → 일반 검색 (기존 방식)
search_k = k * 5
vector_docs = vector_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
print(f"[벡터 검색] 전체 검색 → {len(vector_docs)}개 후보")
# === 3. BM25 검색 (post-filtering, BM25는 IDSelector 미지원) ===
bm25_docs = []
if hasattr(bm25_ret, "invoke"):
search_k = k * 5
candidates = bm25_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
if filtered_ids:
bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids]
else:
bm25_docs = candidates[:k]
print(f"[BM25 검색] {len(candidates)}개 후보 → {len(bm25_docs)}개 필터링 후")
# === 4. 병합 및 최종 k개 반환 ===
combined = {id(d): d for d in (vector_docs + bm25_docs)}.values()
final_results = list(combined)[:k]
print(f"[최종 결과] {len(final_results)}개 문서 반환")
return final_results
def get_unique_metadata_values(
sqlite_conn: sqlite3.Connection,
key_name: str,
partial_match: Optional[str] = None
) -> List[str]:
"""
SQLite 'documents' 테이블에서 특정 컬럼(key_name)의 중복되지 않은
모든 고유 값 리스트를 반환합니다.
Args:
sqlite_conn: SQLite 데이터베이스 연결 객체.
key_name: 고유한 값을 가져올 컬럼 이름 (예: 'regulation_name', 'part_name').
partial_match: (선택 사항) 해당 문자열을 포함하는 값만 검색할 때 사용.
Returns:
중복이 제거된 고유한 값들의 리스트.
"""
if not sqlite_conn:
print("[경고] SQLite 연결이 없어 고유 값 검색을 수행할 수 없습니다.")
return []
cursor = sqlite_conn.cursor()
# SQL 쿼리 구성
# 1. 컬럼 이름에 백틱(`)을 사용하여 안전성 확보
# 2. DISTINCT를 사용하여 중복 제거
sql_query = f"SELECT DISTINCT `{key_name}` FROM documents"
params = []
# 부분 문자열 검색 (LIKE) 조건 추가
if partial_match:
sql_query += f" WHERE `{key_name}` LIKE ?"
params.append(f"%{partial_match}%")
try:
cursor.execute(sql_query, params)
# 쿼리 결과에서 첫 번째 항목 (값)만 추출
unique_values = [row[0] for row in cursor.fetchall() if row[0] is not None]
return unique_values
except sqlite3.OperationalError as e:
# 컬럼 이름이 DB에 없을 때 발생하는 에러 처리
print(f"[에러] SQLite 쿼리 실행 실패 (컬럼 '{key_name}' 이름 오류 가능): {e}")
return []
except Exception as e:
print(f"[에러] 고유 값 검색 중 알 수 없는 오류 발생: {e}")
return []
def smart_search_vectorstore(
retriever,
query,
k=5,
vectorstore=None,
sqlite_conn=None,
enable_detailed_search=True
):
"""
리트리버를 사용해 쿼리와 관련된 문서를 검색하고,
선택적으로 가장 빈번한 regulation_part 카테고리에 대해 상세 검색을 수행합니다.
Args:
retriever: 기본 검색에 사용할 리트리버
query: 검색 쿼리
k: 반환할 문서 수
vectorstore: FAISS 벡터스토어 (상세 검색용)
sqlite_conn: SQLite 연결 (상세 검색용)
enable_detailed_search: 상세 검색 활성화 여부
Returns:
검색된 문서 리스트 (기본 검색 + 상세 검색 결과)
"""
# 1. 기본 검색 수행
basic_results = retriever.invoke(query)
basic_results = basic_results[:k]
print(f"[기본 검색] {len(basic_results)}개 문서 검색 완료")
# 상세 검색이 비활성화되었거나 필요한 컴포넌트가 없으면 기본 결과만 반환
if not enable_detailed_search or not vectorstore or not sqlite_conn:
print("[안내] 상세 검색 비활성화 또는 컴포넌트 부족 → 기본 검색 결과만 반환")
return basic_results
# 2. regulation_part 메타데이터 빈도 분석
regulation_parts = []
for doc in basic_results:
reg_part = doc.metadata.get('regulation_part')
if reg_part:
# regulation_part가 리스트인 경우 모든 항목 추가
if isinstance(reg_part, list):
regulation_parts.extend(reg_part)
elif isinstance(reg_part, str):
# 쉼표로 구분된 문자열인 경우 분리
if ',' in reg_part:
regulation_parts.extend([part.strip() for part in reg_part.split(',')])
else:
regulation_parts.append(reg_part)
# 빈도 계산 및 상위 카테고리 추출
if not regulation_parts:
print("[안내] regulation_part 메타데이터 없음 → 기본 검색 결과만 반환")
return basic_results
counter = Counter(regulation_parts)
most_extracted_category = counter.most_common(2) # 상위 2개 카테고리
print(f"[빈도 분석] regulation_part 빈도: {dict(counter)}")
print(f"[상위 카테고리] {most_extracted_category}")
# 3. 상위 카테고리에 대한 상세 검색 수행
detailed_results = []
for rank, (category, count) in enumerate(most_extracted_category, 1):
print(f"[상세 검색 {rank}순위] '{category}' 카테고리 검색 시작 (빈도: {count})")
# metadata_filter 구성
metadata_filter = {'regulation_part': category}
try:
# search_with_metadata_filter 호출
category_results = search_with_metadata_filter(
ensemble_retriever=retriever,
vectorstore=vectorstore,
query=query,
k=k,
metadata_filter=metadata_filter,
sqlite_conn=sqlite_conn
)
detailed_results.extend(category_results)
print(f"[상세 검색 {rank}순위] {len(category_results)}개 추가 문서 검색 완료")
except Exception as e:
print(f"[경고] 상세 검색 {rank}순위 실패 ({category}): {e}")
continue
# 4. 결과 병합 (중복 제거)
# Document 객체의 고유성을 위해 page_content와 metadata의 조합으로 중복 판단
seen = set()
final_results = []
# 기본 검색 결과 우선 추가
for doc in basic_results:
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
if doc_signature not in seen:
seen.add(doc_signature)
final_results.append(doc)
# 상세 검색 결과 추가 (중복 제거)
for doc in detailed_results:
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
if doc_signature not in seen:
seen.add(doc_signature)
final_results.append(doc)
# 최종 k개로 제한
final_results = final_results[:k]
print(f"[최종 결과] 기본 {len(basic_results)}개 + 상세 {len(detailed_results)}개 → 중복 제거 후 {len(final_results)}개 반환")
return final_results