Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from typing import List, Dict | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from config import CHROMA_DB_PATH, EMBEDDING_MODEL, EMBEDDING_DIM | |
| class CLIPEmbedder: | |
| def __init__(self, model_name: str = EMBEDDING_MODEL): | |
| print(f"π Loading embedding model: {model_name}") | |
| self.model = SentenceTransformer(model_name) | |
| print(f"β Model loaded successfully") | |
| def embed(self, text: str) -> List[float]: | |
| try: | |
| embedding = self.model.encode(text, convert_to_numpy=False) | |
| return embedding.tolist() if hasattr(embedding, 'tolist') else embedding | |
| except Exception as e: | |
| print(f"Error embedding text: {e}") | |
| return [0.0] * EMBEDDING_DIM | |
| def embed_batch(self, texts: List[str]) -> List[List[float]]: | |
| try: | |
| embeddings = self.model.encode(texts, convert_to_numpy=False) | |
| return [e.tolist() if hasattr(e, 'tolist') else e for e in embeddings] | |
| except Exception as e: | |
| print(f"Error embedding batch: {e}") | |
| return [[0.0] * EMBEDDING_DIM] * len(texts) | |
| class VectorStore: | |
| def __init__(self): | |
| self.persist_directory = CHROMA_DB_PATH | |
| self.embedder = CLIPEmbedder() | |
| print(f"\nπ Initializing ChromaDB at: {self.persist_directory}") | |
| try: | |
| self.client = chromadb.PersistentClient( | |
| path=self.persist_directory | |
| ) | |
| print(f"β ChromaDB PersistentClient initialized") | |
| except Exception as e: | |
| print(f"β Error initializing ChromaDB: {e}") | |
| print(f"Trying fallback initialization...") | |
| self.client = chromadb.PersistentClient( | |
| path=self.persist_directory | |
| ) | |
| try: | |
| self.collection = self.client.get_or_create_collection( | |
| name="multimodal_rag", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| count = self.collection.count() | |
| print(f"β Collection loaded: {count} items in store") | |
| except Exception as e: | |
| print(f"Error with collection: {e}") | |
| self.collection = self.client.get_or_create_collection( | |
| name="multimodal_rag" | |
| ) | |
| def add_documents(self, documents: List[Dict], doc_id: str): | |
| texts = [] | |
| metadatas = [] | |
| ids = [] | |
| print(f"\nπ Adding documents for: {doc_id}") | |
| if 'text' in documents and documents['text']: | |
| chunks = self._chunk_text(documents['text'], chunk_size=1000, overlap=200) | |
| for idx, chunk in enumerate(chunks): | |
| texts.append(chunk) | |
| metadatas.append({ | |
| 'doc_id': doc_id, | |
| 'type': 'text', | |
| 'chunk_idx': str(idx) | |
| }) | |
| ids.append(f"{doc_id}_text_{idx}") | |
| print(f" β Text: {len(chunks)} chunks") | |
| if 'images' in documents: | |
| image_count = 0 | |
| for idx, image_data in enumerate(documents['images']): | |
| if image_data.get('ocr_text'): | |
| texts.append(f"Image {idx}: {image_data['ocr_text']}") | |
| metadatas.append({ | |
| 'doc_id': doc_id, | |
| 'type': 'image', | |
| 'image_idx': str(idx), | |
| 'image_path': image_data.get('path', '') | |
| }) | |
| ids.append(f"{doc_id}_image_{idx}") | |
| image_count += 1 | |
| if image_count > 0: | |
| print(f" β Images: {image_count} with OCR text") | |
| if 'tables' in documents: | |
| table_count = 0 | |
| for idx, table_data in enumerate(documents['tables']): | |
| if table_data.get('content'): | |
| texts.append(f"Table {idx}: {table_data.get('content', '')}") | |
| metadatas.append({ | |
| 'doc_id': doc_id, | |
| 'type': 'table', | |
| 'table_idx': str(idx) | |
| }) | |
| ids.append(f"{doc_id}_table_{idx}") | |
| table_count += 1 | |
| if table_count > 0: | |
| print(f" β Tables: {table_count}") | |
| if texts: | |
| print(f" π Generating {len(texts)} embeddings...") | |
| embeddings = self.embedder.embed_batch(texts) | |
| try: | |
| self.collection.add( | |
| ids=ids, | |
| documents=texts, | |
| embeddings=embeddings, | |
| metadatas=metadatas | |
| ) | |
| print(f"β Successfully added {len(texts)} items to vector store") | |
| print(f"β Data persisted automatically to: {self.persist_directory}") | |
| except Exception as e: | |
| print(f"β Error adding to collection: {e}") | |
| def search(self, query: str, n_results: int = 5) -> List[Dict]: | |
| try: | |
| query_embedding = self.embedder.embed(query) | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=n_results | |
| ) | |
| formatted_results = [] | |
| if results['documents']: | |
| for i, doc in enumerate(results['documents'][0]): | |
| metadata = results['metadatas'][0][i] if results['metadatas'] else {} | |
| distance = results['distances'][0][i] if results['distances'] else 0 | |
| formatted_results.append({ | |
| 'content': doc, | |
| 'metadata': metadata, | |
| 'distance': distance, | |
| 'type': metadata.get('type', 'unknown') | |
| }) | |
| return formatted_results | |
| except Exception as e: | |
| print(f"Error searching vector store: {e}") | |
| return [] | |
| def _chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]: | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| chunks.append(text[start:end]) | |
| start = end - overlap | |
| return chunks | |
| def get_collection_info(self) -> Dict: | |
| try: | |
| count = self.collection.count() | |
| return { | |
| 'name': 'multimodal_rag', | |
| 'count': count, | |
| 'status': 'active', | |
| 'persist_path': self.persist_directory | |
| } | |
| except Exception as e: | |
| print(f"Error getting collection info: {e}") | |
| return {'status': 'error', 'message': str(e)} | |
| def delete_by_doc_id(self, doc_id: str): | |
| try: | |
| # Get all IDs with this doc_id | |
| results = self.collection.get(where={'doc_id': doc_id}) | |
| if results['ids']: | |
| self.collection.delete(ids=results['ids']) | |
| print(f"β Deleted {len(results['ids'])} documents for {doc_id}") | |
| # Auto-persist on delete | |
| print(f"β Changes persisted automatically") | |
| except Exception as e: | |
| print(f"Error deleting documents: {e}") | |
| def persist(self): | |
| print("β Vector store is using auto-persist") | |
| def clear_all(self): | |
| try: | |
| self.client.delete_collection(name="multimodal_rag") | |
| self.collection = self.client.get_or_create_collection( | |
| name="multimodal_rag", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| print("β Collection cleared and reset") | |
| except Exception as e: | |
| print(f"Error clearing collection: {e}") |