Grapherr / db_utils.py
Pragatik771's picture
Upload 13 files
6ea0c39 verified
import os
import logging
import json
import numpy as np
from typing import List, Dict, Any, Optional
from pathlib import Path
import faiss
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass, asdict
import pickle
import logging
from typing import Any, Dict, List, Optional
import os
import sys
import tempfile
from pathlib import Path
import streamlit as st
import pandas as pd
from dotenv import load_dotenv
import numpy as np
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Add __module__ to ensure consistent pickling
class Document:
"""A document with content and optional metadata."""
__module__ = 'db_utils' # Explicitly set module for pickling
def __init__(self, page_content: str, metadata: Dict[str, Any] = None):
self.page_content = page_content
self.metadata = metadata or {}
def __getstate__(self) -> Dict[str, Any]:
return {
'page_content': self.page_content,
'metadata': self.metadata
}
def __setstate__(self, state: Dict[str, Any]) -> None:
self.page_content = state['page_content']
self.metadata = state.get('metadata', {})
def __repr__(self) -> str:
return f"Document(page_content='{self.page_content[:50]}...', metadata={self.metadata})"
def to_dict(self) -> Dict[str, Any]:
"""Convert document to dictionary."""
return {
'page_content': self.page_content,
'metadata': self.metadata
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Document':
"""Create document from dictionary."""
return cls(
page_content=data['page_content'],
metadata=data.get('metadata', {})
)
class VectorStore:
"""FAISS-based vector store for document embeddings."""
def __init__(self, dimension: int = 384, index_path: str = "faiss_index"):
"""Initialize the FAISS index and document store."""
self.dimension = dimension
self.index_path = Path(index_path)
self.documents = []
self.index = None
self._initialize_index()
def _initialize_index(self):
"""Initialize or load the FAISS index."""
if self.index_path.exists():
self._load_index()
else:
# Using FlatL2 for exact search (can be changed to IVFFLAT or HNSW for approximate search)
self.index = faiss.IndexFlatL2(self.dimension)
def _save_index(self):
"""Save the FAISS index and documents to disk."""
self.index_path.mkdir(parents=True, exist_ok=True)
# Save FAISS index
faiss.write_index(self.index, str(self.index_path / "index.faiss"))
# Save documents as list of dicts for better compatibility
documents_data = [doc.to_dict() for doc in self.documents]
with open(self.index_path / "documents.pkl", "wb") as f:
pickle.dump(documents_data, f, protocol=pickle.HIGHEST_PROTOCOL)
def add_documents(self, documents: List[Document], model=None):
"""Add documents to the vector store.
Args:
documents: List of Document objects to add
model: Optional model for encoding documents if they don't have embeddings
"""
if not documents:
return
try:
# Ensure index exists
if self.index is None:
self.index = faiss.IndexFlatL2(self.dimension)
logger.info("Initialized new FAISS index")
# Convert documents to embeddings if needed
embeddings = []
valid_docs = []
for doc in documents:
try:
if hasattr(doc, 'metadata') and 'embedding' in doc.metadata:
# Get embedding from metadata if available
emb = doc.metadata['embedding']
if isinstance(emb, list):
emb = np.array(emb, dtype=np.float32)
embeddings.append(emb)
valid_docs.append(doc)
elif model is not None:
# Generate embedding using the model
emb = model.encode([doc.page_content])[0]
embeddings.append(emb)
# Store the embedding in metadata for future use
if not hasattr(doc, 'metadata') or doc.metadata is None:
doc = Document(doc.page_content, {'embedding': emb.tolist()})
else:
doc.metadata['embedding'] = emb.tolist()
valid_docs.append(doc)
else:
logger.warning("Document has no embedding and no model provided to generate one")
except Exception as e:
logger.error(f"Error processing document: {e}")
if not valid_docs:
raise ValueError("No valid documents to add")
# Convert to numpy array if needed
if isinstance(embeddings, list):
embeddings = np.array(embeddings, dtype=np.float32)
# Add to FAISS index
self.index.add(embeddings)
# Store documents
self.documents.extend(valid_docs)
# Save the updated index
self._save_index()
logger.info(f"Successfully added {len(valid_docs)} documents to the index")
except Exception as e:
logger.error(f"Error adding documents: {str(e)}", exc_info=True)
raise
def _load_index(self):
"""Load the FAISS index and documents from disk if they exist."""
try:
if (self.index_path / "index.faiss").exists():
self.index = faiss.read_index(str(self.index_path / "index.faiss"))
# Load documents if they exist
if (self.index_path / "documents.pkl").exists():
try:
with open(self.index_path / "documents.pkl", "rb") as f:
documents_data = pickle.load(f)
# Convert to Document objects if needed
self.documents = []
for doc in documents_data:
if isinstance(doc, Document):
self.documents.append(doc)
elif isinstance(doc, dict):
self.documents.append(Document(
page_content=doc.get('page_content', ''),
metadata=doc.get('metadata', {})
))
logger.info(f"Loaded {len(self.documents)} documents from index")
except Exception as e:
logger.error(f"Error loading documents: {e}")
self.documents = []
return True
else:
# Initialize a new index if none exists
self.index = faiss.IndexFlatL2(self.dimension)
self.documents = []
logger.info("Initialized new FAISS index")
return False
except Exception as e:
logger.error(f"Error loading index: {e}")
# Fallback to new index on error
self.index = faiss.IndexFlatL2(self.dimension)
self.documents = []
return False
def similarity_search(
self,
query_embedding: np.ndarray,
k: int = 4,
file_hash: Optional[str] = None,
score_threshold: float = 0.0
) -> List[Dict[str, Any]]:
"""Perform similarity search using FAISS."""
if not self.documents:
return []
# Convert query_embedding to numpy array if it's not already
if not isinstance(query_embedding, np.ndarray):
query_embedding = np.array(query_embedding, dtype=np.float32)
# Reshape for single query
if len(query_embedding.shape) == 1:
query_embedding = query_embedding.reshape(1, -1)
# Search the index
distances, indices = self.index.search(query_embedding.astype('float32'), k)
# Debug information
if not hasattr(self, 'documents') or not self.documents:
logging.error("No documents found in vector store")
return []
logging.info(f"Total documents in store: {len(self.documents)}")
logging.info(f"Indices from FAISS search: {indices}")
# Prepare results
results = []
for i, idx in enumerate(indices[0]):
if idx < 0 or idx >= len(self.documents): # Skip invalid indices
logging.warning(f"Skipping invalid document index: {idx}")
continue
try:
doc = self.documents[idx]
# Get metadata (default to empty dict if not present)
metadata = getattr(doc, 'metadata', {}) or {}
doc_file_hash = metadata.get('file_hash')
# Skip if file_hash filter is provided and doesn't match
if file_hash and doc_file_hash != file_hash:
continue
# Convert L2 distance to similarity score (1 / (1 + distance))
distance = float(distances[0][i])
similarity = 1.0 / (1.0 + distance)
if similarity >= score_threshold:
results.append({
'id': metadata.get('id', ''),
'text': getattr(doc, 'page_content', ''),
'metadata': metadata,
'score': similarity,
'file_hash': doc_file_hash
})
except Exception as e:
logging.error(f"Error processing document: {e}")
return results
def create_embeddings(
texts: List[str],
metadatas: List[Dict[str, Any]],
ids: List[str],
file_hash: str,
model: SentenceTransformer,
vector_store: 'VectorStore' # Use string annotation to avoid circular import
):
"""Create and store embeddings using FAISS."""
try:
# Generate embeddings using the provided model
logger.info(f"Generating embeddings for {len(texts)} texts...")
embeddings = model.encode(texts, show_progress_bar=True, batch_size=32)
logger.info("Embeddings generated successfully")
# Normalize embeddings to unit length (important for cosine similarity)
faiss.normalize_L2(embeddings)
# Create documents with embeddings
documents = []
for i, (text, embedding) in enumerate(zip(texts, embeddings)):
# Include the ID in the metadata
metadata = metadatas[i].copy() if i < len(metadatas) else {}
metadata.update({
'id': ids[i],
'file_hash': file_hash,
'embedding': embedding # Store embedding in metadata for now
})
doc = Document(
page_content=text,
metadata=metadata
)
documents.append(doc)
# Add to vector store with the model
vector_store.add_documents(documents, model=model)
logger.info(f"Successfully stored {len(documents)} embeddings in the vector store")
except Exception as e:
logger.error(f"Error creating embeddings: {str(e)}", exc_info=True)
raise
def reset_database():
"""Reset the vector store by deleting the index directory."""
import shutil
if os.path.exists("faiss_index"):
shutil.rmtree("faiss_index")
return True
return False
def get_or_create_collection(file_hash: str, vector_store: VectorStore) -> bool:
"""Check if a collection (file) exists in the vector store."""
# In FAISS, we'll just check if we have any documents with this file_hash
results = vector_store.similarity_search(
query_embedding=np.zeros(vector_store.dimension), # Dummy query
k=30,
file_hash=file_hash
)
return len(results) > 0