lexistudio / reg_embedding_system.py
scipious's picture
Update reg_embedding_system.py
9eddc99 verified
raw
history blame
17.3 kB
import gc
import json
import sqlite3
from pathlib import Path
from typing import Optional, Tuple, Any, Dict, List, Set
from collections import Counter
import numpy as np
import faiss
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
# 런타임에 Embeddings 클래스를 찾기 위한 로직
try:
from langchain_core.embeddings import Embeddings
except ImportError:
try:
from langchain.embeddings.base import Embeddings
except ImportError:
Embeddings = object
import logging
# 로거 설정: 레벨을 INFO로 설정하고, 포맷을 지정합니다.
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# --- SQLite 헬퍼 함수 ---
SQLITE_DB_NAME = "metadata_mapping.db"
# === IDSelector 클래스 정의 ===
class MetadataIDSelector(faiss.IDSelectorBatch):
def __init__(self, allowed_ids: Set[int]):
super().__init__(list(allowed_ids))
def get_db_connection(persist_directory: str) -> sqlite3.Connection:
"""FAISS 저장 경로를 기반으로 SQLite 연결을 설정하고 반환합니다."""
db_path = Path(persist_directory) / SQLITE_DB_NAME
conn = sqlite3.connect(db_path)
return conn
def _create_and_populate_sqlite_db(chunks: List[Document], persist_directory: str):
"""문서 청크를 기반으로 SQLite DB를 생성하고 채웁니다."""
conn = get_db_connection(persist_directory)
cursor = conn.cursor()
# 1. 테이블 생성
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
faiss_id INTEGER PRIMARY KEY,
regulation_part TEXT,
regulation_section TEXT,
chapter_section TEXT,
jo TEXT,
json_metadata TEXT
)
""")
conn.commit()
# 2. 데이터 채우기
for i, doc in enumerate(chunks):
faiss_id = i
metadata_json = json.dumps(doc.metadata, ensure_ascii=False)
reg_part = doc.metadata.get('regulation_part')
reg_section = doc.metadata.get('regulation_section')
reg_chapter = doc.metadata.get('chapter_section')
reg_jo = doc.metadata.get('jo')
# 변수가 리스트인 경우, 쉼표로 구분된 문자열로 변환
if isinstance(reg_section, list):
reg_section = ', '.join(map(str, reg_section))
if isinstance(reg_part, list):
reg_part = ', '.join(map(str, reg_part))
if isinstance(reg_chapter, list):
reg_chapter = ', '.join(map(str, reg_chapter))
if isinstance(reg_jo, list):
reg_jo = ', '.join(map(str, reg_jo))
# 문서 메타데이터에 FAISS ID 추가
doc.metadata['faiss_id'] = faiss_id
cursor.execute(
"INSERT OR REPLACE INTO documents (faiss_id, regulation_part, regulation_section, chapter_section, jo, json_metadata) VALUES (?, ?, ?, ?, ?, ?)",
(faiss_id, reg_part, reg_section, reg_chapter, reg_jo, metadata_json)
)
conn.commit()
conn.close()
# --- LocalSentenceTransformerEmbeddings ---
class LocalSentenceTransformerEmbeddings(Embeddings):
"""SentenceTransformer를 LangChain Embeddings 인터페이스로 래핑"""
def __init__(self, st_model, normalize_embeddings: bool = True, encode_batch_size: int = 32):
self.model = st_model
self.normalize = normalize_embeddings
self.encode_batch_size = encode_batch_size
def embed_documents(self, texts):
vecs = self.model.encode(
texts,
batch_size=self.encode_batch_size,
show_progress_bar=False,
normalize_embeddings=self.normalize,
convert_to_numpy=True,
)
return vecs.tolist()
def embed_query(self, text: str):
vec = self.model.encode(
[text],
batch_size=self.encode_batch_size,
show_progress_bar=False,
normalize_embeddings=self.normalize,
convert_to_numpy=True,
)[0]
return vec.tolist()
# --- save_embedding_system ---
def save_embedding_system(
chunks,
persist_directory: str = r"D:/Project AI/RAG",
batch_size: int = 32,
device: str = 'cuda'
):
"""
청크를 임베딩하여 FAISS 벡터스토어와 앙상블 리트리버를 생성하고,
SQLite DB에 메타데이터를 저장합니다.
"""
Path(persist_directory).mkdir(parents=True, exist_ok=True)
# 1) SQLite DB에 메타데이터 저장 및 청크에 faiss_id 추가
_create_and_populate_sqlite_db(chunks, persist_directory)
# 2) SentenceTransformer 로드
model = SentenceTransformer(
'nomic-ai/nomic-embed-text-v2-moe',
trust_remote_code=True,
device=device
)
embeddings = LocalSentenceTransformerEmbeddings(
st_model=model,
normalize_embeddings=True,
encode_batch_size=batch_size
)
# 3) FAISS 벡터스토어 생성
vectorstore = None
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
if vectorstore is None:
vectorstore = FAISS.from_documents(documents=batch, embedding=embeddings)
else:
vectorstore.add_documents(documents=batch)
gc.collect()
# 4) BM25 + 벡터 앙상블 리트리버 생성
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 5
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.6, 0.4]
)
# 5) FAISS 인덱스 저장
vectorstore.save_local(persist_directory)
# 6) SQLite 연결
sqlite_conn = get_db_connection(persist_directory)
gc.collect()
return ensemble_retriever, vectorstore, sqlite_conn
# --- load_embedding_from_faiss ---
def load_embedding_from_faiss(
persist_directory: str = r"D:/Project AI/RAG",
top_k: int = 10,
bm25_k: int = 10,
weights: Tuple[float, float] = (0.6, 0.4),
embeddings: Optional[Any] = None,
device: str = 'cpu'
) -> Tuple[Any, FAISS, sqlite3.Connection]:
"""
저장된 FAISS 인덱스와 SQLite 연결을 로드하여 앙상블 리트리버를 생성합니다.
"""
# 1) Embeddings 준비
if embeddings is None:
st_model = SentenceTransformer(
'nomic-ai/nomic-embed-text-v2-moe',
trust_remote_code=True,
device=device
)
embeddings = LocalSentenceTransformerEmbeddings(
st_model=st_model,
normalize_embeddings=True,
encode_batch_size=32
)
# 2) FAISS 벡터스토어 로드 (Pydantic v1 호환 옵션 추가)
persist_dir = Path(persist_directory)
if not persist_dir.exists():
raise FileNotFoundError(f"FAISS 경로가 없습니다: {persist_dir}")
try:
vectorstore = FAISS.load_local(
folder_path=str(persist_dir),
embeddings=embeddings,
allow_dangerous_deserialization=True
)
logger.info(f"[로드 성공] FAISS 인덱스 로드 완료: {persist_dir}")
except Exception as e:
logger.info(f"[로드 오류] FAISS 로드 실패: {e}")
raise
# 3) BM25를 위한 문서 추출
docs = []
try:
if hasattr(vectorstore, "docstore") and hasattr(vectorstore.docstore, "_dict"):
docs = list(vectorstore.docstore._dict.values())
except Exception as e:
logger.info(f"[경고] 저장된 문서를 읽는 중 문제가 발생했습니다: {e}")
# 4) 앙상블 리트리버 구성
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": top_k})
if docs:
bm25_retriever = BM25Retriever.from_documents(docs)
bm25_retriever.k = bm25_k
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=list(weights)
)
else:
logger.info("[안내] 문서를 찾지 못해 BM25 없이 벡터 리트리버만 반환합니다.")
ensemble_retriever = vector_retriever
# 5) SQLite 연결
sqlite_conn = get_db_connection(persist_directory)
return ensemble_retriever, vectorstore, sqlite_conn
# --- search_vectorstore ---
def search_vectorstore(retriever, query, k=5):
"""리트리버를 사용해 쿼리와 관련된 문서를 검색합니다."""
results = retriever.invoke(query)
return results[:k]
# === search_with_metadata_filter ===
def search_with_metadata_filter(
ensemble_retriever: EnsembleRetriever,
vectorstore: FAISS,
query: str,
k: int = 5,
metadata_filter: Optional[Dict[str, Any]] = None,
sqlite_conn: Optional[sqlite3.Connection] = None,
exact_match: bool = True
) -> List[Document]:
"""SQLite로 사전 필터링 후 FAISS 검색"""
vector_ret, bm25_ret = ensemble_retriever.retrievers
# === 1. SQLite에서 필터링된 FAISS ID 추출 ===
filtered_ids = None
if metadata_filter and sqlite_conn:
cursor = sqlite_conn.cursor()
where_clauses = []
params = []
for key, value in metadata_filter.items():
#logger.info(f"[key] {key}")
#logger.info(f"[value] {value}")
if isinstance(value, list):
if not value:
continue
placeholders = ', '.join(['?'] * len(value))
where_clauses.append(f"{key} IN ({placeholders})")
params.extend(value)
else:
where_clauses.append(f"{key} = ?")
params.append(value)
if where_clauses:
where_sql = " OR ".join(where_clauses)
sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}"
try:
cursor.execute(sql_query, params)
filtered_ids = {row[0] for row in cursor.fetchall()}
#logger.info(f"[사전 필터링] {len(filtered_ids)}개 ID 획득 → FAISS 검색 제한")
except Exception as e:
logger.info(f"[경고] SQLite 필터링 실패: {e}")
filtered_ids = None
#else:
#logger.info("[안내] 필터 조건 없음 → 전체 검색")
#else:
#logger.info("[안내] 필터 또는 DB 없음 → 전체 검색")
# === 2. FAISS 벡터 검색 ===
if filtered_ids and len(filtered_ids) > 0:
selector = MetadataIDSelector(filtered_ids)
index: faiss.Index = vectorstore.index
if not hasattr(index, "search"):
raise ValueError("FAISS 인덱스가 검색을 지원하지 않습니다.")
query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32')
query_embedding = query_embedding.reshape(1, -1)
search_params = faiss.SearchParametersIVF(
sel=selector,
nprobe=50
)
_k = max(k * 10, 100)
D, I = index.search(query_embedding, _k, params=search_params)
valid_indices = [i for i in I[0] if i != -1]
vector_docs = []
for idx in valid_indices[:k]:
doc_id = vectorstore.index_to_docstore_id[idx]
doc = vectorstore.docstore.search(doc_id)
if isinstance(doc, Document):
vector_docs.append(doc)
#logger.info(f"[벡터 검색] {len(valid_indices)}개 후보 → {len(vector_docs)}개 유효")
else:
search_k = k * 5
vector_docs = vector_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
#logger.info(f"[벡터 검색] 전체 검색 → {len(vector_docs)}개 후보")
# === 3. BM25 검색 ===
bm25_docs = []
if hasattr(bm25_ret, "invoke"):
search_k = k * 5
candidates = bm25_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
if filtered_ids:
bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids]
else:
bm25_docs = candidates[:k]
#logger.info(f"[BM25 검색] {len(candidates)}개 후보 → {len(bm25_docs)}개 필터링 후")
# === 4. 병합 및 최종 k개 반환 ===
combined = {id(d): d for d in (vector_docs + bm25_docs)}.values()
final_results = list(combined)[:k]
#logger.info(f"[최종 결과] {len(final_results)}개 문서 반환")
return final_results
def get_unique_metadata_values(
sqlite_conn: sqlite3.Connection,
key_name: str,
partial_match: Optional[str] = None
) -> List[str]:
"""SQLite에서 특정 컬럼의 고유한 값 리스트를 반환합니다."""
if not sqlite_conn:
logger.info("[경고] SQLite 연결이 없어 고유 값 검색을 수행할 수 없습니다.")
return []
cursor = sqlite_conn.cursor()
sql_query = f"SELECT DISTINCT `{key_name}` FROM documents"
params = []
if partial_match:
sql_query += f" WHERE `{key_name}` LIKE ?"
params.append(f"%{partial_match}%")
try:
cursor.execute(sql_query, params)
unique_values = [row[0] for row in cursor.fetchall() if row[0] is not None]
return unique_values
except sqlite3.OperationalError as e:
logger.info(f"[에러] SQLite 쿼리 실행 실패 (컬럼 '{key_name}' 이름 오류 가능): {e}")
return []
except Exception as e:
logger.info(f"[에러] 고유 값 검색 중 알 수 없는 오류 발생: {e}")
return []
def smart_search_vectorstore(
retriever,
query,
k=5,
vectorstore=None,
sqlite_conn=None,
enable_detailed_search=True
):
"""기본 검색 + 상세 검색 수행"""
# 1. 기본 검색
basic_results = retriever.invoke(query)
basic_results = basic_results[:k]
#logger.info(f"[기본 검색] {len(basic_results)}개 문서 검색 완료")
if not enable_detailed_search or not vectorstore or not sqlite_conn:
logger.info("[안내] 상세 검색 비활성화 또는 컴포넌트 부족 → 기본 검색 결과만 반환")
return basic_results
# 2. regulation_part 빈도 분석
regulation_parts = []
for doc in basic_results:
reg_part = doc.metadata.get('regulation_part')
if reg_part:
if isinstance(reg_part, list):
regulation_parts.extend(reg_part)
elif isinstance(reg_part, str):
if ',' in reg_part:
regulation_parts.extend([part.strip() for part in reg_part.split(',')])
else:
regulation_parts.append(reg_part)
if not regulation_parts:
logger.info("[안내] regulation_part 메타데이터 없음 → 기본 검색 결과만 반환")
return basic_results
counter = Counter(regulation_parts)
most_extracted_category = counter.most_common(2)
#logger.info(f"[빈도 분석] regulation_part 빈도: {dict(counter)}")
#logger.info(f"[상위 카테고리] {most_extracted_category}")
# 3. 상세 검색
detailed_results = []
for rank, (category, count) in enumerate(most_extracted_category, 1):
#logger.info(f"[상세 검색 {rank}순위] '{category}' 카테고리 검색 시작 (빈도: {count})")
metadata_filter = {'regulation_part': category}
try:
category_results = search_with_metadata_filter(
ensemble_retriever=retriever,
vectorstore=vectorstore,
query=query,
k=k,
metadata_filter=metadata_filter,
sqlite_conn=sqlite_conn
)
detailed_results.extend(category_results)
#logger.info(f"[상세 검색 {rank}순위] {len(category_results)}개 추가 문서 검색 완료")
except Exception as e:
#logger.info(f"[경고] 상세 검색 {rank}순위 실패 ({category}): {e}")
continue
# 4. 결과 병합
seen = set()
final_results = []
#Detailed 검색 결과를 먼저 추가
for doc in detailed_results:
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
if doc_signature not in seen:
seen.add(doc_signature)
final_results.append(doc)
for doc in basic_results:
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
if doc_signature not in seen:
seen.add(doc_signature)
final_results.append(doc)
final_results = final_results[:k]
#logger.info(f"[최종 결과] 기본 {len(basic_results)}개 + 상세 {len(detailed_results)}개 → 중복 제거 후 {len(final_results)}개 반환")
return final_results
# natural_sort_key 함수 추가 (app.py에서 사용됨)
import re
def natural_sort_key(s):
"""자연스러운 정렬을 위한 키 함수"""
return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', str(s))]