import os import json from typing import List, Dict import chromadb from sentence_transformers import SentenceTransformer import numpy as np from config import CHROMA_DB_PATH, EMBEDDING_MODEL, EMBEDDING_DIM class CLIPEmbedder: def __init__(self, model_name: str = EMBEDDING_MODEL): print(f"šŸ”„ Loading embedding model: {model_name}") self.model = SentenceTransformer(model_name) print(f"āœ… Model loaded successfully") def embed(self, text: str) -> List[float]: try: embedding = self.model.encode(text, convert_to_numpy=False) return embedding.tolist() if hasattr(embedding, 'tolist') else embedding except Exception as e: print(f"Error embedding text: {e}") return [0.0] * EMBEDDING_DIM def embed_batch(self, texts: List[str]) -> List[List[float]]: try: embeddings = self.model.encode(texts, convert_to_numpy=False) return [e.tolist() if hasattr(e, 'tolist') else e for e in embeddings] except Exception as e: print(f"Error embedding batch: {e}") return [[0.0] * EMBEDDING_DIM] * len(texts) class VectorStore: def __init__(self): self.persist_directory = CHROMA_DB_PATH self.embedder = CLIPEmbedder() print(f"\nšŸ”„ Initializing ChromaDB at: {self.persist_directory}") try: self.client = chromadb.PersistentClient( path=self.persist_directory ) print(f"āœ… ChromaDB PersistentClient initialized") except Exception as e: print(f"āŒ Error initializing ChromaDB: {e}") print(f"Trying fallback initialization...") self.client = chromadb.PersistentClient( path=self.persist_directory ) try: self.collection = self.client.get_or_create_collection( name="multimodal_rag", metadata={"hnsw:space": "cosine"} ) count = self.collection.count() print(f"āœ… Collection loaded: {count} items in store") except Exception as e: print(f"Error with collection: {e}") self.collection = self.client.get_or_create_collection( name="multimodal_rag" ) def add_documents(self, documents: List[Dict], doc_id: str): texts = [] metadatas = [] ids = [] print(f"\nšŸ“š Adding documents for: {doc_id}") if 'text' in documents and documents['text']: chunks = self._chunk_text(documents['text'], chunk_size=1000, overlap=200) for idx, chunk in enumerate(chunks): texts.append(chunk) metadatas.append({ 'doc_id': doc_id, 'type': 'text', 'chunk_idx': str(idx) }) ids.append(f"{doc_id}_text_{idx}") print(f" āœ… Text: {len(chunks)} chunks") if 'images' in documents: image_count = 0 for idx, image_data in enumerate(documents['images']): if image_data.get('ocr_text'): texts.append(f"Image {idx}: {image_data['ocr_text']}") metadatas.append({ 'doc_id': doc_id, 'type': 'image', 'image_idx': str(idx), 'image_path': image_data.get('path', '') }) ids.append(f"{doc_id}_image_{idx}") image_count += 1 if image_count > 0: print(f" āœ… Images: {image_count} with OCR text") if 'tables' in documents: table_count = 0 for idx, table_data in enumerate(documents['tables']): if table_data.get('content'): texts.append(f"Table {idx}: {table_data.get('content', '')}") metadatas.append({ 'doc_id': doc_id, 'type': 'table', 'table_idx': str(idx) }) ids.append(f"{doc_id}_table_{idx}") table_count += 1 if table_count > 0: print(f" āœ… Tables: {table_count}") if texts: print(f" šŸ”„ Generating {len(texts)} embeddings...") embeddings = self.embedder.embed_batch(texts) try: self.collection.add( ids=ids, documents=texts, embeddings=embeddings, metadatas=metadatas ) print(f"āœ… Successfully added {len(texts)} items to vector store") print(f"āœ… Data persisted automatically to: {self.persist_directory}") except Exception as e: print(f"āŒ Error adding to collection: {e}") def search(self, query: str, n_results: int = 5) -> List[Dict]: try: query_embedding = self.embedder.embed(query) results = self.collection.query( query_embeddings=[query_embedding], n_results=n_results ) formatted_results = [] if results['documents']: for i, doc in enumerate(results['documents'][0]): metadata = results['metadatas'][0][i] if results['metadatas'] else {} distance = results['distances'][0][i] if results['distances'] else 0 formatted_results.append({ 'content': doc, 'metadata': metadata, 'distance': distance, 'type': metadata.get('type', 'unknown') }) return formatted_results except Exception as e: print(f"Error searching vector store: {e}") return [] def _chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]: chunks = [] start = 0 while start < len(text): end = start + chunk_size chunks.append(text[start:end]) start = end - overlap return chunks def get_collection_info(self) -> Dict: try: count = self.collection.count() return { 'name': 'multimodal_rag', 'count': count, 'status': 'active', 'persist_path': self.persist_directory } except Exception as e: print(f"Error getting collection info: {e}") return {'status': 'error', 'message': str(e)} def delete_by_doc_id(self, doc_id: str): try: # Get all IDs with this doc_id results = self.collection.get(where={'doc_id': doc_id}) if results['ids']: self.collection.delete(ids=results['ids']) print(f"āœ… Deleted {len(results['ids'])} documents for {doc_id}") # Auto-persist on delete print(f"āœ… Changes persisted automatically") except Exception as e: print(f"Error deleting documents: {e}") def persist(self): print("āœ… Vector store is using auto-persist") def clear_all(self): try: self.client.delete_collection(name="multimodal_rag") self.collection = self.client.get_or_create_collection( name="multimodal_rag", metadata={"hnsw:space": "cosine"} ) print("āœ… Collection cleared and reset") except Exception as e: print(f"Error clearing collection: {e}")