financial-rag-chatbot / services /vector_store.py
Claude
Add complete Financial RAG system with Metacognitive Agent
f6b05db unverified
"""
Vector Database ํ†ตํ•ฉ (ChromaDB ์‚ฌ์šฉ)
"""
from typing import List, Dict, Optional, Any
import chromadb
from chromadb.config import Settings
from loguru import logger
from pathlib import Path
class VectorStore:
"""ChromaDB๋ฅผ ์‚ฌ์šฉํ•œ ๋ฒกํ„ฐ ์ €์žฅ์†Œ ํด๋ž˜์Šค"""
def __init__(
self,
persist_directory: str = "./data/chroma_db",
collection_name: str = "financial_papers"
):
"""
Args:
persist_directory: ChromaDB ๋ฐ์ดํ„ฐ ์ €์žฅ ๊ฒฝ๋กœ
collection_name: ์ปฌ๋ ‰์…˜ ์ด๋ฆ„
"""
self.persist_directory = Path(persist_directory)
self.collection_name = collection_name
# ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
self.persist_directory.mkdir(parents=True, exist_ok=True)
# ChromaDB ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
logger.info(f"Initializing ChromaDB at {persist_directory}")
self.client = chromadb.PersistentClient(
path=str(self.persist_directory)
)
# ์ปฌ๋ ‰์…˜ ์ƒ์„ฑ ๋˜๋Š” ๊ฐ€์ ธ์˜ค๊ธฐ
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"description": "Financial and Economics research papers"}
)
logger.info(f"Collection '{collection_name}' ready. Current count: {self.collection.count()}")
def add_documents(
self,
chunks: List[Dict[str, Any]],
embeddings: List[List[float]]
) -> None:
"""
๋ฌธ์„œ ์ฒญํฌ๋“ค์„ ๋ฒกํ„ฐ DB์— ์ถ”๊ฐ€
Args:
chunks: ์ฒญํฌ ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ (text, metadata ํฌํ•จ)
embeddings: ๊ฐ ์ฒญํฌ์˜ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ
"""
if len(chunks) != len(embeddings):
raise ValueError("Number of chunks and embeddings must match")
logger.info(f"Adding {len(chunks)} documents to vector store...")
# ChromaDB์— ํ•„์š”ํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
ids = [f"{chunk['source_filename']}_{chunk['chunk_id']}" for chunk in chunks]
documents = [chunk['text'] for chunk in chunks]
metadatas = [
{
'source_filename': chunk['source_filename'],
'source_filepath': chunk['source_filepath'],
'chunk_id': str(chunk['chunk_id']),
'total_chunks': str(chunk['total_chunks']),
'title': chunk['metadata'].get('title', ''),
'author': chunk['metadata'].get('author', ''),
'page_count': str(chunk['page_count'])
}
for chunk in chunks
]
# ๋ฐฐ์น˜๋กœ ์ถ”๊ฐ€ (ChromaDB๋Š” ํ•œ๋ฒˆ์— ๋งŽ์€ ์–‘ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅ)
batch_size = 100
for i in range(0, len(chunks), batch_size):
batch_end = min(i + batch_size, len(chunks))
self.collection.add(
ids=ids[i:batch_end],
embeddings=embeddings[i:batch_end],
documents=documents[i:batch_end],
metadatas=metadatas[i:batch_end]
)
logger.info(f"Added batch {i // batch_size + 1}/{(len(chunks) + batch_size - 1) // batch_size}")
logger.info(f"Successfully added {len(chunks)} documents. Total in collection: {self.collection.count()}")
def search(
self,
query_embedding: List[float],
top_k: int = 5,
filter_metadata: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
Args:
query_embedding: ์ฟผ๋ฆฌ์˜ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ
top_k: ๋ฐ˜ํ™˜ํ•  ๊ฒฐ๊ณผ ๊ฐœ์ˆ˜
filter_metadata: ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํ•„ํ„ฐ (optional)
Returns:
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ (documents, metadatas, distances)
"""
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=filter_metadata
)
return {
'documents': results['documents'][0] if results['documents'] else [],
'metadatas': results['metadatas'][0] if results['metadatas'] else [],
'distances': results['distances'][0] if results['distances'] else [],
'ids': results['ids'][0] if results['ids'] else []
}
def search_by_text(
self,
query_text: str,
top_k: int = 5,
filter_metadata: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
ํ…์ŠคํŠธ๋กœ ๊ฒ€์ƒ‰ (ChromaDB๊ฐ€ ์ž๋™์œผ๋กœ ์ž„๋ฒ ๋”ฉ)
Args:
query_text: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ ํ…์ŠคํŠธ
top_k: ๋ฐ˜ํ™˜ํ•  ๊ฒฐ๊ณผ ๊ฐœ์ˆ˜
filter_metadata: ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํ•„ํ„ฐ
Returns:
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
"""
results = self.collection.query(
query_texts=[query_text],
n_results=top_k,
where=filter_metadata
)
return {
'documents': results['documents'][0] if results['documents'] else [],
'metadatas': results['metadatas'][0] if results['metadatas'] else [],
'distances': results['distances'][0] if results['distances'] else [],
'ids': results['ids'][0] if results['ids'] else []
}
def get_collection_stats(self) -> Dict[str, Any]:
"""์ปฌ๋ ‰์…˜ ํ†ต๊ณ„ ์ •๋ณด"""
count = self.collection.count()
# ์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
sample = self.collection.peek(limit=1)
return {
'collection_name': self.collection_name,
'total_documents': count,
'persist_directory': str(self.persist_directory),
'has_data': count > 0
}
def delete_collection(self) -> None:
"""์ปฌ๋ ‰์…˜ ์‚ญ์ œ (์ฃผ์˜: ๋ชจ๋“  ๋ฐ์ดํ„ฐ ์‚ญ์ œ๋จ)"""
logger.warning(f"Deleting collection '{self.collection_name}'")
self.client.delete_collection(name=self.collection_name)
logger.info("Collection deleted")
def reset_collection(self) -> None:
"""์ปฌ๋ ‰์…˜ ์ดˆ๊ธฐํ™” (์‚ญ์ œ ํ›„ ์žฌ์ƒ์„ฑ)"""
self.delete_collection()
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Financial and Economics research papers"}
)
logger.info("Collection reset")