scipious commited on
Commit
78f9356
·
verified ·
1 Parent(s): 2000df9

Upload reg_embedding_system_v02.py

Browse files
Files changed (1) hide show
  1. reg_embedding_system_v02.py +453 -0
reg_embedding_system_v02.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import json
3
+ import sqlite3
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple, Any, Dict, List, Set, Union
6
+ import numpy as np
7
+
8
+ import faiss
9
+
10
+ # [수정됨] 패키지 구조 변경 반영 및 EnsembleRetriever 제거
11
+ from langchain_community.retrievers import BM25Retriever
12
+ from langchain_core.documents import Document
13
+ from langchain_community.vectorstores import FAISS
14
+ from sentence_transformers import SentenceTransformer
15
+
16
+ # 런타임에 Embeddings 클래스를 찾기 위한 로직
17
+ try:
18
+ from langchain_core.embeddings import Embeddings
19
+ except ImportError:
20
+ try:
21
+ from langchain.embeddings.base import Embeddings
22
+ except ImportError:
23
+ Embeddings = object
24
+
25
+ # --- SQLite 헬퍼 함수 ---
26
+ SQLITE_DB_NAME = "metadata_mapping.db"
27
+
28
+ # === IDSelector 클래스 정의 ===
29
+ class MetadataIDSelector(faiss.IDSelectorBatch):
30
+ def __init__(self, allowed_ids: Set[int]):
31
+ super().__init__(list(allowed_ids))
32
+
33
+ def get_db_connection(persist_directory: str) -> sqlite3.Connection:
34
+ """FAISS 저장 경로를 기반으로 SQLite 연결을 설정하고 반환합니다."""
35
+ db_path = Path(persist_directory) / SQLITE_DB_NAME
36
+ conn = sqlite3.connect(db_path)
37
+ return conn
38
+
39
+ def _create_and_populate_sqlite_db(chunks: List[Document], persist_directory: str):
40
+ """
41
+ 문서 청크를 기반으로 SQLite DB를 생성하고 채웁니다.
42
+ [업데이트 반영] 메타데이터 구조: regulation, chapter, section, standard
43
+ """
44
+ # 1. 입력 데이터 확인 (가장 중요한 체크 포인트)
45
+ if not chunks:
46
+ print("🚨 [오류] _create_and_populate_sqlite_db 함수에 전달된 chunks 리스트가 비어 있습니다!")
47
+ print(" -> load_chunks_from_jsonl 함수가 정상적으로 파일을 읽었는지 확인해주세요.")
48
+ return
49
+
50
+ # 2. 저장 경로 확인 및 생성
51
+ save_dir = Path(persist_directory)
52
+ save_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ conn = get_db_connection(persist_directory)
55
+
56
+ try:
57
+ cursor = conn.cursor()
58
+
59
+ # 3. 테이블 생성 (기존 테이블 삭제 후 재생성 옵션 고려)
60
+ # 스키마가 변경되었으므로 기존 테이블이 있다면 충돌날 수 있습니다.
61
+ # 안전하게 지우고 다시 만드는 방법을 추천합니다. (개발 단계)
62
+ cursor.execute("DROP TABLE IF EXISTS documents")
63
+
64
+ cursor.execute("""
65
+ CREATE TABLE documents (
66
+ faiss_id INTEGER PRIMARY KEY,
67
+ source TEXT,
68
+ regulation TEXT,
69
+ chapter TEXT,
70
+ section TEXT,
71
+ standard TEXT,
72
+ json_metadata TEXT
73
+ )
74
+ """)
75
+ # 테이블 생성 직후 커밋 (파일에 스키마 기록)
76
+ conn.commit()
77
+ print(f"📂 DB 테이블 생성 완료 (경로: {save_dir}/{SQLITE_DB_NAME})")
78
+
79
+ # 4. 데이터 채우기
80
+ inserted_count = 0
81
+ for i, doc in enumerate(chunks):
82
+ faiss_id = i
83
+ metadata_json = json.dumps(doc.metadata, ensure_ascii=False)
84
+
85
+ source_val = doc.metadata.get('source', '')
86
+ regulation_val = doc.metadata.get('regulation', '')
87
+ chapter_val = doc.metadata.get('chapter', '')
88
+ section_val = doc.metadata.get('section', '')
89
+ standard_val = doc.metadata.get('standard', '')
90
+
91
+ if isinstance(regulation_val, list): regulation_val = ', '.join(map(str, regulation_val))
92
+ if isinstance(chapter_val, list): chapter_val = ', '.join(map(str, chapter_val))
93
+ if isinstance(section_val, list): section_val = ', '.join(map(str, section_val))
94
+ if isinstance(standard_val, list): standard_val = ', '.join(map(str, standard_val))
95
+
96
+ doc.metadata['faiss_id'] = faiss_id
97
+
98
+ cursor.execute(
99
+ """
100
+ INSERT OR REPLACE INTO documents
101
+ (faiss_id, source, regulation, chapter, section, standard, json_metadata)
102
+ VALUES (?, ?, ?, ?, ?, ?, ?)
103
+ """,
104
+ (faiss_id, source_val, regulation_val, chapter_val, section_val, standard_val, metadata_json)
105
+ )
106
+ inserted_count += 1
107
+
108
+ # 5. 최종 커밋
109
+ conn.commit()
110
+ print(f"✅ SQLite 데이터 저장 완료: 총 {inserted_count}행이 삽입되었습니다.")
111
+
112
+ except Exception as e:
113
+ print(f"🚨 [DB 저장 중 에러 발생] {e}")
114
+ # 에러가 나도 traceback을 볼 수 있게 함
115
+ import traceback
116
+ traceback.print_exc()
117
+
118
+ finally:
119
+ # 6. 연결 확실히 종료
120
+ conn.close()
121
+
122
+ # --- LocalSentenceTransformerEmbeddings ---
123
+ class LocalSentenceTransformerEmbeddings(Embeddings):
124
+ def __init__(self, st_model, normalize_embeddings: bool = True, encode_batch_size: int = 32):
125
+ self.model = st_model
126
+ self.normalize = normalize_embeddings
127
+ self.encode_batch_size = encode_batch_size
128
+
129
+ def embed_documents(self, texts):
130
+ vecs = self.model.encode(
131
+ texts,
132
+ batch_size=self.encode_batch_size,
133
+ show_progress_bar=False,
134
+ normalize_embeddings=self.normalize,
135
+ convert_to_numpy=True,
136
+ )
137
+ return vecs.tolist()
138
+
139
+ def embed_query(self, text: str):
140
+ vec = self.model.encode(
141
+ [text],
142
+ batch_size=self.encode_batch_size,
143
+ show_progress_bar=False,
144
+ normalize_embeddings=self.normalize,
145
+ convert_to_numpy=True,
146
+ )[0]
147
+ return vec.tolist()
148
+
149
+ def load_chunks_from_jsonl(file_paths: Union[str, List[str]]):
150
+ """
151
+ JSONL 파일 로드 함수
152
+ """
153
+ if isinstance(file_paths, str):
154
+ file_paths = [file_paths]
155
+
156
+ restored_documents = []
157
+ print(f" 총 {len(file_paths)}개의 파일 병합 로드를 시작합니다...")
158
+
159
+ for file_path in file_paths:
160
+ try:
161
+ file_doc_count = 0
162
+ with open(file_path, 'r', encoding='utf-8') as f:
163
+ for line_number, line in enumerate(f):
164
+ line = line.strip()
165
+ if not line: continue
166
+ data = json.loads(line)
167
+ doc = Document(
168
+ page_content=data.get('page_content', ""),
169
+ metadata=data.get('metadata', {})
170
+ )
171
+ restored_documents.append(doc)
172
+ file_doc_count += 1
173
+ print(f" - [성공] {file_path}: {file_doc_count}개 Chunk")
174
+
175
+ except Exception as e:
176
+ print(f" [실패] 오류 ({file_path}): {e}")
177
+ continue
178
+
179
+ print(f"✅ 전체 로드 완료: 총 {len(restored_documents)}개의 Chunk가 복원되었습니다.")
180
+ return restored_documents
181
+
182
+ # --- save_embedding_system (수정됨: Ensemble 제거 및 개별 반환) ---
183
+ def save_embedding_system(
184
+ chunks,
185
+ persist_directory: str = r"D:/Project AI/RAG",
186
+ batch_size: int = 32,
187
+ device: str = 'cuda'
188
+ ):
189
+ Path(persist_directory).mkdir(parents=True, exist_ok=True)
190
+
191
+ # 1) SQLite DB 저장
192
+ _create_and_populate_sqlite_db(chunks, persist_directory)
193
+
194
+ # 2) 모델 로드
195
+ model = SentenceTransformer(
196
+ 'nomic-ai/nomic-embed-text-v2-moe',
197
+ trust_remote_code=True,
198
+ device=device
199
+ )
200
+
201
+ embeddings = LocalSentenceTransformerEmbeddings(
202
+ st_model=model,
203
+ normalize_embeddings=True,
204
+ encode_batch_size=batch_size
205
+ )
206
+
207
+ # 3) FAISS 생성
208
+ vectorstore = None
209
+ for i in range(0, len(chunks), batch_size):
210
+ batch = chunks[i:i + batch_size]
211
+ if vectorstore is None:
212
+ vectorstore = FAISS.from_documents(documents=batch, embedding=embeddings)
213
+ else:
214
+ vectorstore.add_documents(documents=batch)
215
+ gc.collect()
216
+
217
+ # 4) BM25 생성 (Ensemble 없이 독립 생성)
218
+ bm25_retriever = BM25Retriever.from_documents(chunks)
219
+ bm25_retriever.k = 5
220
+
221
+ # 5) 저장
222
+ vectorstore.save_local(persist_directory)
223
+
224
+ # 6) 연결 반환 (개별 요소 반환)
225
+ sqlite_conn = get_db_connection(persist_directory)
226
+ gc.collect()
227
+
228
+ return bm25_retriever, vectorstore, sqlite_conn
229
+
230
+ # --- load_embedding_from_faiss (수정됨: Ensemble 제거 및 개별 반환) ---
231
+ def load_embedding_from_faiss(
232
+ persist_directory: str = r"D:/Project AI/RAG",
233
+ top_k: int = 10,
234
+ bm25_k: int = 10,
235
+ embeddings: Optional[Any] = None,
236
+ device: str = 'cpu'
237
+ ) -> Tuple[Any, FAISS, sqlite3.Connection]:
238
+
239
+ if embeddings is None:
240
+ st_model = SentenceTransformer(
241
+ 'nomic-ai/nomic-embed-text-v2-moe',
242
+ trust_remote_code=True,
243
+ device=device
244
+ )
245
+ embeddings = LocalSentenceTransformerEmbeddings(
246
+ st_model=st_model,
247
+ normalize_embeddings=True,
248
+ encode_batch_size=32
249
+ )
250
+
251
+ persist_dir = Path(persist_directory)
252
+ if not persist_dir.exists():
253
+ raise FileNotFoundError(f"FAISS 경로가 없습니다: {persist_dir}")
254
+
255
+ # FAISS 로드
256
+ vectorstore = FAISS.load_local(
257
+ folder_path=str(persist_dir),
258
+ embeddings=embeddings,
259
+ allow_dangerous_deserialization=True
260
+ )
261
+
262
+ # BM25 복원 (저장된 문서로부터 재생성)
263
+ bm25_retriever = None
264
+ docs = []
265
+ try:
266
+ if hasattr(vectorstore, "docstore") and hasattr(vectorstore.docstore, "_dict"):
267
+ docs = list(vectorstore.docstore._dict.values())
268
+ if docs:
269
+ bm25_retriever = BM25Retriever.from_documents(docs)
270
+ bm25_retriever.k = bm25_k
271
+ else:
272
+ print("[경고] 저장된 문서를 찾을 수 없어 BM25를 생성하지 못했습니다.")
273
+ except Exception as e:
274
+ print(f"[경고] 저장된 문서를 읽는 중 문제가 발생했습니다: {e}")
275
+
276
+ sqlite_conn = get_db_connection(persist_directory)
277
+
278
+ return bm25_retriever, vectorstore, sqlite_conn
279
+
280
+ # --- search_vectorstore (단순 벡터 검색 헬퍼) ---
281
+ def search_vectorstore(bm25_retriever, vectorstore, query, k=5):
282
+ """
283
+ vectorstore와 bm25_retriever를 받아 앙상블(Hybrid) 검색을 수행하는 함수.
284
+ EnsembleRetriever(weights=[0.6, 0.4])와 유사한 결과를 반환합니다.
285
+ """
286
+ weights=[0.6, 0.4]
287
+ # 1. 벡터 검색 수행 (Vector Search)
288
+ # FAISS를 리트리버로 변환하여 검색
289
+ vec_retriever = vectorstore.as_retriever(search_kwargs={"k": k})
290
+ vec_docs = vec_retriever.invoke(query)
291
+
292
+ # 2. 키워드 검색 수행 (BM25 Search)
293
+ # 검색 개수를 k개로 맞춰서 실행
294
+ bm25_docs = bm25_retriever.invoke(query, config={"search_kwargs": {"k": k}})
295
+
296
+ # 3. 랭킹 퓨전 (Weighted Reciprocal Rank Fusion)
297
+ # 두 리스트의 순위를 기반으로 가중치를 적용해 점수를 매깁니다.
298
+
299
+ doc_scores = {} # 문서 내용(또는 ID) -> 점수
300
+ doc_map = {} # 문서 내용 -> 문서 객체 저장 (나중에 반환하기 위해)
301
+
302
+ # 내부 함수: 순위에 따른 점수 계산 (Weight / (Rank + 1))
303
+ def apply_rank_score(docs, weight):
304
+ for rank, doc in enumerate(docs):
305
+ # 고유 키 생성 (page_content가 고유하다고 가정하거나, doc_id가 있다면 사용)
306
+ doc_key = doc.page_content
307
+ doc_map[doc_key] = doc
308
+
309
+ if doc_key not in doc_scores:
310
+ doc_scores[doc_key] = 0.0
311
+
312
+ # 순위가 높을수록(rank가 작을수록) 점수가 높음
313
+ score = weight / (rank + 1)
314
+ doc_scores[doc_key] += score
315
+
316
+ # 벡터 검색 결과 점수 반영 (가중치 0.6)
317
+ apply_rank_score(vec_docs, weights[0])
318
+
319
+ # BM25 검색 결과 점수 반영 (가중치 0.4)
320
+ apply_rank_score(bm25_docs, weights[1])
321
+
322
+ # 4. 점수순 정렬 (높은 점수가 상위)
323
+ # 점수(item[1])를 기준으로 내림차순 정렬
324
+ sorted_docs = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
325
+
326
+ # 5. Top-K 추출 및 문서 객체 반환
327
+ final_results = [doc_map[key] for key, score in sorted_docs[:k]]
328
+
329
+ return final_results
330
+
331
+ # --- search_with_metadata_filter (수정됨: 수동 병합 로직 구현) ---
332
+ def search_with_metadata_filter(
333
+ bm25_retriever: Any, # [변경] Ensemble 대신 BM25를 직접 받음
334
+ vectorstore: FAISS,
335
+ query: str,
336
+ k: int = 5,
337
+ metadata_filter: Optional[Dict[str, Any]] = None,
338
+ sqlite_conn: Optional[sqlite3.Connection] = None
339
+ ) -> List[Document]:
340
+ """
341
+ SQLite 사전 필터링 -> FAISS 벡터 검색 + BM25 검색 -> 결과 병합
342
+ """
343
+
344
+ # === 1. SQLite에서 필터링된 FAISS ID 추출 ===
345
+ filtered_ids = None
346
+ if metadata_filter and sqlite_conn:
347
+ cursor = sqlite_conn.cursor()
348
+ where_clauses = []
349
+ params = []
350
+
351
+ for key, value in metadata_filter.items():
352
+ if isinstance(value, list):
353
+ if not value: continue
354
+ placeholders = ', '.join(['?'] * len(value))
355
+ where_clauses.append(f"{key} IN ({placeholders})")
356
+ params.extend(value)
357
+ else:
358
+ where_clauses.append(f"{key} = ?")
359
+ params.append(value)
360
+
361
+ if where_clauses:
362
+ where_sql = " OR ".join(where_clauses)
363
+ sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}"
364
+
365
+ try:
366
+ cursor.execute(sql_query, params)
367
+ filtered_ids = {row[0] for row in cursor.fetchall()}
368
+ print(f"[사전 필터링] {len(filtered_ids)}개 ID 획득 → FAISS 검색 제한")
369
+ except Exception as e:
370
+ print(f"[경고] SQLite 필터링 실패: {e}")
371
+ filtered_ids = None
372
+ else:
373
+ print("[안내] 필터 조건 없음 → 전체 검색")
374
+ else:
375
+ print("[안내] 필터 또는 DB 없음 → 전체 검색")
376
+
377
+ # === 2. FAISS 벡터 검색 ===
378
+ vector_docs = []
379
+ if filtered_ids and len(filtered_ids) > 0:
380
+ selector = MetadataIDSelector(filtered_ids)
381
+ index: faiss.Index = vectorstore.index
382
+
383
+ query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32')
384
+ query_embedding = query_embedding.reshape(1, -1)
385
+
386
+ search_params = faiss.SearchParametersIVF(sel=selector, nprobe=20)
387
+ _k = max(k * 10, 100)
388
+ D, I = index.search(query_embedding, _k, params=search_params)
389
+
390
+ valid_indices = [i for i in I[0] if i != -1]
391
+ for idx in valid_indices[:k]:
392
+ doc_id = vectorstore.index_to_docstore_id[idx]
393
+ doc = vectorstore.docstore.search(doc_id)
394
+ if isinstance(doc, Document):
395
+ vector_docs.append(doc)
396
+ print(f"[벡터 검색] {len(valid_indices)}개 후보 → {len(vector_docs)}개 유효")
397
+ else:
398
+ # 전체 검색
399
+ vector_retriever = vectorstore.as_retriever(search_kwargs={"k": k})
400
+ vector_docs = vector_retriever.invoke(query)
401
+ print(f"[벡터 검색] 전체 검색 → {len(vector_docs)}개 후보")
402
+
403
+ # === 3. BM25 검색 ===
404
+ bm25_docs = []
405
+ if bm25_retriever:
406
+ search_k = k * 5
407
+ candidates = bm25_retriever.invoke(query, config={"search_kwargs": {"k": search_k}})
408
+ if filtered_ids:
409
+ bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids]
410
+ else:
411
+ bm25_docs = candidates
412
+
413
+ # Top K 자르기
414
+ bm25_docs = bm25_docs[:k]
415
+ print(f"[BM25 검색] {len(candidates)}개 후보 → {len(bm25_docs)}개 필터링 후")
416
+
417
+ # === 4. 병합 (Vector 우선 + 중복 제거) ===
418
+ combined = {id(d): d for d in (vector_docs + bm25_docs)}.values()
419
+ final_results = list(combined)[:k]
420
+
421
+ print(f"[최종 결과] {len(final_results)}개 문서 반환")
422
+ return final_results
423
+
424
+ # --- get_unique_metadata_values (빠진 함수 추가) ---
425
+ def get_unique_metadata_values(
426
+ sqlite_conn: sqlite3.Connection,
427
+ key_name: str,
428
+ partial_match: Optional[str] = None
429
+ ) -> List[str]:
430
+ """
431
+ 고유 값 검색 함수.
432
+ key_name 인자로 'part', 'subpart', 'section', 'source' 등을 사용할 수 있습니다.
433
+ """
434
+ if not sqlite_conn:
435
+ return []
436
+
437
+ cursor = sqlite_conn.cursor()
438
+ # 안전을 위해 key_name은 컬럼명으로 직접 사용 (SQL Injection 주의: 내부 사용 전제)
439
+ # 실제 프로덕션에서는 key_name을 화이트리스트로 검증하는 것이 좋습니다.
440
+ sql_query = f"SELECT DISTINCT `{key_name}` FROM documents"
441
+ params = []
442
+
443
+ if partial_match:
444
+ sql_query += f" WHERE `{key_name}` LIKE ?"
445
+ params.append(f"%{partial_match}%")
446
+
447
+ try:
448
+ cursor.execute(sql_query, params)
449
+ unique_values = [row[0] for row in cursor.fetchall() if row[0] is not None]
450
+ return unique_values
451
+ except Exception as e:
452
+ print(f"[에러] 고유 값 검색 실패 ({key_name}): {e}")
453
+ return []