Spaces:
Running
Running
Update reg_embedding_system.py
Browse files- reg_embedding_system.py +469 -527
reg_embedding_system.py
CHANGED
|
@@ -1,527 +1,469 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
import
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
cursor.
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
self.
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
#
|
| 134 |
-
_create_and_populate_sqlite_db(chunks, persist_directory)
|
| 135 |
-
|
| 136 |
-
# 2) SentenceTransformer 로드
|
| 137 |
-
model = SentenceTransformer(
|
| 138 |
-
'nomic-ai/nomic-embed-text-v2-moe',
|
| 139 |
-
trust_remote_code=True,
|
| 140 |
-
device=device
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
embeddings = LocalSentenceTransformerEmbeddings(
|
| 144 |
-
st_model=model,
|
| 145 |
-
normalize_embeddings=True,
|
| 146 |
-
encode_batch_size=batch_size
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
# 3) FAISS 벡터스토어 생성
|
| 150 |
-
vectorstore = None
|
| 151 |
-
for i in range(0, len(chunks), batch_size):
|
| 152 |
-
batch = chunks[i:i + batch_size]
|
| 153 |
-
if vectorstore is None:
|
| 154 |
-
vectorstore = FAISS.from_documents(documents=batch, embedding=embeddings)
|
| 155 |
-
else:
|
| 156 |
-
vectorstore.add_documents(documents=batch)
|
| 157 |
-
gc.collect()
|
| 158 |
-
|
| 159 |
-
# 4) BM25 + 벡터 앙상블 리트리버 생성
|
| 160 |
-
bm25_retriever = BM25Retriever.from_documents(chunks)
|
| 161 |
-
bm25_retriever.k = 5
|
| 162 |
-
|
| 163 |
-
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
| 164 |
-
|
| 165 |
-
ensemble_retriever = EnsembleRetriever(
|
| 166 |
-
retrievers=[vector_retriever, bm25_retriever],
|
| 167 |
-
weights=[0.6, 0.4]
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
# 5) FAISS 인덱스 저장
|
| 171 |
-
vectorstore.save_local(persist_directory)
|
| 172 |
-
|
| 173 |
-
# 6) SQLite 연결
|
| 174 |
-
sqlite_conn = get_db_connection(persist_directory)
|
| 175 |
-
gc.collect()
|
| 176 |
-
|
| 177 |
-
return ensemble_retriever, vectorstore, sqlite_conn
|
| 178 |
-
|
| 179 |
-
# --- load_embedding_from_faiss
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
print(f"[
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
filtered_ids =
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
vector_docs =
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
print(f"[
|
| 345 |
-
|
| 346 |
-
# ===
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
)
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
if
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
counter = Counter(regulation_parts)
|
| 471 |
-
most_extracted_category = counter.most_common(2) # 상위 2개 카테고리
|
| 472 |
-
|
| 473 |
-
print(f"[빈도 분석] regulation_part 빈도: {dict(counter)}")
|
| 474 |
-
print(f"[상위 카테고리] {most_extracted_category}")
|
| 475 |
-
|
| 476 |
-
# 3. 상위 카테고리에 대한 상세 검색 수행
|
| 477 |
-
detailed_results = []
|
| 478 |
-
|
| 479 |
-
for rank, (category, count) in enumerate(most_extracted_category, 1):
|
| 480 |
-
print(f"[상세 검색 {rank}순위] '{category}' 카테고리 검색 시작 (빈도: {count})")
|
| 481 |
-
|
| 482 |
-
# metadata_filter 구성
|
| 483 |
-
metadata_filter = {'regulation_part': category}
|
| 484 |
-
|
| 485 |
-
try:
|
| 486 |
-
# search_with_metadata_filter 호출
|
| 487 |
-
category_results = search_with_metadata_filter(
|
| 488 |
-
ensemble_retriever=retriever,
|
| 489 |
-
vectorstore=vectorstore,
|
| 490 |
-
query=query,
|
| 491 |
-
k=k,
|
| 492 |
-
metadata_filter=metadata_filter,
|
| 493 |
-
sqlite_conn=sqlite_conn
|
| 494 |
-
)
|
| 495 |
-
|
| 496 |
-
detailed_results.extend(category_results)
|
| 497 |
-
print(f"[상세 검색 {rank}순위] {len(category_results)}개 추가 문서 검색 완료")
|
| 498 |
-
|
| 499 |
-
except Exception as e:
|
| 500 |
-
print(f"[경고] 상세 검색 {rank}순위 실패 ({category}): {e}")
|
| 501 |
-
continue
|
| 502 |
-
|
| 503 |
-
# 4. 결과 병합 (중복 제거)
|
| 504 |
-
# Document 객체의 고유성을 위해 page_content와 metadata의 조합으로 중복 판단
|
| 505 |
-
seen = set()
|
| 506 |
-
final_results = []
|
| 507 |
-
|
| 508 |
-
# 기본 검색 결과 우선 추가
|
| 509 |
-
for doc in basic_results:
|
| 510 |
-
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
|
| 511 |
-
if doc_signature not in seen:
|
| 512 |
-
seen.add(doc_signature)
|
| 513 |
-
final_results.append(doc)
|
| 514 |
-
|
| 515 |
-
# 상세 검색 결과 추가 (중복 제거)
|
| 516 |
-
for doc in detailed_results:
|
| 517 |
-
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
|
| 518 |
-
if doc_signature not in seen:
|
| 519 |
-
seen.add(doc_signature)
|
| 520 |
-
final_results.append(doc)
|
| 521 |
-
|
| 522 |
-
# 최종 k개로 제한
|
| 523 |
-
final_results = final_results[:k]
|
| 524 |
-
|
| 525 |
-
print(f"[최종 결과] 기본 {len(basic_results)}개 + 상세 {len(detailed_results)}개 → 중복 제거 후 {len(final_results)}개 반환")
|
| 526 |
-
|
| 527 |
-
return final_results
|
|
|
|
| 1 |
+
# ===== Pydantic v1 호환 설정 (파일 최상단에 배치) =====
|
| 2 |
+
import os
|
| 3 |
+
os.environ["PYDANTIC_V1_STYLE"] = "1"
|
| 4 |
+
os.environ["PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS"] = "1"
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import json
|
| 8 |
+
import sqlite3
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, Tuple, Any, Dict, List, Set
|
| 11 |
+
from collections import Counter
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import faiss
|
| 15 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
| 16 |
+
from langchain_core.documents import Document
|
| 17 |
+
from langchain_community.vectorstores import FAISS
|
| 18 |
+
from sentence_transformers import SentenceTransformer
|
| 19 |
+
|
| 20 |
+
# 런타임에 Embeddings 클래스를 찾기 위한 로직
|
| 21 |
+
try:
|
| 22 |
+
from langchain_core.embeddings import Embeddings
|
| 23 |
+
except ImportError:
|
| 24 |
+
try:
|
| 25 |
+
from langchain.embeddings.base import Embeddings
|
| 26 |
+
except ImportError:
|
| 27 |
+
Embeddings = object
|
| 28 |
+
|
| 29 |
+
# --- SQLite 헬퍼 함수 ---
|
| 30 |
+
SQLITE_DB_NAME = "metadata_mapping.db"
|
| 31 |
+
|
| 32 |
+
# === IDSelector 클래스 정의 ===
|
| 33 |
+
class MetadataIDSelector(faiss.IDSelectorBatch):
|
| 34 |
+
def __init__(self, allowed_ids: Set[int]):
|
| 35 |
+
super().__init__(list(allowed_ids))
|
| 36 |
+
|
| 37 |
+
def get_db_connection(persist_directory: str) -> sqlite3.Connection:
|
| 38 |
+
"""FAISS 저장 경로를 기반으로 SQLite 연결을 설정하고 반환합니다."""
|
| 39 |
+
db_path = Path(persist_directory) / SQLITE_DB_NAME
|
| 40 |
+
conn = sqlite3.connect(db_path)
|
| 41 |
+
return conn
|
| 42 |
+
|
| 43 |
+
def _create_and_populate_sqlite_db(chunks: List[Document], persist_directory: str):
|
| 44 |
+
"""문서 청크를 기반으로 SQLite DB를 생성하고 채웁니다."""
|
| 45 |
+
conn = get_db_connection(persist_directory)
|
| 46 |
+
cursor = conn.cursor()
|
| 47 |
+
|
| 48 |
+
# 1. 테이블 생성
|
| 49 |
+
cursor.execute("""
|
| 50 |
+
CREATE TABLE IF NOT EXISTS documents (
|
| 51 |
+
faiss_id INTEGER PRIMARY KEY,
|
| 52 |
+
regulation_part TEXT,
|
| 53 |
+
regulation_section TEXT,
|
| 54 |
+
chapter_section TEXT,
|
| 55 |
+
jo TEXT,
|
| 56 |
+
json_metadata TEXT
|
| 57 |
+
)
|
| 58 |
+
""")
|
| 59 |
+
conn.commit()
|
| 60 |
+
|
| 61 |
+
# 2. 데이터 채우기
|
| 62 |
+
for i, doc in enumerate(chunks):
|
| 63 |
+
faiss_id = i
|
| 64 |
+
metadata_json = json.dumps(doc.metadata, ensure_ascii=False)
|
| 65 |
+
reg_part = doc.metadata.get('regulation_part')
|
| 66 |
+
reg_section = doc.metadata.get('regulation_section')
|
| 67 |
+
reg_chapter = doc.metadata.get('chapter_section')
|
| 68 |
+
reg_jo = doc.metadata.get('jo')
|
| 69 |
+
|
| 70 |
+
# 변수가 리스트인 경우, 쉼표로 구분된 문자열로 변환
|
| 71 |
+
if isinstance(reg_section, list):
|
| 72 |
+
reg_section = ', '.join(map(str, reg_section))
|
| 73 |
+
if isinstance(reg_part, list):
|
| 74 |
+
reg_part = ', '.join(map(str, reg_part))
|
| 75 |
+
if isinstance(reg_chapter, list):
|
| 76 |
+
reg_chapter = ', '.join(map(str, reg_chapter))
|
| 77 |
+
if isinstance(reg_jo, list):
|
| 78 |
+
reg_jo = ', '.join(map(str, reg_jo))
|
| 79 |
+
|
| 80 |
+
# 문서 메타데이터에 FAISS ID 추가
|
| 81 |
+
doc.metadata['faiss_id'] = faiss_id
|
| 82 |
+
|
| 83 |
+
cursor.execute(
|
| 84 |
+
"INSERT OR REPLACE INTO documents (faiss_id, regulation_part, regulation_section, chapter_section, jo, json_metadata) VALUES (?, ?, ?, ?, ?, ?)",
|
| 85 |
+
(faiss_id, reg_part, reg_section, reg_chapter, reg_jo, metadata_json)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
conn.commit()
|
| 89 |
+
conn.close()
|
| 90 |
+
|
| 91 |
+
# --- LocalSentenceTransformerEmbeddings ---
|
| 92 |
+
class LocalSentenceTransformerEmbeddings(Embeddings):
|
| 93 |
+
"""SentenceTransformer를 LangChain Embeddings 인터페이스로 래핑"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, st_model, normalize_embeddings: bool = True, encode_batch_size: int = 32):
|
| 96 |
+
self.model = st_model
|
| 97 |
+
self.normalize = normalize_embeddings
|
| 98 |
+
self.encode_batch_size = encode_batch_size
|
| 99 |
+
|
| 100 |
+
def embed_documents(self, texts):
|
| 101 |
+
vecs = self.model.encode(
|
| 102 |
+
texts,
|
| 103 |
+
batch_size=self.encode_batch_size,
|
| 104 |
+
show_progress_bar=False,
|
| 105 |
+
normalize_embeddings=self.normalize,
|
| 106 |
+
convert_to_numpy=True,
|
| 107 |
+
)
|
| 108 |
+
return vecs.tolist()
|
| 109 |
+
|
| 110 |
+
def embed_query(self, text: str):
|
| 111 |
+
vec = self.model.encode(
|
| 112 |
+
[text],
|
| 113 |
+
batch_size=self.encode_batch_size,
|
| 114 |
+
show_progress_bar=False,
|
| 115 |
+
normalize_embeddings=self.normalize,
|
| 116 |
+
convert_to_numpy=True,
|
| 117 |
+
)[0]
|
| 118 |
+
return vec.tolist()
|
| 119 |
+
|
| 120 |
+
# --- save_embedding_system ---
|
| 121 |
+
def save_embedding_system(
|
| 122 |
+
chunks,
|
| 123 |
+
persist_directory: str = r"D:/Project AI/RAG",
|
| 124 |
+
batch_size: int = 32,
|
| 125 |
+
device: str = 'cuda'
|
| 126 |
+
):
|
| 127 |
+
"""
|
| 128 |
+
청크를 임베딩���여 FAISS 벡터스토어와 앙상블 리트리버를 생성하고,
|
| 129 |
+
SQLite DB에 메타데이터를 저장합니다.
|
| 130 |
+
"""
|
| 131 |
+
Path(persist_directory).mkdir(parents=True, exist_ok=True)
|
| 132 |
+
|
| 133 |
+
# 1) SQLite DB에 메타데이터 저장 및 청크에 faiss_id 추가
|
| 134 |
+
_create_and_populate_sqlite_db(chunks, persist_directory)
|
| 135 |
+
|
| 136 |
+
# 2) SentenceTransformer 로드
|
| 137 |
+
model = SentenceTransformer(
|
| 138 |
+
'nomic-ai/nomic-embed-text-v2-moe',
|
| 139 |
+
trust_remote_code=True,
|
| 140 |
+
device=device
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
embeddings = LocalSentenceTransformerEmbeddings(
|
| 144 |
+
st_model=model,
|
| 145 |
+
normalize_embeddings=True,
|
| 146 |
+
encode_batch_size=batch_size
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# 3) FAISS 벡터스토어 생성
|
| 150 |
+
vectorstore = None
|
| 151 |
+
for i in range(0, len(chunks), batch_size):
|
| 152 |
+
batch = chunks[i:i + batch_size]
|
| 153 |
+
if vectorstore is None:
|
| 154 |
+
vectorstore = FAISS.from_documents(documents=batch, embedding=embeddings)
|
| 155 |
+
else:
|
| 156 |
+
vectorstore.add_documents(documents=batch)
|
| 157 |
+
gc.collect()
|
| 158 |
+
|
| 159 |
+
# 4) BM25 + 벡터 앙상블 리트리버 생성
|
| 160 |
+
bm25_retriever = BM25Retriever.from_documents(chunks)
|
| 161 |
+
bm25_retriever.k = 5
|
| 162 |
+
|
| 163 |
+
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
| 164 |
+
|
| 165 |
+
ensemble_retriever = EnsembleRetriever(
|
| 166 |
+
retrievers=[vector_retriever, bm25_retriever],
|
| 167 |
+
weights=[0.6, 0.4]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# 5) FAISS 인덱스 저장
|
| 171 |
+
vectorstore.save_local(persist_directory)
|
| 172 |
+
|
| 173 |
+
# 6) SQLite 연결
|
| 174 |
+
sqlite_conn = get_db_connection(persist_directory)
|
| 175 |
+
gc.collect()
|
| 176 |
+
|
| 177 |
+
return ensemble_retriever, vectorstore, sqlite_conn
|
| 178 |
+
|
| 179 |
+
# --- load_embedding_from_faiss ---
|
| 180 |
+
def load_embedding_from_faiss(
|
| 181 |
+
persist_directory: str = r"D:/Project AI/RAG",
|
| 182 |
+
top_k: int = 10,
|
| 183 |
+
bm25_k: int = 10,
|
| 184 |
+
weights: Tuple[float, float] = (0.6, 0.4),
|
| 185 |
+
embeddings: Optional[Any] = None,
|
| 186 |
+
device: str = 'cpu'
|
| 187 |
+
) -> Tuple[Any, FAISS, sqlite3.Connection]:
|
| 188 |
+
"""
|
| 189 |
+
저장된 FAISS 인덱스와 SQLite 연결을 로드하여 앙상블 리트리버를 생성합니다.
|
| 190 |
+
"""
|
| 191 |
+
# 1) Embeddings 준비
|
| 192 |
+
if embeddings is None:
|
| 193 |
+
st_model = SentenceTransformer(
|
| 194 |
+
'nomic-ai/nomic-embed-text-v2-moe',
|
| 195 |
+
trust_remote_code=True,
|
| 196 |
+
device=device
|
| 197 |
+
)
|
| 198 |
+
embeddings = LocalSentenceTransformerEmbeddings(
|
| 199 |
+
st_model=st_model,
|
| 200 |
+
normalize_embeddings=True,
|
| 201 |
+
encode_batch_size=32
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# 2) FAISS 벡터스토어 로드 (Pydantic v1 호환 옵션 추가)
|
| 205 |
+
persist_dir = Path(persist_directory)
|
| 206 |
+
if not persist_dir.exists():
|
| 207 |
+
raise FileNotFoundError(f"FAISS 경로가 없습니다: {persist_dir}")
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
vectorstore = FAISS.load_local(
|
| 211 |
+
folder_path=str(persist_dir),
|
| 212 |
+
embeddings=embeddings,
|
| 213 |
+
allow_dangerous_deserialization=True
|
| 214 |
+
)
|
| 215 |
+
print(f"[로드 성공] FAISS 인덱스 로드 완료: {persist_dir}")
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(f"[로드 오류] FAISS 로드 실패: {e}")
|
| 218 |
+
raise
|
| 219 |
+
|
| 220 |
+
# 3) BM25를 위한 문서 추출
|
| 221 |
+
docs = []
|
| 222 |
+
try:
|
| 223 |
+
if hasattr(vectorstore, "docstore") and hasattr(vectorstore.docstore, "_dict"):
|
| 224 |
+
docs = list(vectorstore.docstore._dict.values())
|
| 225 |
+
except Exception as e:
|
| 226 |
+
print(f"[경고] 저장된 문서를 읽는 중 문제가 발생했습니다: {e}")
|
| 227 |
+
|
| 228 |
+
# 4) 앙상블 리트리버 구성
|
| 229 |
+
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": top_k})
|
| 230 |
+
|
| 231 |
+
if docs:
|
| 232 |
+
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 233 |
+
bm25_retriever.k = bm25_k
|
| 234 |
+
ensemble_retriever = EnsembleRetriever(
|
| 235 |
+
retrievers=[vector_retriever, bm25_retriever],
|
| 236 |
+
weights=list(weights)
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
print("[안내] 문서를 찾지 못해 BM25 없이 벡터 리트리버만 반환합니다.")
|
| 240 |
+
ensemble_retriever = vector_retriever
|
| 241 |
+
|
| 242 |
+
# 5) SQLite 연결
|
| 243 |
+
sqlite_conn = get_db_connection(persist_directory)
|
| 244 |
+
|
| 245 |
+
return ensemble_retriever, vectorstore, sqlite_conn
|
| 246 |
+
|
| 247 |
+
# --- search_vectorstore ---
|
| 248 |
+
def search_vectorstore(retriever, query, k=5):
|
| 249 |
+
"""리트리버를 사용해 쿼리와 관련된 문서를 검색합니다."""
|
| 250 |
+
results = retriever.invoke(query)
|
| 251 |
+
return results[:k]
|
| 252 |
+
|
| 253 |
+
# === search_with_metadata_filter ===
|
| 254 |
+
def search_with_metadata_filter(
|
| 255 |
+
ensemble_retriever: EnsembleRetriever,
|
| 256 |
+
vectorstore: FAISS,
|
| 257 |
+
query: str,
|
| 258 |
+
k: int = 5,
|
| 259 |
+
metadata_filter: Optional[Dict[str, Any]] = None,
|
| 260 |
+
sqlite_conn: Optional[sqlite3.Connection] = None,
|
| 261 |
+
exact_match: bool = True
|
| 262 |
+
) -> List[Document]:
|
| 263 |
+
"""SQLite로 사전 필터링 후 FAISS 검색"""
|
| 264 |
+
vector_ret, bm25_ret = ensemble_retriever.retrievers
|
| 265 |
+
|
| 266 |
+
# === 1. SQLite에서 필터링된 FAISS ID 추출 ===
|
| 267 |
+
filtered_ids = None
|
| 268 |
+
if metadata_filter and sqlite_conn:
|
| 269 |
+
cursor = sqlite_conn.cursor()
|
| 270 |
+
where_clauses = []
|
| 271 |
+
params = []
|
| 272 |
+
|
| 273 |
+
for key, value in metadata_filter.items():
|
| 274 |
+
print(f"[key] {key}")
|
| 275 |
+
print(f"[value] {value}")
|
| 276 |
+
if isinstance(value, list):
|
| 277 |
+
if not value:
|
| 278 |
+
continue
|
| 279 |
+
placeholders = ', '.join(['?'] * len(value))
|
| 280 |
+
where_clauses.append(f"{key} IN ({placeholders})")
|
| 281 |
+
params.extend(value)
|
| 282 |
+
else:
|
| 283 |
+
where_clauses.append(f"{key} = ?")
|
| 284 |
+
params.append(value)
|
| 285 |
+
|
| 286 |
+
if where_clauses:
|
| 287 |
+
where_sql = " OR ".join(where_clauses)
|
| 288 |
+
sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}"
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
cursor.execute(sql_query, params)
|
| 292 |
+
filtered_ids = {row[0] for row in cursor.fetchall()}
|
| 293 |
+
print(f"[사전 필터링] {len(filtered_ids)}개 ID 획득 → FAISS 검색 제한")
|
| 294 |
+
except Exception as e:
|
| 295 |
+
print(f"[경고] SQLite 필터링 실패: {e}")
|
| 296 |
+
filtered_ids = None
|
| 297 |
+
else:
|
| 298 |
+
print("[안내] 필터 조건 없음 → 전체 검색")
|
| 299 |
+
else:
|
| 300 |
+
print("[안내] 필터 또는 DB 없음 → 전체 검색")
|
| 301 |
+
|
| 302 |
+
# === 2. FAISS 벡터 검색 ===
|
| 303 |
+
if filtered_ids and len(filtered_ids) > 0:
|
| 304 |
+
selector = MetadataIDSelector(filtered_ids)
|
| 305 |
+
index: faiss.Index = vectorstore.index
|
| 306 |
+
|
| 307 |
+
if not hasattr(index, "search"):
|
| 308 |
+
raise ValueError("FAISS 인덱스가 검색을 지원하지 않습니다.")
|
| 309 |
+
|
| 310 |
+
query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32')
|
| 311 |
+
query_embedding = query_embedding.reshape(1, -1)
|
| 312 |
+
|
| 313 |
+
search_params = faiss.SearchParametersIVF(
|
| 314 |
+
sel=selector,
|
| 315 |
+
nprobe=50
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
_k = max(k * 10, 100)
|
| 319 |
+
D, I = index.search(query_embedding, _k, params=search_params)
|
| 320 |
+
|
| 321 |
+
valid_indices = [i for i in I[0] if i != -1]
|
| 322 |
+
vector_docs = []
|
| 323 |
+
for idx in valid_indices[:k]:
|
| 324 |
+
doc_id = vectorstore.index_to_docstore_id[idx]
|
| 325 |
+
doc = vectorstore.docstore.search(doc_id)
|
| 326 |
+
if isinstance(doc, Document):
|
| 327 |
+
vector_docs.append(doc)
|
| 328 |
+
|
| 329 |
+
print(f"[벡터 검색] {len(valid_indices)}개 후보 → {len(vector_docs)}개 유효")
|
| 330 |
+
else:
|
| 331 |
+
search_k = k * 5
|
| 332 |
+
vector_docs = vector_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
|
| 333 |
+
print(f"[벡터 검색] 전체 검색 → {len(vector_docs)}개 후보")
|
| 334 |
+
|
| 335 |
+
# === 3. BM25 검색 ===
|
| 336 |
+
bm25_docs = []
|
| 337 |
+
if hasattr(bm25_ret, "invoke"):
|
| 338 |
+
search_k = k * 5
|
| 339 |
+
candidates = bm25_ret.invoke(query, config={"search_kwargs": {"k": search_k}})
|
| 340 |
+
if filtered_ids:
|
| 341 |
+
bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids]
|
| 342 |
+
else:
|
| 343 |
+
bm25_docs = candidates[:k]
|
| 344 |
+
print(f"[BM25 검색] {len(candidates)}개 후보 → {len(bm25_docs)}개 필터링 후")
|
| 345 |
+
|
| 346 |
+
# === 4. 병합 및 최종 k개 반환 ===
|
| 347 |
+
combined = {id(d): d for d in (vector_docs + bm25_docs)}.values()
|
| 348 |
+
final_results = list(combined)[:k]
|
| 349 |
+
|
| 350 |
+
print(f"[최종 결과] {len(final_results)}개 문서 반환")
|
| 351 |
+
return final_results
|
| 352 |
+
|
| 353 |
+
def get_unique_metadata_values(
|
| 354 |
+
sqlite_conn: sqlite3.Connection,
|
| 355 |
+
key_name: str,
|
| 356 |
+
partial_match: Optional[str] = None
|
| 357 |
+
) -> List[str]:
|
| 358 |
+
"""SQLite에서 특정 컬럼의 고유한 값 리스트를 반환합니다."""
|
| 359 |
+
if not sqlite_conn:
|
| 360 |
+
print("[경고] SQLite 연결이 없어 고유 값 검색을 수행할 수 없습니다.")
|
| 361 |
+
return []
|
| 362 |
+
|
| 363 |
+
cursor = sqlite_conn.cursor()
|
| 364 |
+
sql_query = f"SELECT DISTINCT `{key_name}` FROM documents"
|
| 365 |
+
params = []
|
| 366 |
+
|
| 367 |
+
if partial_match:
|
| 368 |
+
sql_query += f" WHERE `{key_name}` LIKE ?"
|
| 369 |
+
params.append(f"%{partial_match}%")
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
cursor.execute(sql_query, params)
|
| 373 |
+
unique_values = [row[0] for row in cursor.fetchall() if row[0] is not None]
|
| 374 |
+
return unique_values
|
| 375 |
+
except sqlite3.OperationalError as e:
|
| 376 |
+
print(f"[에러] SQLite 쿼리 실행 실패 (컬럼 '{key_name}' 이름 오류 가능): {e}")
|
| 377 |
+
return []
|
| 378 |
+
except Exception as e:
|
| 379 |
+
print(f"[에러] 고유 값 검색 중 알 수 없는 오류 발생: {e}")
|
| 380 |
+
return []
|
| 381 |
+
|
| 382 |
+
def smart_search_vectorstore(
|
| 383 |
+
retriever,
|
| 384 |
+
query,
|
| 385 |
+
k=5,
|
| 386 |
+
vectorstore=None,
|
| 387 |
+
sqlite_conn=None,
|
| 388 |
+
enable_detailed_search=True
|
| 389 |
+
):
|
| 390 |
+
"""기본 검색 + 상세 검색 수행"""
|
| 391 |
+
# 1. 기본 검색
|
| 392 |
+
basic_results = retriever.invoke(query)
|
| 393 |
+
basic_results = basic_results[:k]
|
| 394 |
+
print(f"[기본 검색] {len(basic_results)}개 문서 검색 완료")
|
| 395 |
+
|
| 396 |
+
if not enable_detailed_search or not vectorstore or not sqlite_conn:
|
| 397 |
+
print("[안내] 상세 검색 비활성화 또는 컴포넌트 부족 → 기본 검색 결과만 반환")
|
| 398 |
+
return basic_results
|
| 399 |
+
|
| 400 |
+
# 2. regulation_part 빈도 분석
|
| 401 |
+
regulation_parts = []
|
| 402 |
+
for doc in basic_results:
|
| 403 |
+
reg_part = doc.metadata.get('regulation_part')
|
| 404 |
+
if reg_part:
|
| 405 |
+
if isinstance(reg_part, list):
|
| 406 |
+
regulation_parts.extend(reg_part)
|
| 407 |
+
elif isinstance(reg_part, str):
|
| 408 |
+
if ',' in reg_part:
|
| 409 |
+
regulation_parts.extend([part.strip() for part in reg_part.split(',')])
|
| 410 |
+
else:
|
| 411 |
+
regulation_parts.append(reg_part)
|
| 412 |
+
|
| 413 |
+
if not regulation_parts:
|
| 414 |
+
print("[안내] regulation_part 메타데이터 없음 → 기본 검색 결과만 반환")
|
| 415 |
+
return basic_results
|
| 416 |
+
|
| 417 |
+
counter = Counter(regulation_parts)
|
| 418 |
+
most_extracted_category = counter.most_common(2)
|
| 419 |
+
print(f"[빈도 분석] regulation_part 빈도: {dict(counter)}")
|
| 420 |
+
print(f"[상위 카테고리] {most_extracted_category}")
|
| 421 |
+
|
| 422 |
+
# 3. 상세 검색
|
| 423 |
+
detailed_results = []
|
| 424 |
+
for rank, (category, count) in enumerate(most_extracted_category, 1):
|
| 425 |
+
print(f"[상세 검색 {rank}순위] '{category}' 카테고리 검색 시작 (빈도: {count})")
|
| 426 |
+
metadata_filter = {'regulation_part': category}
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
category_results = search_with_metadata_filter(
|
| 430 |
+
ensemble_retriever=retriever,
|
| 431 |
+
vectorstore=vectorstore,
|
| 432 |
+
query=query,
|
| 433 |
+
k=k,
|
| 434 |
+
metadata_filter=metadata_filter,
|
| 435 |
+
sqlite_conn=sqlite_conn
|
| 436 |
+
)
|
| 437 |
+
detailed_results.extend(category_results)
|
| 438 |
+
print(f"[상세 검색 {rank}순위] {len(category_results)}개 추가 문서 검색 완료")
|
| 439 |
+
except Exception as e:
|
| 440 |
+
print(f"[경고] 상세 검색 {rank}순위 실패 ({category}): {e}")
|
| 441 |
+
continue
|
| 442 |
+
|
| 443 |
+
# 4. 결과 병합
|
| 444 |
+
seen = set()
|
| 445 |
+
final_results = []
|
| 446 |
+
|
| 447 |
+
for doc in basic_results:
|
| 448 |
+
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
|
| 449 |
+
if doc_signature not in seen:
|
| 450 |
+
seen.add(doc_signature)
|
| 451 |
+
final_results.append(doc)
|
| 452 |
+
|
| 453 |
+
for doc in detailed_results:
|
| 454 |
+
doc_signature = (doc.page_content, str(sorted(doc.metadata.items())))
|
| 455 |
+
if doc_signature not in seen:
|
| 456 |
+
seen.add(doc_signature)
|
| 457 |
+
final_results.append(doc)
|
| 458 |
+
|
| 459 |
+
final_results = final_results[:k]
|
| 460 |
+
print(f"[최종 결과] 기본 {len(basic_results)}개 + 상세 {len(detailed_results)}개 → 중복 제거 후 {len(final_results)}개 반환")
|
| 461 |
+
|
| 462 |
+
return final_results
|
| 463 |
+
|
| 464 |
+
# natural_sort_key 함수 추가 (app.py에서 사용됨)
|
| 465 |
+
import re
|
| 466 |
+
|
| 467 |
+
def natural_sort_key(s):
|
| 468 |
+
"""자연스러운 정렬을 위한 키 함수"""
|
| 469 |
+
return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', str(s))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|