ai-reasoning-copilot / models /vector_store.py
faisalsns's picture
Initial commit for the ai-reasoning-copilot
b1f00a0
import chromadb
from chromadb.config import Settings as ChromaSettings
from sentence_transformers import SentenceTransformer
import logging
from typing import List, Dict, Any, Optional
import uuid
from config.settings import Settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VectorStore:
def __init__(self):
self.client = chromadb.PersistentClient(
path=Settings.CHROMA_PERSIST_DIR,
settings=ChromaSettings(anonymized_telemetry=False)
)
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.collection = None
self.initialize_collection()
def initialize_collection(self):
"""
Initialize or get the main knowledge base collection
"""
try:
self.collection = self.client.get_or_create_collection(
name=Settings.COLLECTION_NAME,
metadata={"description": "General knowledge base for reasoning copilot"}
)
logger.info(f"Initialized collection: {Settings.COLLECTION_NAME}")
except Exception as e:
logger.error(f"Error initializing collection: {e}")
raise
def add_documents(self, documents: List[str], metadata: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None) -> bool:
"""
Add documents to the vector store
"""
try:
if not documents:
return False
# Generate IDs if not provided
if ids is None:
ids = [str(uuid.uuid4()) for _ in documents]
# Generate embeddings
embeddings = self.embedding_model.encode(documents).tolist()
# Prepare metadata
if metadata is None:
metadata = [{"source": "user_upload", "type": "document"} for _ in documents]
# Add to collection
self.collection.add(
documents=documents,
embeddings=embeddings,
metadatas=metadata,
ids=ids
)
logger.info(f"Added {len(documents)} documents to vector store")
return True
except Exception as e:
logger.error(f"Error adding documents: {e}")
return False
def search_similar(self, query: str, n_results: int = 5,
where: Optional[Dict] = None) -> Dict[str, Any]:
"""
Search for similar documents
"""
try:
# Generate query embedding
query_embedding = self.embedding_model.encode([query]).tolist()[0]
# Search
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
include=['documents', 'metadatas', 'distances']
)
return {
'documents': results['documents'][0] if results['documents'] else [],
'metadatas': results['metadatas'][0] if results['metadatas'] else [],
'distances': results['distances'][0] if results['distances'] else [],
'count': len(results['documents'][0]) if results['documents'] else 0
}
except Exception as e:
logger.error(f"Error searching documents: {e}")
return {'documents': [], 'metadatas': [], 'distances': [], 'count': 0}
def get_relevant_context(self, query: str, max_context_length: int = 2000) -> str:
"""
Get relevant context for a query, formatted for LLM consumption
"""
results = self.search_similar(query, n_results=5)
if not results['documents']:
return ""
context_parts = []
current_length = 0
for i, (doc, metadata) in enumerate(zip(results['documents'], results['metadatas'])):
# Create a context snippet
source = metadata.get('source', 'Unknown')
snippet = f"Source: {source}\nContent: {doc[:500]}...\n"
if current_length + len(snippet) > max_context_length:
break
context_parts.append(snippet)
current_length += len(snippet)
return "\n---\n".join(context_parts)
def add_conversation_memory(self, user_input: str, assistant_response: str, session_id: str):
"""
Add conversation exchange to memory
"""
try:
memory_doc = f"User: {user_input}\nAssistant: {assistant_response}"
metadata = {
"type": "conversation",
"session_id": session_id,
"timestamp": str(uuid.uuid4())
}
return self.add_documents([memory_doc], [metadata])
except Exception as e:
logger.error(f"Error adding conversation memory: {e}")
return False
def search_conversations(self, query: str, session_id: Optional[str] = None) -> List[str]:
"""
Search previous conversations
"""
where_clause = {"type": "conversation"}
if session_id:
where_clause["session_id"] = session_id
results = self.search_similar(query, n_results=3, where=where_clause)
return results['documents']
def get_collection_stats(self) -> Dict[str, Any]:
"""
Get statistics about the collection
"""
try:
count = self.collection.count()
return {
"total_documents": count,
"collection_name": Settings.COLLECTION_NAME
}
except Exception as e:
logger.error(f"Error getting collection stats: {e}")
return {"total_documents": 0, "collection_name": "unknown"}
def delete_documents(self, ids: List[str]) -> bool:
"""
Delete documents by IDs
"""
try:
self.collection.delete(ids=ids)
logger.info(f"Deleted {len(ids)} documents")
return True
except Exception as e:
logger.error(f"Error deleting documents: {e}")
return False
def clear_collection(self) -> bool:
"""
Clear all documents from the collection
"""
try:
# Delete the collection and recreate it
self.client.delete_collection(Settings.COLLECTION_NAME)
self.initialize_collection()
logger.info("Cleared all documents from collection")
return True
except Exception as e:
logger.error(f"Error clearing collection: {e}")
return False
def create_specialized_collection(self, name: str, description: str) -> bool:
"""
Create a specialized collection for specific domains
"""
try:
collection = self.client.get_or_create_collection(
name=name,
metadata={"description": description}
)
logger.info(f"Created specialized collection: {name}")
return True
except Exception as e:
logger.error(f"Error creating specialized collection: {e}")
return False
def switch_collection(self, name: str) -> bool:
"""
Switch to a different collection
"""
try:
self.collection = self.client.get_collection(name=name)
logger.info(f"Switched to collection: {name}")
return True
except Exception as e:
logger.error(f"Error switching to collection {name}: {e}")
return False
def list_collections(self) -> List[str]:
"""
List all available collections
"""
try:
collections = self.client.list_collections()
return [col.name for col in collections]
except Exception as e:
logger.error(f"Error listing collections: {e}")
return []