soyailabs / app /vector_db.py
SOY NV AI
feat: Add Re-ranking system and improve AI response prompts
fa87e9c
raw
history blame
11.7 kB
"""
๋ฒกํ„ฐ DB ๋ฐ ์ž„๋ฒ ๋”ฉ ๊ด€๋ จ ๊ธฐ๋Šฅ
Chroma DB๋ฅผ ์‚ฌ์šฉํ•œ ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ๋ฐ Re-ranking ์‹œ์Šคํ…œ
"""
import os
import json
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer, CrossEncoder
from pathlib import Path
import numpy as np
# ๋ฒกํ„ฐ DB ๊ฒฝ๋กœ
VECTOR_DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'vector_db')
# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ (ํ•œ๊ตญ์–ด ์ง€์›)
EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask" # ํ•œ๊ตญ์–ด ์ง€์› ๋ชจ๋ธ
# ๋˜๋Š” ์˜์–ด ์ค‘์‹ฌ: "all-MiniLM-L6-v2" (๋” ๋น ๋ฅด์ง€๋งŒ ํ•œ๊ตญ์–ด ์„ฑ๋Šฅ ๋‚ฎ์Œ)
# Cross-Encoder ๋ชจ๋ธ (๋ฆฌ๋žญํ‚น์šฉ)
# ํ•œ๊ตญ์–ด ๋ฆฌ๋žญ์ปค๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์œผ๋ฉด ์˜์–ด ๋ชจ๋ธ ์‚ฌ์šฉ
RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" # ๋ฒ”์šฉ ๋ฆฌ๋žญ์ปค (ํ•œ๊ตญ์–ด๋„ ์–ด๋А ์ •๋„ ์ง€์›)
# ๋˜๋Š”: "BAAI/bge-reranker-base" (๋” ๋‚˜์€ ์„ฑ๋Šฅ)
class VectorDBManager:
"""๋ฒกํ„ฐ DB ๊ด€๋ฆฌ ํด๋ž˜์Šค"""
def __init__(self):
"""๋ฒกํ„ฐ DB ์ดˆ๊ธฐํ™”"""
self.embedding_model = None
self.reranker_model = None
self.client = None
self.collection = None
# ๋ฒกํ„ฐ DB ํด๋” ์ƒ์„ฑ
os.makedirs(VECTOR_DB_PATH, exist_ok=True)
# Chroma DB ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
self.client = chromadb.PersistentClient(
path=VECTOR_DB_PATH,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
# ์ปฌ๋ ‰์…˜ ์ƒ์„ฑ (์—†์œผ๋ฉด ์ƒ์„ฑ, ์žˆ์œผ๋ฉด ๊ฐ€์ ธ์˜ค๊ธฐ)
try:
self.collection = self.client.get_or_create_collection(
name="document_chunks",
metadata={"hnsw:space": "cosine"} # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ์‚ฌ์šฉ
)
print(f"[๋ฒกํ„ฐ DB] ์ปฌ๋ ‰์…˜ ๋กœ๋“œ/์ƒ์„ฑ ์™„๋ฃŒ: {len(self.collection.get()['ids'])}๊ฐœ ๋ฌธ์„œ")
except Exception as e:
print(f"[๋ฒกํ„ฐ DB] ์ปฌ๋ ‰์…˜ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
raise
def get_embedding_model(self):
"""์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ (์ง€์—ฐ ๋กœ๋”ฉ)"""
if self.embedding_model is None:
print(f"[์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ] ๋กœ๋”ฉ ์ค‘: {EMBEDDING_MODEL_NAME}")
try:
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
print(f"[์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ] ๋กœ๋”ฉ ์™„๋ฃŒ")
except Exception as e:
print(f"[์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ] ๋กœ๋”ฉ ์˜ค๋ฅ˜: {e}")
# ๋Œ€์ฒด ๋ชจ๋ธ ์‹œ๋„
try:
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print(f"[์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ] ๋Œ€์ฒด ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ: all-MiniLM-L6-v2")
except Exception as e2:
print(f"[์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ] ๋Œ€์ฒด ๋ชจ๋ธ๋„ ๋กœ๋”ฉ ์‹คํŒจ: {e2}")
raise
return self.embedding_model
def get_reranker_model(self):
"""Cross-Encoder ๋ฆฌ๋žญ์ปค ๋ชจ๋ธ ๋กœ๋“œ (์ง€์—ฐ ๋กœ๋”ฉ)"""
if self.reranker_model is None:
print(f"[๋ฆฌ๋žญ์ปค ๋ชจ๋ธ] ๋กœ๋”ฉ ์ค‘: {RERANKER_MODEL_NAME}")
try:
self.reranker_model = CrossEncoder(RERANKER_MODEL_NAME)
print(f"[๋ฆฌ๋žญ์ปค ๋ชจ๋ธ] ๋กœ๋”ฉ ์™„๋ฃŒ")
except Exception as e:
print(f"[๋ฆฌ๋žญ์ปค ๋ชจ๋ธ] ๋กœ๋”ฉ ์˜ค๋ฅ˜: {e}")
# ๋Œ€์ฒด ๋ชจ๋ธ ์‹œ๋„ (๋” ๊ฐ€๋ฒผ์šด ๋ชจ๋ธ)
try:
print(f"[๋ฆฌ๋žญ์ปค ๋ชจ๋ธ] ๋Œ€์ฒด ๋ชจ๋ธ ์‹œ๋„: BAAI/bge-reranker-base")
self.reranker_model = CrossEncoder("BAAI/bge-reranker-base", max_length=512)
print(f"[๋ฆฌ๋žญ์ปค ๋ชจ๋ธ] ๋Œ€์ฒด ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
except Exception as e2:
print(f"[๋ฆฌ๋žญ์ปค ๋ชจ๋ธ] ๋Œ€์ฒด ๋ชจ๋ธ๋„ ๋กœ๋”ฉ ์‹คํŒจ: {e2}")
# ๋ฆฌ๋žญํ‚น ์—†์ด ์ง„ํ–‰ (๊ฒฝ๊ณ ๋งŒ ์ถœ๋ ฅ)
print(f"[๋ฆฌ๋žญ์ปค ๋ชจ๋ธ] โš ๏ธ ๊ฒฝ๊ณ : ๋ฆฌ๋žญ์ปค ๋ชจ๋ธ์„ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋ฆฌ๋žญํ‚น ์—†์ด ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
self.reranker_model = None
return self.reranker_model
def generate_embedding(self, text):
"""ํ…์ŠคํŠธ์— ๋Œ€ํ•œ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ"""
try:
model = self.get_embedding_model()
embedding = model.encode(text, convert_to_numpy=True).tolist()
return embedding
except Exception as e:
print(f"[์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ] ์˜ค๋ฅ˜: {e}")
return None
def add_chunk(self, chunk_id, chunk_content, file_id, chunk_index, metadata=None):
"""์ฒญํฌ๋ฅผ ๋ฒกํ„ฐ DB์— ์ถ”๊ฐ€"""
try:
# ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
embedding = self.generate_embedding(chunk_content)
if embedding is None:
return False
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ค€๋น„
chunk_metadata = {
'file_id': str(file_id),
'chunk_index': str(chunk_index),
'content_length': str(len(chunk_content))
}
if metadata:
chunk_metadata.update(metadata)
# ๋ฒกํ„ฐ DB์— ์ถ”๊ฐ€
self.collection.add(
ids=[str(chunk_id)],
embeddings=[embedding],
documents=[chunk_content],
metadatas=[chunk_metadata]
)
return True
except Exception as e:
print(f"[๋ฒกํ„ฐ DB ์ถ”๊ฐ€] ์˜ค๋ฅ˜: {e}")
import traceback
traceback.print_exc()
return False
def search_chunks(self, query, file_ids=None, top_k=30):
"""๋ฒกํ„ฐ ๊ฒ€์ƒ‰์œผ๋กœ ๊ด€๋ จ ์ฒญํฌ ๊ฒ€์ƒ‰ (์ดˆ๊ธฐ ๊ฒ€์ƒ‰, top_k=30)"""
try:
# ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
query_embedding = self.generate_embedding(query)
if query_embedding is None:
return []
# ํ•„ํ„ฐ ์กฐ๊ฑด ์„ค์ •
where_clause = None
if file_ids:
where_clause = {"file_id": {"$in": [str(fid) for fid in file_ids]}}
# ๋ฒกํ„ฐ ๊ฒ€์ƒ‰
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(top_k, 30), # ์ตœ๋Œ€ 30๊ฐœ
where=where_clause
)
# ๊ฒฐ๊ณผ ํŒŒ์‹ฑ
chunks = []
if results and 'ids' in results and len(results['ids'][0]) > 0:
for i, chunk_id in enumerate(results['ids'][0]):
chunks.append({
'chunk_id': int(chunk_id),
'content': results['documents'][0][i] if 'documents' in results else '',
'metadata': results['metadatas'][0][i] if 'metadatas' in results else {},
'distance': results['distances'][0][i] if 'distances' in results else 1.0
})
print(f"[๋ฒกํ„ฐ ๊ฒ€์ƒ‰] {len(chunks)}๊ฐœ ์ฒญํฌ ๊ฒ€์ƒ‰ ์™„๋ฃŒ")
return chunks
except Exception as e:
print(f"[๋ฒกํ„ฐ ๊ฒ€์ƒ‰] ์˜ค๋ฅ˜: {e}")
import traceback
traceback.print_exc()
return []
def rerank_chunks(self, query, chunks, top_k=5):
"""Cross-Encoder๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ฒญํฌ ๋ฆฌ๋žญํ‚น (์ƒ์œ„ top_k๊ฐœ ๋ฐ˜ํ™˜)"""
try:
if not chunks or len(chunks) == 0:
return []
# ๋ฆฌ๋žญ์ปค ๋ชจ๋ธ ๋กœ๋“œ
reranker = self.get_reranker_model()
# ๋ฆฌ๋žญ์ปค ๋ชจ๋ธ์ด ์—†์œผ๋ฉด ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ˜ ์ •๋ ฌ๋งŒ ์ˆ˜ํ–‰
if reranker is None:
print(f"[๋ฆฌ๋žญํ‚น] โš ๏ธ ๋ฆฌ๋žญ์ปค ๋ชจ๋ธ ์—†์Œ, ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ˜ ์ •๋ ฌ๋งŒ ์ˆ˜ํ–‰")
scored_chunks = []
for chunk in chunks:
scored_chunks.append({
'chunk_id': chunk['chunk_id'],
'content': chunk['content'],
'metadata': chunk['metadata'],
'rerank_score': 1.0 - chunk.get('distance', 1.0), # ๊ฑฐ๋ฆฌ๋ฅผ ์ ์ˆ˜๋กœ ๋ณ€ํ™˜
'original_distance': chunk.get('distance', 1.0)
})
scored_chunks.sort(key=lambda x: x['rerank_score'], reverse=True)
return scored_chunks[:top_k]
# ์ฟผ๋ฆฌ-๋ฌธ์„œ ์Œ ์ค€๋น„ (์ตœ๋Œ€ ๊ธธ์ด ์ œํ•œ)
pairs = []
max_content_length = 500 # ์ฒญํฌ ๋‚ด์šฉ์ด ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„
for chunk in chunks:
content = chunk['content']
if len(content) > max_content_length:
content = content[:max_content_length]
pairs.append([query, content])
# ๋ฆฌ๋žญํ‚น ์ ์ˆ˜ ๊ณ„์‚ฐ
print(f"[๋ฆฌ๋žญํ‚น] {len(pairs)}๊ฐœ ์ฒญํฌ์— ๋Œ€ํ•œ ๋ฆฌ๋žญํ‚น ์‹œ์ž‘...")
scores = reranker.predict(pairs)
# ์ ์ˆ˜์™€ ์ฒญํฌ ๊ฒฐํ•ฉ
scored_chunks = []
for i, chunk in enumerate(chunks):
scored_chunks.append({
'chunk_id': chunk['chunk_id'],
'content': chunk['content'],
'metadata': chunk['metadata'],
'rerank_score': float(scores[i]),
'original_distance': chunk.get('distance', 1.0)
})
# ์ ์ˆ˜ ์ˆœ์œผ๋กœ ์ •๋ ฌ (๋†’์€ ์ ์ˆ˜ = ๋” ๊ด€๋ จ์„ฑ ๋†’์Œ)
scored_chunks.sort(key=lambda x: x['rerank_score'], reverse=True)
# ์ƒ์œ„ top_k๊ฐœ๋งŒ ์„ ํƒ
top_chunks = scored_chunks[:top_k]
print(f"[๋ฆฌ๋žญํ‚น] ์™„๋ฃŒ: ์ƒ์œ„ {len(top_chunks)}๊ฐœ ์ฒญํฌ ์„ ํƒ")
for i, chunk in enumerate(top_chunks):
print(f" {i+1}. ์ ์ˆ˜: {chunk['rerank_score']:.4f}, ์ฒญํฌ ID: {chunk['chunk_id']}")
return top_chunks
except Exception as e:
print(f"[๋ฆฌ๋žญํ‚น] ์˜ค๋ฅ˜: {e}")
import traceback
traceback.print_exc()
# ์˜ค๋ฅ˜ ์‹œ ์›๋ณธ ์ฒญํฌ ์ƒ์œ„ top_k๊ฐœ ๋ฐ˜ํ™˜ (๊ฑฐ๋ฆฌ ๊ธฐ์ค€)
chunks_sorted = sorted(chunks, key=lambda x: x.get('distance', 1.0))
return chunks_sorted[:top_k]
def delete_chunks_by_file_id(self, file_id):
"""ํŒŒ์ผ ID๋กœ ํ•ด๋‹น ํŒŒ์ผ์˜ ๋ชจ๋“  ์ฒญํฌ ์‚ญ์ œ"""
try:
# ํ•ด๋‹น ํŒŒ์ผ์˜ ๋ชจ๋“  ์ฒญํฌ ์ฐพ๊ธฐ
results = self.collection.get(
where={"file_id": str(file_id)}
)
if results and 'ids' in results and len(results['ids']) > 0:
# ์ฒญํฌ ์‚ญ์ œ
self.collection.delete(ids=results['ids'])
print(f"[๋ฒกํ„ฐ DB ์‚ญ์ œ] ํŒŒ์ผ ID {file_id}์˜ {len(results['ids'])}๊ฐœ ์ฒญํฌ ์‚ญ์ œ ์™„๋ฃŒ")
return True
return False
except Exception as e:
print(f"[๋ฒกํ„ฐ DB ์‚ญ์ œ] ์˜ค๋ฅ˜: {e}")
return False
def get_chunk_count(self):
"""๋ฒกํ„ฐ DB์— ์ €์žฅ๋œ ์ฒญํฌ ๊ฐœ์ˆ˜ ๋ฐ˜ํ™˜"""
try:
return self.collection.count()
except Exception as e:
print(f"[๋ฒกํ„ฐ DB] ์ฒญํฌ ๊ฐœ์ˆ˜ ์กฐํšŒ ์˜ค๋ฅ˜: {e}")
return 0
# ์ „์—ญ ๋ฒกํ„ฐ DB ๋งค๋‹ˆ์ € ์ธ์Šคํ„ด์Šค
_vector_db_manager = None
def get_vector_db():
"""๋ฒกํ„ฐ DB ๋งค๋‹ˆ์ € ์‹ฑ๊ธ€ํ†ค ์ธ์Šคํ„ด์Šค ๋ฐ˜ํ™˜"""
global _vector_db_manager
if _vector_db_manager is None:
_vector_db_manager = VectorDBManager()
return _vector_db_manager