Spaces:
Sleeping
Sleeping
| """ | |
| Vector Database ํตํฉ (ChromaDB ์ฌ์ฉ) | |
| """ | |
| from typing import List, Dict, Optional, Any | |
| import chromadb | |
| from chromadb.config import Settings | |
| from loguru import logger | |
| from pathlib import Path | |
| class VectorStore: | |
| """ChromaDB๋ฅผ ์ฌ์ฉํ ๋ฒกํฐ ์ ์ฅ์ ํด๋์ค""" | |
| def __init__( | |
| self, | |
| persist_directory: str = "./data/chroma_db", | |
| collection_name: str = "financial_papers" | |
| ): | |
| """ | |
| Args: | |
| persist_directory: ChromaDB ๋ฐ์ดํฐ ์ ์ฅ ๊ฒฝ๋ก | |
| collection_name: ์ปฌ๋ ์ ์ด๋ฆ | |
| """ | |
| self.persist_directory = Path(persist_directory) | |
| self.collection_name = collection_name | |
| # ๋๋ ํ ๋ฆฌ ์์ฑ | |
| self.persist_directory.mkdir(parents=True, exist_ok=True) | |
| # ChromaDB ํด๋ผ์ด์ธํธ ์ด๊ธฐํ | |
| logger.info(f"Initializing ChromaDB at {persist_directory}") | |
| self.client = chromadb.PersistentClient( | |
| path=str(self.persist_directory) | |
| ) | |
| # ์ปฌ๋ ์ ์์ฑ ๋๋ ๊ฐ์ ธ์ค๊ธฐ | |
| self.collection = self.client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"description": "Financial and Economics research papers"} | |
| ) | |
| logger.info(f"Collection '{collection_name}' ready. Current count: {self.collection.count()}") | |
| def add_documents( | |
| self, | |
| chunks: List[Dict[str, Any]], | |
| embeddings: List[List[float]] | |
| ) -> None: | |
| """ | |
| ๋ฌธ์ ์ฒญํฌ๋ค์ ๋ฒกํฐ DB์ ์ถ๊ฐ | |
| Args: | |
| chunks: ์ฒญํฌ ๋ฐ์ดํฐ ๋ฆฌ์คํธ (text, metadata ํฌํจ) | |
| embeddings: ๊ฐ ์ฒญํฌ์ ์๋ฒ ๋ฉ ๋ฒกํฐ | |
| """ | |
| if len(chunks) != len(embeddings): | |
| raise ValueError("Number of chunks and embeddings must match") | |
| logger.info(f"Adding {len(chunks)} documents to vector store...") | |
| # ChromaDB์ ํ์ํ ํ์์ผ๋ก ๋ณํ | |
| ids = [f"{chunk['source_filename']}_{chunk['chunk_id']}" for chunk in chunks] | |
| documents = [chunk['text'] for chunk in chunks] | |
| metadatas = [ | |
| { | |
| 'source_filename': chunk['source_filename'], | |
| 'source_filepath': chunk['source_filepath'], | |
| 'chunk_id': str(chunk['chunk_id']), | |
| 'total_chunks': str(chunk['total_chunks']), | |
| 'title': chunk['metadata'].get('title', ''), | |
| 'author': chunk['metadata'].get('author', ''), | |
| 'page_count': str(chunk['page_count']) | |
| } | |
| for chunk in chunks | |
| ] | |
| # ๋ฐฐ์น๋ก ์ถ๊ฐ (ChromaDB๋ ํ๋ฒ์ ๋ง์ ์ ์ฒ๋ฆฌ ๊ฐ๋ฅ) | |
| batch_size = 100 | |
| for i in range(0, len(chunks), batch_size): | |
| batch_end = min(i + batch_size, len(chunks)) | |
| self.collection.add( | |
| ids=ids[i:batch_end], | |
| embeddings=embeddings[i:batch_end], | |
| documents=documents[i:batch_end], | |
| metadatas=metadatas[i:batch_end] | |
| ) | |
| logger.info(f"Added batch {i // batch_size + 1}/{(len(chunks) + batch_size - 1) // batch_size}") | |
| logger.info(f"Successfully added {len(chunks)} documents. Total in collection: {self.collection.count()}") | |
| def search( | |
| self, | |
| query_embedding: List[float], | |
| top_k: int = 5, | |
| filter_metadata: Optional[Dict[str, str]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| ๋ฒกํฐ ๊ฒ์ ์ํ | |
| Args: | |
| query_embedding: ์ฟผ๋ฆฌ์ ์๋ฒ ๋ฉ ๋ฒกํฐ | |
| top_k: ๋ฐํํ ๊ฒฐ๊ณผ ๊ฐ์ | |
| filter_metadata: ๋ฉํ๋ฐ์ดํฐ ํํฐ (optional) | |
| Returns: | |
| ๊ฒ์ ๊ฒฐ๊ณผ (documents, metadatas, distances) | |
| """ | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| where=filter_metadata | |
| ) | |
| return { | |
| 'documents': results['documents'][0] if results['documents'] else [], | |
| 'metadatas': results['metadatas'][0] if results['metadatas'] else [], | |
| 'distances': results['distances'][0] if results['distances'] else [], | |
| 'ids': results['ids'][0] if results['ids'] else [] | |
| } | |
| def search_by_text( | |
| self, | |
| query_text: str, | |
| top_k: int = 5, | |
| filter_metadata: Optional[Dict[str, str]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| ํ ์คํธ๋ก ๊ฒ์ (ChromaDB๊ฐ ์๋์ผ๋ก ์๋ฒ ๋ฉ) | |
| Args: | |
| query_text: ๊ฒ์ ์ฟผ๋ฆฌ ํ ์คํธ | |
| top_k: ๋ฐํํ ๊ฒฐ๊ณผ ๊ฐ์ | |
| filter_metadata: ๋ฉํ๋ฐ์ดํฐ ํํฐ | |
| Returns: | |
| ๊ฒ์ ๊ฒฐ๊ณผ | |
| """ | |
| results = self.collection.query( | |
| query_texts=[query_text], | |
| n_results=top_k, | |
| where=filter_metadata | |
| ) | |
| return { | |
| 'documents': results['documents'][0] if results['documents'] else [], | |
| 'metadatas': results['metadatas'][0] if results['metadatas'] else [], | |
| 'distances': results['distances'][0] if results['distances'] else [], | |
| 'ids': results['ids'][0] if results['ids'] else [] | |
| } | |
| def get_collection_stats(self) -> Dict[str, Any]: | |
| """์ปฌ๋ ์ ํต๊ณ ์ ๋ณด""" | |
| count = self.collection.count() | |
| # ์ํ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ | |
| sample = self.collection.peek(limit=1) | |
| return { | |
| 'collection_name': self.collection_name, | |
| 'total_documents': count, | |
| 'persist_directory': str(self.persist_directory), | |
| 'has_data': count > 0 | |
| } | |
| def delete_collection(self) -> None: | |
| """์ปฌ๋ ์ ์ญ์ (์ฃผ์: ๋ชจ๋ ๋ฐ์ดํฐ ์ญ์ ๋จ)""" | |
| logger.warning(f"Deleting collection '{self.collection_name}'") | |
| self.client.delete_collection(name=self.collection_name) | |
| logger.info("Collection deleted") | |
| def reset_collection(self) -> None: | |
| """์ปฌ๋ ์ ์ด๊ธฐํ (์ญ์ ํ ์ฌ์์ฑ)""" | |
| self.delete_collection() | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.collection_name, | |
| metadata={"description": "Financial and Economics research papers"} | |
| ) | |
| logger.info("Collection reset") | |