Spaces:
Build error
Build error
| import os | |
| import pickle | |
| from typing import List, Dict, Any, Optional | |
| from sentence_transformers import SentenceTransformer, util | |
| import numpy as np | |
| from datetime import datetime | |
| import streamlit as st | |
| import torch | |
| import json | |
| from pathlib import Path | |
| class VectorStore: | |
| def __init__(self, storage_path: str = None): | |
| """Initialize VectorStore with storage management.""" | |
| # Handle storage path | |
| if storage_path is None: | |
| if os.environ.get('SPACE_ID'): | |
| storage_path = "/data/vectors" | |
| else: | |
| storage_path = os.path.join(os.getcwd(), "data", "vectors") | |
| self.storage_path = storage_path | |
| os.makedirs(storage_path, exist_ok=True) | |
| # Initialize the model and vectors | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.vectors = [] | |
| self._load_vectors() | |
| def _load_vectors(self): | |
| """Load stored vectors with error handling.""" | |
| vector_file = os.path.join(self.storage_path, "vectors.pkl") | |
| try: | |
| if os.path.exists(vector_file): | |
| with open(vector_file, "rb") as f: | |
| self.vectors = pickle.load(f) | |
| if not isinstance(self.vectors, list): | |
| self.vectors = [] | |
| except Exception as e: | |
| print(f"Error loading vectors: {str(e)}") | |
| self.vectors = [] | |
| def _save_vectors(self): | |
| """Save vectors with error handling.""" | |
| vector_file = os.path.join(self.storage_path, "vectors.pkl") | |
| temp_file = os.path.join(self.storage_path, "vectors.tmp.pkl") | |
| try: | |
| # Save to temporary file first | |
| with open(temp_file, "wb") as f: | |
| pickle.dump(self.vectors, f) | |
| # Then rename to final filename (atomic operation) | |
| os.replace(temp_file, vector_file) | |
| except Exception as e: | |
| if os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| raise Exception(f"Error saving vectors: {str(e)}") | |
| def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any] = None): | |
| """Add document with enhanced metadata processing.""" | |
| try: | |
| # Create vector embedding | |
| vector = self.model.encode(text, convert_to_tensor=True) | |
| # Ensure metadata includes ontology links if not present | |
| if metadata and 'ontology_links' not in metadata: | |
| metadata['ontology_links'] = [] | |
| doc_record = { | |
| "doc_id": doc_id, | |
| "vector": vector, | |
| "text": text, | |
| "metadata": metadata or {} | |
| } | |
| if not isinstance(self.vectors, list): | |
| self.vectors = [] | |
| self.vectors.append(doc_record) | |
| self._save_vectors() | |
| except Exception as e: | |
| raise Exception(f"Error adding document: {str(e)}") | |
| def similarity_search(self, query: str, k: int = 3, filter_docs: Optional[List[str]] = None) -> List[Dict]: | |
| """Enhanced similarity search with ontology awareness.""" | |
| try: | |
| if not self.vectors: | |
| return [] | |
| # Encode query | |
| query_vector = self.model.encode(query, convert_to_tensor=True) | |
| # Calculate enhanced similarities | |
| results = [] | |
| for doc in self.vectors: | |
| # Skip if document is filtered out | |
| if filter_docs and doc["doc_id"] not in filter_docs: | |
| continue | |
| try: | |
| # Base similarity score | |
| base_similarity = util.pytorch_cos_sim(query_vector, doc["vector"]).item() | |
| # Calculate ontology boost | |
| ontology_boost = self._calculate_ontology_relevance( | |
| query, | |
| doc.get('metadata', {}).get('ontology_links', []) | |
| ) | |
| # Final score combining vector similarity and ontology relevance | |
| final_score = (base_similarity * 0.7) + (ontology_boost * 0.3) | |
| results.append({ | |
| "doc_id": doc["doc_id"], | |
| "text": doc["text"], | |
| "metadata": doc["metadata"], | |
| "score": float(final_score), | |
| "base_similarity": float(base_similarity), | |
| "ontology_boost": float(ontology_boost) | |
| }) | |
| except Exception as e: | |
| print(f"Error processing document: {str(e)}") | |
| continue | |
| # Sort by final score | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| return results[:k] | |
| except Exception as e: | |
| print(f"Error in similarity search: {str(e)}") | |
| return [] | |
| def _calculate_ontology_relevance(self, query: str, ontology_links: List[Dict]) -> float: | |
| """Calculate ontology-based relevance score.""" | |
| if not ontology_links: | |
| return 0.0 | |
| query_lower = query.lower() | |
| relevance_score = 0.0 | |
| for link in ontology_links: | |
| # Direct concept match | |
| if link['concept'].lower() in query_lower: | |
| relevance_score += 0.3 | |
| # Description match | |
| if 'description' in link and any(term in query_lower | |
| for term in link['description'].lower().split()): | |
| relevance_score += 0.2 | |
| # Related concepts match | |
| if 'relationships' in link: | |
| for related in link['relationships']: | |
| if related.lower() in query_lower: | |
| relevance_score += 0.1 | |
| # Normalize score to [0, 1] | |
| return min(1.0, relevance_score) | |
| def delete_document(self, doc_id: str) -> bool: | |
| """Delete a document from the vector store.""" | |
| try: | |
| initial_length = len(self.vectors) | |
| self.vectors = [doc for doc in self.vectors if doc["doc_id"] != doc_id] | |
| self._save_vectors() | |
| return len(self.vectors) < initial_length | |
| except Exception as e: | |
| raise Exception(f"Error deleting document: {str(e)}") | |
| def clear(self): | |
| """Clear all vectors.""" | |
| self.vectors = [] | |
| self._save_vectors() | |
| def get_document(self, doc_id: str) -> Optional[Dict]: | |
| """Retrieve a specific document by ID.""" | |
| for doc in self.vectors: | |
| if doc["doc_id"] == doc_id: | |
| return { | |
| "doc_id": doc["doc_id"], | |
| "text": doc["text"], | |
| "metadata": doc["metadata"] | |
| } | |
| return None | |
| def __len__(self): | |
| """Get number of documents in store.""" | |
| return len(self.vectors) if self.vectors is not None else 0 |