""" 벡터 DB 및 임베딩 관련 기능 Chroma DB를 사용한 벡터 검색 및 Re-ranking 시스템 """ import os import json import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer, CrossEncoder from pathlib import Path import numpy as np # 벡터 DB 경로 VECTOR_DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'vector_db') # 임베딩 모델 (한국어 지원) EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask" # 한국어 지원 모델 # 또는 영어 중심: "all-MiniLM-L6-v2" (더 빠르지만 한국어 성능 낮음) # Cross-Encoder 모델 (리랭킹용) # 한국어 리랭커를 찾을 수 없으면 영어 모델 사용 RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 범용 리랭커 (한국어도 어느 정도 지원) # 또는: "BAAI/bge-reranker-base" (더 나은 성능) class VectorDBManager: """벡터 DB 관리 클래스""" def __init__(self): """벡터 DB 초기화""" self.embedding_model = None self.reranker_model = None self.client = None self.collection = None # 벡터 DB 폴더 생성 os.makedirs(VECTOR_DB_PATH, exist_ok=True) # Chroma DB 클라이언트 초기화 self.client = chromadb.PersistentClient( path=VECTOR_DB_PATH, settings=Settings( anonymized_telemetry=False, allow_reset=True ) ) # 컬렉션 생성 (없으면 생성, 있으면 가져오기) try: self.collection = self.client.get_or_create_collection( name="document_chunks", metadata={"hnsw:space": "cosine"} # 코사인 유사도 사용 ) print(f"[벡터 DB] 컬렉션 로드/생성 완료: {len(self.collection.get()['ids'])}개 문서") except Exception as e: print(f"[벡터 DB] 컬렉션 생성 오류: {e}") raise def get_embedding_model(self): """임베딩 모델 로드 (지연 로딩)""" if self.embedding_model is None: print(f"[임베딩 모델] 로딩 중: {EMBEDDING_MODEL_NAME}") try: self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) print(f"[임베딩 모델] 로딩 완료") except Exception as e: print(f"[임베딩 모델] 로딩 오류: {e}") # 대체 모델 시도 try: self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") print(f"[임베딩 모델] 대체 모델 로딩 완료: all-MiniLM-L6-v2") except Exception as e2: print(f"[임베딩 모델] 대체 모델도 로딩 실패: {e2}") raise return self.embedding_model def get_reranker_model(self): """Cross-Encoder 리랭커 모델 로드 (지연 로딩)""" if self.reranker_model is None: print(f"[리랭커 모델] 로딩 중: {RERANKER_MODEL_NAME}") try: self.reranker_model = CrossEncoder(RERANKER_MODEL_NAME) print(f"[리랭커 모델] 로딩 완료") except Exception as e: print(f"[리랭커 모델] 로딩 오류: {e}") # 대체 모델 시도 (더 가벼운 모델) try: print(f"[리랭커 모델] 대체 모델 시도: BAAI/bge-reranker-base") self.reranker_model = CrossEncoder("BAAI/bge-reranker-base", max_length=512) print(f"[리랭커 모델] 대체 모델 로딩 완료") except Exception as e2: print(f"[리랭커 모델] 대체 모델도 로딩 실패: {e2}") # 리랭킹 없이 진행 (경고만 출력) print(f"[리랭커 모델] ⚠️ 경고: 리랭커 모델을 로드할 수 없습니다. 리랭킹 없이 진행합니다.") self.reranker_model = None return self.reranker_model def generate_embedding(self, text): """텍스트에 대한 임베딩 생성""" try: model = self.get_embedding_model() embedding = model.encode(text, convert_to_numpy=True).tolist() return embedding except Exception as e: print(f"[임베딩 생성] 오류: {e}") return None def add_chunk(self, chunk_id, chunk_content, file_id, chunk_index, metadata=None): """청크를 벡터 DB에 추가""" try: # 임베딩 생성 embedding = self.generate_embedding(chunk_content) if embedding is None: return False # 메타데이터 준비 chunk_metadata = { 'file_id': str(file_id), 'chunk_index': str(chunk_index), 'content_length': str(len(chunk_content)) } if metadata: chunk_metadata.update(metadata) # 벡터 DB에 추가 self.collection.add( ids=[str(chunk_id)], embeddings=[embedding], documents=[chunk_content], metadatas=[chunk_metadata] ) return True except Exception as e: print(f"[벡터 DB 추가] 오류: {e}") import traceback traceback.print_exc() return False def search_chunks(self, query, file_ids=None, top_k=30): """벡터 검색으로 관련 청크 검색 (초기 검색, top_k=30)""" try: # 쿼리 임베딩 생성 query_embedding = self.generate_embedding(query) if query_embedding is None: return [] # 필터 조건 설정 where_clause = None if file_ids: where_clause = {"file_id": {"$in": [str(fid) for fid in file_ids]}} # 벡터 검색 results = self.collection.query( query_embeddings=[query_embedding], n_results=min(top_k, 30), # 최대 30개 where=where_clause ) # 결과 파싱 chunks = [] if results and 'ids' in results and len(results['ids'][0]) > 0: for i, chunk_id in enumerate(results['ids'][0]): chunks.append({ 'chunk_id': int(chunk_id), 'content': results['documents'][0][i] if 'documents' in results else '', 'metadata': results['metadatas'][0][i] if 'metadatas' in results else {}, 'distance': results['distances'][0][i] if 'distances' in results else 1.0 }) print(f"[벡터 검색] {len(chunks)}개 청크 검색 완료") return chunks except Exception as e: print(f"[벡터 검색] 오류: {e}") import traceback traceback.print_exc() return [] def rerank_chunks(self, query, chunks, top_k=5): """Cross-Encoder를 사용하여 청크 리랭킹 (상위 top_k개 반환)""" try: if not chunks or len(chunks) == 0: return [] # 리랭커 모델 로드 reranker = self.get_reranker_model() # 리랭커 모델이 없으면 거리 기반 정렬만 수행 if reranker is None: print(f"[리랭킹] ⚠️ 리랭커 모델 없음, 거리 기반 정렬만 수행") scored_chunks = [] for chunk in chunks: scored_chunks.append({ 'chunk_id': chunk['chunk_id'], 'content': chunk['content'], 'metadata': chunk['metadata'], 'rerank_score': 1.0 - chunk.get('distance', 1.0), # 거리를 점수로 변환 'original_distance': chunk.get('distance', 1.0) }) scored_chunks.sort(key=lambda x: x['rerank_score'], reverse=True) return scored_chunks[:top_k] # 쿼리-문서 쌍 준비 (최대 길이 제한) pairs = [] max_content_length = 500 # 청크 내용이 너무 길면 잘라냄 for chunk in chunks: content = chunk['content'] if len(content) > max_content_length: content = content[:max_content_length] pairs.append([query, content]) # 리랭킹 점수 계산 print(f"[리랭킹] {len(pairs)}개 청크에 대한 리랭킹 시작...") scores = reranker.predict(pairs) # 점수와 청크 결합 scored_chunks = [] for i, chunk in enumerate(chunks): scored_chunks.append({ 'chunk_id': chunk['chunk_id'], 'content': chunk['content'], 'metadata': chunk['metadata'], 'rerank_score': float(scores[i]), 'original_distance': chunk.get('distance', 1.0) }) # 점수 순으로 정렬 (높은 점수 = 더 관련성 높음) scored_chunks.sort(key=lambda x: x['rerank_score'], reverse=True) # 상위 top_k개만 선택 top_chunks = scored_chunks[:top_k] print(f"[리랭킹] 완료: 상위 {len(top_chunks)}개 청크 선택") for i, chunk in enumerate(top_chunks): print(f" {i+1}. 점수: {chunk['rerank_score']:.4f}, 청크 ID: {chunk['chunk_id']}") return top_chunks except Exception as e: print(f"[리랭킹] 오류: {e}") import traceback traceback.print_exc() # 오류 시 원본 청크 상위 top_k개 반환 (거리 기준) chunks_sorted = sorted(chunks, key=lambda x: x.get('distance', 1.0)) return chunks_sorted[:top_k] def delete_chunks_by_file_id(self, file_id): """파일 ID로 해당 파일의 모든 청크 삭제""" try: # 해당 파일의 모든 청크 찾기 results = self.collection.get( where={"file_id": str(file_id)} ) if results and 'ids' in results and len(results['ids']) > 0: # 청크 삭제 self.collection.delete(ids=results['ids']) print(f"[벡터 DB 삭제] 파일 ID {file_id}의 {len(results['ids'])}개 청크 삭제 완료") return True return False except Exception as e: print(f"[벡터 DB 삭제] 오류: {e}") return False def get_chunk_count(self): """벡터 DB에 저장된 청크 개수 반환""" try: return self.collection.count() except Exception as e: print(f"[벡터 DB] 청크 개수 조회 오류: {e}") return 0 # 전역 벡터 DB 매니저 인스턴스 _vector_db_manager = None def get_vector_db(): """벡터 DB 매니저 싱글톤 인스턴스 반환""" global _vector_db_manager if _vector_db_manager is None: _vector_db_manager = VectorDBManager() return _vector_db_manager