import os import pickle from typing import List, Dict, Any, Optional from sentence_transformers import SentenceTransformer, util import numpy as np from datetime import datetime import streamlit as st import torch import json from pathlib import Path class VectorStore: def __init__(self, storage_path: str = None): """Initialize VectorStore with storage management.""" # Handle storage path if storage_path is None: if os.environ.get('SPACE_ID'): storage_path = "/data/vectors" else: storage_path = os.path.join(os.getcwd(), "data", "vectors") self.storage_path = storage_path os.makedirs(storage_path, exist_ok=True) # Initialize the model and vectors self.model = SentenceTransformer('all-MiniLM-L6-v2') self.vectors = [] self._load_vectors() def _load_vectors(self): """Load stored vectors with error handling.""" vector_file = os.path.join(self.storage_path, "vectors.pkl") try: if os.path.exists(vector_file): with open(vector_file, "rb") as f: self.vectors = pickle.load(f) if not isinstance(self.vectors, list): self.vectors = [] except Exception as e: print(f"Error loading vectors: {str(e)}") self.vectors = [] def _save_vectors(self): """Save vectors with error handling.""" vector_file = os.path.join(self.storage_path, "vectors.pkl") temp_file = os.path.join(self.storage_path, "vectors.tmp.pkl") try: # Save to temporary file first with open(temp_file, "wb") as f: pickle.dump(self.vectors, f) # Then rename to final filename (atomic operation) os.replace(temp_file, vector_file) except Exception as e: if os.path.exists(temp_file): os.remove(temp_file) raise Exception(f"Error saving vectors: {str(e)}") def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any] = None): """Add document with enhanced metadata processing.""" try: # Create vector embedding vector = self.model.encode(text, convert_to_tensor=True) # Ensure metadata includes ontology links if not present if metadata and 'ontology_links' not in metadata: metadata['ontology_links'] = [] doc_record = { "doc_id": doc_id, "vector": vector, "text": text, "metadata": metadata or {} } if not isinstance(self.vectors, list): self.vectors = [] self.vectors.append(doc_record) self._save_vectors() except Exception as e: raise Exception(f"Error adding document: {str(e)}") def similarity_search(self, query: str, k: int = 3, filter_docs: Optional[List[str]] = None) -> List[Dict]: """Enhanced similarity search with ontology awareness.""" try: if not self.vectors: return [] # Encode query query_vector = self.model.encode(query, convert_to_tensor=True) # Calculate enhanced similarities results = [] for doc in self.vectors: # Skip if document is filtered out if filter_docs and doc["doc_id"] not in filter_docs: continue try: # Base similarity score base_similarity = util.pytorch_cos_sim(query_vector, doc["vector"]).item() # Calculate ontology boost ontology_boost = self._calculate_ontology_relevance( query, doc.get('metadata', {}).get('ontology_links', []) ) # Final score combining vector similarity and ontology relevance final_score = (base_similarity * 0.7) + (ontology_boost * 0.3) results.append({ "doc_id": doc["doc_id"], "text": doc["text"], "metadata": doc["metadata"], "score": float(final_score), "base_similarity": float(base_similarity), "ontology_boost": float(ontology_boost) }) except Exception as e: print(f"Error processing document: {str(e)}") continue # Sort by final score results.sort(key=lambda x: x["score"], reverse=True) return results[:k] except Exception as e: print(f"Error in similarity search: {str(e)}") return [] def _calculate_ontology_relevance(self, query: str, ontology_links: List[Dict]) -> float: """Calculate ontology-based relevance score.""" if not ontology_links: return 0.0 query_lower = query.lower() relevance_score = 0.0 for link in ontology_links: # Direct concept match if link['concept'].lower() in query_lower: relevance_score += 0.3 # Description match if 'description' in link and any(term in query_lower for term in link['description'].lower().split()): relevance_score += 0.2 # Related concepts match if 'relationships' in link: for related in link['relationships']: if related.lower() in query_lower: relevance_score += 0.1 # Normalize score to [0, 1] return min(1.0, relevance_score) def delete_document(self, doc_id: str) -> bool: """Delete a document from the vector store.""" try: initial_length = len(self.vectors) self.vectors = [doc for doc in self.vectors if doc["doc_id"] != doc_id] self._save_vectors() return len(self.vectors) < initial_length except Exception as e: raise Exception(f"Error deleting document: {str(e)}") def clear(self): """Clear all vectors.""" self.vectors = [] self._save_vectors() def get_document(self, doc_id: str) -> Optional[Dict]: """Retrieve a specific document by ID.""" for doc in self.vectors: if doc["doc_id"] == doc_id: return { "doc_id": doc["doc_id"], "text": doc["text"], "metadata": doc["metadata"] } return None def __len__(self): """Get number of documents in store.""" return len(self.vectors) if self.vectors is not None else 0