|
|
""" |
|
|
๋ฒกํฐ DB ๋ฐ ์๋ฒ ๋ฉ ๊ด๋ จ ๊ธฐ๋ฅ |
|
|
Chroma DB๋ฅผ ์ฌ์ฉํ ๋ฒกํฐ ๊ฒ์ ๋ฐ Re-ranking ์์คํ
|
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
VECTOR_DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'vector_db') |
|
|
|
|
|
|
|
|
EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" |
|
|
|
|
|
|
|
|
class VectorDBManager: |
|
|
"""๋ฒกํฐ DB ๊ด๋ฆฌ ํด๋์ค""" |
|
|
|
|
|
def __init__(self): |
|
|
"""๋ฒกํฐ DB ์ด๊ธฐํ""" |
|
|
self.embedding_model = None |
|
|
self.reranker_model = None |
|
|
self.client = None |
|
|
self.collection = None |
|
|
|
|
|
|
|
|
os.makedirs(VECTOR_DB_PATH, exist_ok=True) |
|
|
|
|
|
|
|
|
self.client = chromadb.PersistentClient( |
|
|
path=VECTOR_DB_PATH, |
|
|
settings=Settings( |
|
|
anonymized_telemetry=False, |
|
|
allow_reset=True |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
self.collection = self.client.get_or_create_collection( |
|
|
name="document_chunks", |
|
|
metadata={"hnsw:space": "cosine"} |
|
|
) |
|
|
print(f"[๋ฒกํฐ DB] ์ปฌ๋ ์
๋ก๋/์์ฑ ์๋ฃ: {len(self.collection.get()['ids'])}๊ฐ ๋ฌธ์") |
|
|
except Exception as e: |
|
|
print(f"[๋ฒกํฐ DB] ์ปฌ๋ ์
์์ฑ ์ค๋ฅ: {e}") |
|
|
raise |
|
|
|
|
|
def get_embedding_model(self): |
|
|
"""์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ (์ง์ฐ ๋ก๋ฉ)""" |
|
|
if self.embedding_model is None: |
|
|
print(f"[์๋ฒ ๋ฉ ๋ชจ๋ธ] ๋ก๋ฉ ์ค: {EMBEDDING_MODEL_NAME}") |
|
|
try: |
|
|
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
|
print(f"[์๋ฒ ๋ฉ ๋ชจ๋ธ] ๋ก๋ฉ ์๋ฃ") |
|
|
except Exception as e: |
|
|
print(f"[์๋ฒ ๋ฉ ๋ชจ๋ธ] ๋ก๋ฉ ์ค๋ฅ: {e}") |
|
|
|
|
|
try: |
|
|
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
print(f"[์๋ฒ ๋ฉ ๋ชจ๋ธ] ๋์ฒด ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ: all-MiniLM-L6-v2") |
|
|
except Exception as e2: |
|
|
print(f"[์๋ฒ ๋ฉ ๋ชจ๋ธ] ๋์ฒด ๋ชจ๋ธ๋ ๋ก๋ฉ ์คํจ: {e2}") |
|
|
raise |
|
|
return self.embedding_model |
|
|
|
|
|
def get_reranker_model(self): |
|
|
"""Cross-Encoder ๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ (์ง์ฐ ๋ก๋ฉ)""" |
|
|
if self.reranker_model is None: |
|
|
print(f"[๋ฆฌ๋ญ์ปค ๋ชจ๋ธ] ๋ก๋ฉ ์ค: {RERANKER_MODEL_NAME}") |
|
|
try: |
|
|
self.reranker_model = CrossEncoder(RERANKER_MODEL_NAME) |
|
|
print(f"[๋ฆฌ๋ญ์ปค ๋ชจ๋ธ] ๋ก๋ฉ ์๋ฃ") |
|
|
except Exception as e: |
|
|
print(f"[๋ฆฌ๋ญ์ปค ๋ชจ๋ธ] ๋ก๋ฉ ์ค๋ฅ: {e}") |
|
|
|
|
|
try: |
|
|
print(f"[๋ฆฌ๋ญ์ปค ๋ชจ๋ธ] ๋์ฒด ๋ชจ๋ธ ์๋: BAAI/bge-reranker-base") |
|
|
self.reranker_model = CrossEncoder("BAAI/bge-reranker-base", max_length=512) |
|
|
print(f"[๋ฆฌ๋ญ์ปค ๋ชจ๋ธ] ๋์ฒด ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ") |
|
|
except Exception as e2: |
|
|
print(f"[๋ฆฌ๋ญ์ปค ๋ชจ๋ธ] ๋์ฒด ๋ชจ๋ธ๋ ๋ก๋ฉ ์คํจ: {e2}") |
|
|
|
|
|
print(f"[๋ฆฌ๋ญ์ปค ๋ชจ๋ธ] โ ๏ธ ๊ฒฝ๊ณ : ๋ฆฌ๋ญ์ปค ๋ชจ๋ธ์ ๋ก๋ํ ์ ์์ต๋๋ค. ๋ฆฌ๋ญํน ์์ด ์งํํฉ๋๋ค.") |
|
|
self.reranker_model = None |
|
|
return self.reranker_model |
|
|
|
|
|
def generate_embedding(self, text): |
|
|
"""ํ
์คํธ์ ๋ํ ์๋ฒ ๋ฉ ์์ฑ""" |
|
|
try: |
|
|
model = self.get_embedding_model() |
|
|
embedding = model.encode(text, convert_to_numpy=True).tolist() |
|
|
return embedding |
|
|
except Exception as e: |
|
|
print(f"[์๋ฒ ๋ฉ ์์ฑ] ์ค๋ฅ: {e}") |
|
|
return None |
|
|
|
|
|
def add_chunk(self, chunk_id, chunk_content, file_id, chunk_index, metadata=None): |
|
|
"""์ฒญํฌ๋ฅผ ๋ฒกํฐ DB์ ์ถ๊ฐ""" |
|
|
try: |
|
|
|
|
|
embedding = self.generate_embedding(chunk_content) |
|
|
if embedding is None: |
|
|
return False |
|
|
|
|
|
|
|
|
chunk_metadata = { |
|
|
'file_id': str(file_id), |
|
|
'chunk_index': str(chunk_index), |
|
|
'content_length': str(len(chunk_content)) |
|
|
} |
|
|
if metadata: |
|
|
chunk_metadata.update(metadata) |
|
|
|
|
|
|
|
|
self.collection.add( |
|
|
ids=[str(chunk_id)], |
|
|
embeddings=[embedding], |
|
|
documents=[chunk_content], |
|
|
metadatas=[chunk_metadata] |
|
|
) |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"[๋ฒกํฐ DB ์ถ๊ฐ] ์ค๋ฅ: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
def search_chunks(self, query, file_ids=None, top_k=30): |
|
|
"""๋ฒกํฐ ๊ฒ์์ผ๋ก ๊ด๋ จ ์ฒญํฌ ๊ฒ์ (์ด๊ธฐ ๊ฒ์, top_k=30)""" |
|
|
try: |
|
|
|
|
|
query_embedding = self.generate_embedding(query) |
|
|
if query_embedding is None: |
|
|
return [] |
|
|
|
|
|
|
|
|
where_clause = None |
|
|
if file_ids: |
|
|
where_clause = {"file_id": {"$in": [str(fid) for fid in file_ids]}} |
|
|
|
|
|
|
|
|
results = self.collection.query( |
|
|
query_embeddings=[query_embedding], |
|
|
n_results=min(top_k, 30), |
|
|
where=where_clause |
|
|
) |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
if results and 'ids' in results and len(results['ids'][0]) > 0: |
|
|
for i, chunk_id in enumerate(results['ids'][0]): |
|
|
chunks.append({ |
|
|
'chunk_id': int(chunk_id), |
|
|
'content': results['documents'][0][i] if 'documents' in results else '', |
|
|
'metadata': results['metadatas'][0][i] if 'metadatas' in results else {}, |
|
|
'distance': results['distances'][0][i] if 'distances' in results else 1.0 |
|
|
}) |
|
|
|
|
|
print(f"[๋ฒกํฐ ๊ฒ์] {len(chunks)}๊ฐ ์ฒญํฌ ๊ฒ์ ์๋ฃ") |
|
|
return chunks |
|
|
except Exception as e: |
|
|
print(f"[๋ฒกํฐ ๊ฒ์] ์ค๋ฅ: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return [] |
|
|
|
|
|
def rerank_chunks(self, query, chunks, top_k=5): |
|
|
"""Cross-Encoder๋ฅผ ์ฌ์ฉํ์ฌ ์ฒญํฌ ๋ฆฌ๋ญํน (์์ top_k๊ฐ ๋ฐํ)""" |
|
|
try: |
|
|
if not chunks or len(chunks) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
reranker = self.get_reranker_model() |
|
|
|
|
|
|
|
|
if reranker is None: |
|
|
print(f"[๋ฆฌ๋ญํน] โ ๏ธ ๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ์์, ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ ์ ๋ ฌ๋ง ์ํ") |
|
|
scored_chunks = [] |
|
|
for chunk in chunks: |
|
|
scored_chunks.append({ |
|
|
'chunk_id': chunk['chunk_id'], |
|
|
'content': chunk['content'], |
|
|
'metadata': chunk['metadata'], |
|
|
'rerank_score': 1.0 - chunk.get('distance', 1.0), |
|
|
'original_distance': chunk.get('distance', 1.0) |
|
|
}) |
|
|
scored_chunks.sort(key=lambda x: x['rerank_score'], reverse=True) |
|
|
return scored_chunks[:top_k] |
|
|
|
|
|
|
|
|
pairs = [] |
|
|
max_content_length = 500 |
|
|
for chunk in chunks: |
|
|
content = chunk['content'] |
|
|
if len(content) > max_content_length: |
|
|
content = content[:max_content_length] |
|
|
pairs.append([query, content]) |
|
|
|
|
|
|
|
|
print(f"[๋ฆฌ๋ญํน] {len(pairs)}๊ฐ ์ฒญํฌ์ ๋ํ ๋ฆฌ๋ญํน ์์...") |
|
|
scores = reranker.predict(pairs) |
|
|
|
|
|
|
|
|
scored_chunks = [] |
|
|
for i, chunk in enumerate(chunks): |
|
|
scored_chunks.append({ |
|
|
'chunk_id': chunk['chunk_id'], |
|
|
'content': chunk['content'], |
|
|
'metadata': chunk['metadata'], |
|
|
'rerank_score': float(scores[i]), |
|
|
'original_distance': chunk.get('distance', 1.0) |
|
|
}) |
|
|
|
|
|
|
|
|
scored_chunks.sort(key=lambda x: x['rerank_score'], reverse=True) |
|
|
|
|
|
|
|
|
top_chunks = scored_chunks[:top_k] |
|
|
|
|
|
print(f"[๋ฆฌ๋ญํน] ์๋ฃ: ์์ {len(top_chunks)}๊ฐ ์ฒญํฌ ์ ํ") |
|
|
for i, chunk in enumerate(top_chunks): |
|
|
print(f" {i+1}. ์ ์: {chunk['rerank_score']:.4f}, ์ฒญํฌ ID: {chunk['chunk_id']}") |
|
|
|
|
|
return top_chunks |
|
|
except Exception as e: |
|
|
print(f"[๋ฆฌ๋ญํน] ์ค๋ฅ: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
chunks_sorted = sorted(chunks, key=lambda x: x.get('distance', 1.0)) |
|
|
return chunks_sorted[:top_k] |
|
|
|
|
|
def delete_chunks_by_file_id(self, file_id): |
|
|
"""ํ์ผ ID๋ก ํด๋น ํ์ผ์ ๋ชจ๋ ์ฒญํฌ ์ญ์ """ |
|
|
try: |
|
|
|
|
|
results = self.collection.get( |
|
|
where={"file_id": str(file_id)} |
|
|
) |
|
|
|
|
|
if results and 'ids' in results and len(results['ids']) > 0: |
|
|
|
|
|
self.collection.delete(ids=results['ids']) |
|
|
print(f"[๋ฒกํฐ DB ์ญ์ ] ํ์ผ ID {file_id}์ {len(results['ids'])}๊ฐ ์ฒญํฌ ์ญ์ ์๋ฃ") |
|
|
return True |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"[๋ฒกํฐ DB ์ญ์ ] ์ค๋ฅ: {e}") |
|
|
return False |
|
|
|
|
|
def get_chunk_count(self): |
|
|
"""๋ฒกํฐ DB์ ์ ์ฅ๋ ์ฒญํฌ ๊ฐ์ ๋ฐํ""" |
|
|
try: |
|
|
return self.collection.count() |
|
|
except Exception as e: |
|
|
print(f"[๋ฒกํฐ DB] ์ฒญํฌ ๊ฐ์ ์กฐํ ์ค๋ฅ: {e}") |
|
|
return 0 |
|
|
|
|
|
|
|
|
_vector_db_manager = None |
|
|
|
|
|
def get_vector_db(): |
|
|
"""๋ฒกํฐ DB ๋งค๋์ ์ฑ๊ธํค ์ธ์คํด์ค ๋ฐํ""" |
|
|
global _vector_db_manager |
|
|
if _vector_db_manager is None: |
|
|
_vector_db_manager = VectorDBManager() |
|
|
return _vector_db_manager |
|
|
|
|
|
|