Spaces:
Runtime error
Runtime error
| import chromadb | |
| from chromadb.config import Settings as ChromaSettings | |
| from sentence_transformers import SentenceTransformer | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import uuid | |
| from config.settings import Settings | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class VectorStore: | |
| def __init__(self): | |
| self.client = chromadb.PersistentClient( | |
| path=Settings.CHROMA_PERSIST_DIR, | |
| settings=ChromaSettings(anonymized_telemetry=False) | |
| ) | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.collection = None | |
| self.initialize_collection() | |
| def initialize_collection(self): | |
| """ | |
| Initialize or get the main knowledge base collection | |
| """ | |
| try: | |
| self.collection = self.client.get_or_create_collection( | |
| name=Settings.COLLECTION_NAME, | |
| metadata={"description": "General knowledge base for reasoning copilot"} | |
| ) | |
| logger.info(f"Initialized collection: {Settings.COLLECTION_NAME}") | |
| except Exception as e: | |
| logger.error(f"Error initializing collection: {e}") | |
| raise | |
| def add_documents(self, documents: List[str], metadata: Optional[List[Dict]] = None, | |
| ids: Optional[List[str]] = None) -> bool: | |
| """ | |
| Add documents to the vector store | |
| """ | |
| try: | |
| if not documents: | |
| return False | |
| # Generate IDs if not provided | |
| if ids is None: | |
| ids = [str(uuid.uuid4()) for _ in documents] | |
| # Generate embeddings | |
| embeddings = self.embedding_model.encode(documents).tolist() | |
| # Prepare metadata | |
| if metadata is None: | |
| metadata = [{"source": "user_upload", "type": "document"} for _ in documents] | |
| # Add to collection | |
| self.collection.add( | |
| documents=documents, | |
| embeddings=embeddings, | |
| metadatas=metadata, | |
| ids=ids | |
| ) | |
| logger.info(f"Added {len(documents)} documents to vector store") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error adding documents: {e}") | |
| return False | |
| def search_similar(self, query: str, n_results: int = 5, | |
| where: Optional[Dict] = None) -> Dict[str, Any]: | |
| """ | |
| Search for similar documents | |
| """ | |
| try: | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.encode([query]).tolist()[0] | |
| # Search | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=n_results, | |
| where=where, | |
| include=['documents', 'metadatas', 'distances'] | |
| ) | |
| return { | |
| 'documents': results['documents'][0] if results['documents'] else [], | |
| 'metadatas': results['metadatas'][0] if results['metadatas'] else [], | |
| 'distances': results['distances'][0] if results['distances'] else [], | |
| 'count': len(results['documents'][0]) if results['documents'] else 0 | |
| } | |
| except Exception as e: | |
| logger.error(f"Error searching documents: {e}") | |
| return {'documents': [], 'metadatas': [], 'distances': [], 'count': 0} | |
| def get_relevant_context(self, query: str, max_context_length: int = 2000) -> str: | |
| """ | |
| Get relevant context for a query, formatted for LLM consumption | |
| """ | |
| results = self.search_similar(query, n_results=5) | |
| if not results['documents']: | |
| return "" | |
| context_parts = [] | |
| current_length = 0 | |
| for i, (doc, metadata) in enumerate(zip(results['documents'], results['metadatas'])): | |
| # Create a context snippet | |
| source = metadata.get('source', 'Unknown') | |
| snippet = f"Source: {source}\nContent: {doc[:500]}...\n" | |
| if current_length + len(snippet) > max_context_length: | |
| break | |
| context_parts.append(snippet) | |
| current_length += len(snippet) | |
| return "\n---\n".join(context_parts) | |
| def add_conversation_memory(self, user_input: str, assistant_response: str, session_id: str): | |
| """ | |
| Add conversation exchange to memory | |
| """ | |
| try: | |
| memory_doc = f"User: {user_input}\nAssistant: {assistant_response}" | |
| metadata = { | |
| "type": "conversation", | |
| "session_id": session_id, | |
| "timestamp": str(uuid.uuid4()) | |
| } | |
| return self.add_documents([memory_doc], [metadata]) | |
| except Exception as e: | |
| logger.error(f"Error adding conversation memory: {e}") | |
| return False | |
| def search_conversations(self, query: str, session_id: Optional[str] = None) -> List[str]: | |
| """ | |
| Search previous conversations | |
| """ | |
| where_clause = {"type": "conversation"} | |
| if session_id: | |
| where_clause["session_id"] = session_id | |
| results = self.search_similar(query, n_results=3, where=where_clause) | |
| return results['documents'] | |
| def get_collection_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get statistics about the collection | |
| """ | |
| try: | |
| count = self.collection.count() | |
| return { | |
| "total_documents": count, | |
| "collection_name": Settings.COLLECTION_NAME | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting collection stats: {e}") | |
| return {"total_documents": 0, "collection_name": "unknown"} | |
| def delete_documents(self, ids: List[str]) -> bool: | |
| """ | |
| Delete documents by IDs | |
| """ | |
| try: | |
| self.collection.delete(ids=ids) | |
| logger.info(f"Deleted {len(ids)} documents") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error deleting documents: {e}") | |
| return False | |
| def clear_collection(self) -> bool: | |
| """ | |
| Clear all documents from the collection | |
| """ | |
| try: | |
| # Delete the collection and recreate it | |
| self.client.delete_collection(Settings.COLLECTION_NAME) | |
| self.initialize_collection() | |
| logger.info("Cleared all documents from collection") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error clearing collection: {e}") | |
| return False | |
| def create_specialized_collection(self, name: str, description: str) -> bool: | |
| """ | |
| Create a specialized collection for specific domains | |
| """ | |
| try: | |
| collection = self.client.get_or_create_collection( | |
| name=name, | |
| metadata={"description": description} | |
| ) | |
| logger.info(f"Created specialized collection: {name}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error creating specialized collection: {e}") | |
| return False | |
| def switch_collection(self, name: str) -> bool: | |
| """ | |
| Switch to a different collection | |
| """ | |
| try: | |
| self.collection = self.client.get_collection(name=name) | |
| logger.info(f"Switched to collection: {name}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error switching to collection {name}: {e}") | |
| return False | |
| def list_collections(self) -> List[str]: | |
| """ | |
| List all available collections | |
| """ | |
| try: | |
| collections = self.client.list_collections() | |
| return [col.name for col in collections] | |
| except Exception as e: | |
| logger.error(f"Error listing collections: {e}") | |
| return [] |