Legal_AI_Agent / utils /vector_store.py
cryogenic22's picture
Update utils/vector_store.py
f13c348 verified
import os
import pickle
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer, util
import numpy as np
from datetime import datetime
import streamlit as st
import torch
import json
from pathlib import Path
class VectorStore:
def __init__(self, storage_path: str = None):
"""Initialize VectorStore with storage management."""
# Handle storage path
if storage_path is None:
if os.environ.get('SPACE_ID'):
storage_path = "/data/vectors"
else:
storage_path = os.path.join(os.getcwd(), "data", "vectors")
self.storage_path = storage_path
os.makedirs(storage_path, exist_ok=True)
# Initialize the model and vectors
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.vectors = []
self._load_vectors()
def _load_vectors(self):
"""Load stored vectors with error handling."""
vector_file = os.path.join(self.storage_path, "vectors.pkl")
try:
if os.path.exists(vector_file):
with open(vector_file, "rb") as f:
self.vectors = pickle.load(f)
if not isinstance(self.vectors, list):
self.vectors = []
except Exception as e:
print(f"Error loading vectors: {str(e)}")
self.vectors = []
def _save_vectors(self):
"""Save vectors with error handling."""
vector_file = os.path.join(self.storage_path, "vectors.pkl")
temp_file = os.path.join(self.storage_path, "vectors.tmp.pkl")
try:
# Save to temporary file first
with open(temp_file, "wb") as f:
pickle.dump(self.vectors, f)
# Then rename to final filename (atomic operation)
os.replace(temp_file, vector_file)
except Exception as e:
if os.path.exists(temp_file):
os.remove(temp_file)
raise Exception(f"Error saving vectors: {str(e)}")
def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any] = None):
"""Add document with enhanced metadata processing."""
try:
# Create vector embedding
vector = self.model.encode(text, convert_to_tensor=True)
# Ensure metadata includes ontology links if not present
if metadata and 'ontology_links' not in metadata:
metadata['ontology_links'] = []
doc_record = {
"doc_id": doc_id,
"vector": vector,
"text": text,
"metadata": metadata or {}
}
if not isinstance(self.vectors, list):
self.vectors = []
self.vectors.append(doc_record)
self._save_vectors()
except Exception as e:
raise Exception(f"Error adding document: {str(e)}")
def similarity_search(self, query: str, k: int = 3, filter_docs: Optional[List[str]] = None) -> List[Dict]:
"""Enhanced similarity search with ontology awareness."""
try:
if not self.vectors:
return []
# Encode query
query_vector = self.model.encode(query, convert_to_tensor=True)
# Calculate enhanced similarities
results = []
for doc in self.vectors:
# Skip if document is filtered out
if filter_docs and doc["doc_id"] not in filter_docs:
continue
try:
# Base similarity score
base_similarity = util.pytorch_cos_sim(query_vector, doc["vector"]).item()
# Calculate ontology boost
ontology_boost = self._calculate_ontology_relevance(
query,
doc.get('metadata', {}).get('ontology_links', [])
)
# Final score combining vector similarity and ontology relevance
final_score = (base_similarity * 0.7) + (ontology_boost * 0.3)
results.append({
"doc_id": doc["doc_id"],
"text": doc["text"],
"metadata": doc["metadata"],
"score": float(final_score),
"base_similarity": float(base_similarity),
"ontology_boost": float(ontology_boost)
})
except Exception as e:
print(f"Error processing document: {str(e)}")
continue
# Sort by final score
results.sort(key=lambda x: x["score"], reverse=True)
return results[:k]
except Exception as e:
print(f"Error in similarity search: {str(e)}")
return []
def _calculate_ontology_relevance(self, query: str, ontology_links: List[Dict]) -> float:
"""Calculate ontology-based relevance score."""
if not ontology_links:
return 0.0
query_lower = query.lower()
relevance_score = 0.0
for link in ontology_links:
# Direct concept match
if link['concept'].lower() in query_lower:
relevance_score += 0.3
# Description match
if 'description' in link and any(term in query_lower
for term in link['description'].lower().split()):
relevance_score += 0.2
# Related concepts match
if 'relationships' in link:
for related in link['relationships']:
if related.lower() in query_lower:
relevance_score += 0.1
# Normalize score to [0, 1]
return min(1.0, relevance_score)
def delete_document(self, doc_id: str) -> bool:
"""Delete a document from the vector store."""
try:
initial_length = len(self.vectors)
self.vectors = [doc for doc in self.vectors if doc["doc_id"] != doc_id]
self._save_vectors()
return len(self.vectors) < initial_length
except Exception as e:
raise Exception(f"Error deleting document: {str(e)}")
def clear(self):
"""Clear all vectors."""
self.vectors = []
self._save_vectors()
def get_document(self, doc_id: str) -> Optional[Dict]:
"""Retrieve a specific document by ID."""
for doc in self.vectors:
if doc["doc_id"] == doc_id:
return {
"doc_id": doc["doc_id"],
"text": doc["text"],
"metadata": doc["metadata"]
}
return None
def __len__(self):
"""Get number of documents in store."""
return len(self.vectors) if self.vectors is not None else 0