Spaces:
Sleeping
Sleeping
| """Vector database management for Francis Botcon.""" | |
| import json | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple | |
| import numpy as np | |
| from src.embeddings import EmbeddingGenerator | |
| from src.logger import LoggerSetup | |
| from src.config_loader import config | |
| logger = LoggerSetup.setup().getChild(__name__) | |
| class VectorDatabase: | |
| """Manage vector embeddings and retrieval using ChromaDB or FAISS.""" | |
| def __init__(self, db_type: str = None, db_path: str = None): | |
| """Initialize vector database. | |
| Args: | |
| db_type: Type of database ('chromadb' or 'faiss') | |
| db_path: Path to database | |
| """ | |
| self.db_type = db_type or config.get("vector_db.type", "chromadb") | |
| self.db_path = Path(db_path or config.get("vector_db.db_path", "./data/vectordb")) | |
| self.db_path.mkdir(parents=True, exist_ok=True) | |
| self.embedding_generator = EmbeddingGenerator() | |
| self.top_k = config.get("vector_db.top_k", 5) | |
| self.similarity_threshold = config.get("vector_db.similarity_threshold", 0.6) | |
| logger.info(f"Initializing {self.db_type} database at {self.db_path}") | |
| if self.db_type == "chromadb": | |
| self._init_chromadb() | |
| elif self.db_type == "faiss": | |
| self._init_faiss() | |
| else: | |
| raise ValueError(f"Unsupported database type: {self.db_type}") | |
| def _init_chromadb(self): | |
| """Initialize ChromaDB.""" | |
| try: | |
| import chromadb | |
| self.client = chromadb.PersistentClient(path=str(self.db_path)) | |
| self.collection = None | |
| # Try to load existing collection | |
| try: | |
| self.collection = self.client.get_collection(name="francis_bacon") | |
| logger.info("✓ ChromaDB initialized - loaded existing collection") | |
| except Exception as e: | |
| logger.debug(f"No existing collection found: {e}. Will create on first add_documents call.") | |
| logger.info("✓ ChromaDB initialized") | |
| except ImportError: | |
| logger.error("ChromaDB not installed. Install with: pip install chromadb") | |
| raise | |
| def _init_faiss(self): | |
| """Initialize FAISS.""" | |
| try: | |
| import faiss | |
| self.faiss = faiss | |
| self.index = None | |
| self.documents = [] | |
| logger.info("✓ FAISS initialized") | |
| except ImportError: | |
| logger.error("FAISS not installed. Install with: pip install faiss-cpu") | |
| raise | |
| def add_documents(self, documents: List[Dict[str, str]], batch_size: int = 32): | |
| """Add documents to vector database. | |
| Args: | |
| documents: List of documents with 'id', 'text', and metadata | |
| batch_size: Batch size for embedding generation | |
| """ | |
| logger.info(f"Adding {len(documents)} documents to {self.db_type} database") | |
| # Extract texts for embedding | |
| texts = [doc["text"] for doc in documents] | |
| # Generate embeddings | |
| embeddings = self.embedding_generator.embed(texts, batch_size=batch_size) | |
| if self.db_type == "chromadb": | |
| self._add_to_chromadb(documents, embeddings, texts) | |
| elif self.db_type == "faiss": | |
| self._add_to_faiss(documents, embeddings, texts) | |
| logger.info("✓ Documents added successfully") | |
| def _add_to_chromadb(self, documents: List[Dict], embeddings: np.ndarray, texts: List[str]): | |
| """Add documents to ChromaDB. | |
| Args: | |
| documents: Document list | |
| embeddings: Embedding vectors | |
| texts: Text strings | |
| """ | |
| # Create collection if not exists | |
| if self.collection is None: | |
| self.collection = self.client.get_or_create_collection( | |
| name="francis_bacon", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| # Prepare metadata | |
| metadatas = [] | |
| ids = [] | |
| for i, doc in enumerate(documents): | |
| ids.append(doc["id"]) | |
| metadatas.append({ | |
| "source": doc.get("source", ""), | |
| "title": doc.get("title", ""), | |
| "author": doc.get("author", ""), | |
| "segment_index": str(doc.get("segment_index", 0)) | |
| }) | |
| # Add to collection | |
| self.collection.add( | |
| ids=ids, | |
| embeddings=embeddings.tolist(), | |
| documents=texts, | |
| metadatas=metadatas | |
| ) | |
| def _add_to_faiss(self, documents: List[Dict], embeddings: np.ndarray, texts: List[str]): | |
| """Add documents to FAISS. | |
| Args: | |
| documents: Document list | |
| embeddings: Embedding vectors | |
| texts: Text strings | |
| """ | |
| # Initialize index if needed | |
| if self.index is None: | |
| embedding_dim = embeddings.shape[1] | |
| self.index = self.faiss.IndexFlatL2(embedding_dim) | |
| # Convert to float32 for FAISS | |
| embeddings_float32 = embeddings.astype(np.float32) | |
| # Add vectors | |
| self.index.add(embeddings_float32) | |
| # Store documents | |
| for doc, text in zip(documents, texts): | |
| doc["embedding_index"] = len(self.documents) | |
| self.documents.append({**doc, "text": text}) | |
| # Save index | |
| self._save_faiss_index() | |
| def search(self, query: str, top_k: int = None) -> List[Tuple[str, float, Dict]]: | |
| """Search for similar documents. | |
| Args: | |
| query: Query text | |
| top_k: Number of results to return | |
| Returns: | |
| List of (text, score, metadata) tuples | |
| """ | |
| top_k = top_k or self.top_k | |
| # Generate query embedding | |
| query_embedding = self.embedding_generator.embed_single(query) | |
| if self.db_type == "chromadb": | |
| return self._search_chromadb(query_embedding, top_k) | |
| elif self.db_type == "faiss": | |
| return self._search_faiss(query_embedding, top_k) | |
| def _search_chromadb(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float, Dict]]: | |
| """Search ChromaDB. | |
| Args: | |
| query_embedding: Query embedding vector | |
| top_k: Number of results | |
| Returns: | |
| Search results | |
| """ | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding.tolist()], | |
| n_results=top_k, | |
| include=["documents", "distances", "metadatas"] | |
| ) | |
| output = [] | |
| if results["documents"] and len(results["documents"]) > 0: | |
| for i, doc in enumerate(results["documents"][0]): | |
| # ChromaDB uses distance, convert to similarity (cosine) | |
| distance = results["distances"][0][i] | |
| similarity = 1 - (distance / 2) # Approximate cosine conversion | |
| metadata = results["metadatas"][0][i] if results["metadatas"] else {} | |
| if similarity >= self.similarity_threshold: | |
| output.append((doc, similarity, metadata)) | |
| return output | |
| def _search_faiss(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float, Dict]]: | |
| """Search FAISS. | |
| Args: | |
| query_embedding: Query embedding vector | |
| top_k: Number of results | |
| Returns: | |
| Search results | |
| """ | |
| query_embedding_float32 = query_embedding.astype(np.float32).reshape(1, -1) | |
| distances, indices = self.index.search(query_embedding_float32, top_k) | |
| output = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx != -1: | |
| # Convert L2 distance to similarity | |
| distance = distances[0][i] | |
| similarity = 1 / (1 + distance) | |
| if similarity >= self.similarity_threshold: | |
| doc_info = self.documents[idx] | |
| metadata = { | |
| "source": doc_info.get("source", ""), | |
| "title": doc_info.get("title", ""), | |
| "author": doc_info.get("author", "") | |
| } | |
| output.append((doc_info["text"], similarity, metadata)) | |
| return output | |
| def _save_faiss_index(self): | |
| """Save FAISS index and documents.""" | |
| if self.db_type == "faiss": | |
| import faiss | |
| index_path = self.db_path / "faiss_index.bin" | |
| docs_path = self.db_path / "documents.json" | |
| faiss.write_index(self.index, str(index_path)) | |
| with open(docs_path, 'w') as f: | |
| json.dump(self.documents, f, ensure_ascii=False, indent=2) | |
| logger.debug(f"FAISS index saved to {index_path}") | |
| def load_index(self): | |
| """Load existing FAISS index.""" | |
| if self.db_type == "faiss": | |
| import faiss | |
| index_path = self.db_path / "faiss_index.bin" | |
| docs_path = self.db_path / "documents.json" | |
| if index_path.exists() and docs_path.exists(): | |
| self.index = faiss.read_index(str(index_path)) | |
| with open(docs_path, 'r') as f: | |
| self.documents = json.load(f) | |
| logger.info("✓ FAISS index loaded") | |
| return True | |
| return False | |