DetectifAI-Backend / DetectifAI_db /vector_index.py
blacksinisterx's picture
fix: keyframe images, video clips, evidence images, live stream webcam+URL, remove demo mode
fd50325 verified
import faiss
import numpy as np
from pymongo import MongoClient
import os
from dotenv import load_dotenv
import pickle
from typing import List, Dict, Tuple, Optional
import logging
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class FAISSIndexManager:
"""Manages FAISS indices for text and visual embeddings"""
def __init__(self, mongo_uri: str, db_name: str = None):
self.mongo_client = MongoClient(mongo_uri)
self.db = self.mongo_client.get_default_database() if not db_name else self.mongo_client[db_name]
# Collection references
self.event_descriptions = self.db.event_description
self.events = self.db.event
# FAISS indices
self.text_index = None
self.visual_index = None
# Index metadata
self.text_index_metadata = {} # Maps FAISS ID to MongoDB document ID
self.visual_index_metadata = {} # Maps FAISS ID to MongoDB document ID
# Embedding dimensions (adjust based on your embedding model)
self.text_embedding_dim = 384 # Common for sentence-transformers
self.visual_embedding_dim = 512 # Common for visual embeddings
# Index file paths
self.text_index_path = "faiss_text_index.bin"
self.visual_index_path = "faiss_visual_index.bin"
self.text_metadata_path = "faiss_text_metadata.pkl"
self.visual_metadata_path = "faiss_visual_metadata.pkl"
self._initialize_indices()
def _initialize_indices(self):
"""Initialize or load existing FAISS indices"""
try:
# Try to load existing indices
if os.path.exists(self.text_index_path) and os.path.exists(self.text_metadata_path):
self._load_text_index()
logger.info("Loaded existing text index")
else:
self._create_text_index()
logger.info("Created new text index")
if os.path.exists(self.visual_index_path) and os.path.exists(self.visual_metadata_path):
self._load_visual_index()
logger.info("Loaded existing visual index")
else:
self._create_visual_index()
logger.info("Created new visual index")
except Exception as e:
logger.error(f"Error initializing indices: {e}")
# Fallback to creating new indices
self._create_text_index()
self._create_visual_index()
def _create_text_index(self):
"""Create a new FAISS index for text embeddings"""
self.text_index = faiss.IndexFlatIP(self.text_embedding_dim) # Inner product for cosine similarity
self.text_index_metadata = {}
self._save_text_index()
def _create_visual_index(self):
"""Create a new FAISS index for visual embeddings"""
self.visual_index = faiss.IndexFlatIP(self.visual_embedding_dim) # Inner product for cosine similarity
self.visual_index_metadata = {}
self._save_visual_index()
def _load_text_index(self):
"""Load text index from disk"""
self.text_index = faiss.read_index(self.text_index_path)
with open(self.text_metadata_path, 'rb') as f:
self.text_index_metadata = pickle.load(f)
def _load_visual_index(self):
"""Load visual index from disk"""
self.visual_index = faiss.read_index(self.visual_index_path)
with open(self.visual_metadata_path, 'rb') as f:
self.visual_index_metadata = pickle.load(f)
def _save_text_index(self):
"""Save text index to disk"""
if self.text_index is not None:
faiss.write_index(self.text_index, self.text_index_path)
with open(self.text_metadata_path, 'wb') as f:
pickle.dump(self.text_index_metadata, f)
def _save_visual_index(self):
"""Save visual index to disk"""
if self.visual_index is not None:
faiss.write_index(self.visual_index, self.visual_index_path)
with open(self.visual_metadata_path, 'wb') as f:
pickle.dump(self.visual_index_metadata, f)
def rebuild_text_index(self):
"""Rebuild text index from MongoDB data"""
logger.info("Rebuilding text index from MongoDB...")
# Create new index
self._create_text_index()
# Fetch all event descriptions with embeddings
cursor = self.event_descriptions.find(
{"text_embedding": {"$exists": True, "$ne": []}},
{"_id": 0, "description_id": 1, "text_embedding": 1}
)
embeddings = []
metadata = {}
for doc in cursor:
embedding = np.array(doc["text_embedding"], dtype=np.float32)
if len(embedding) == self.text_embedding_dim:
faiss_id = len(embeddings)
embeddings.append(embedding)
metadata[faiss_id] = doc["description_id"]
if embeddings:
embeddings_array = np.vstack(embeddings)
self.text_index.add(embeddings_array)
self.text_index_metadata = metadata
self._save_text_index()
logger.info(f"Rebuilt text index with {len(embeddings)} embeddings")
else:
logger.warning("No text embeddings found in MongoDB")
def rebuild_visual_index(self):
"""Rebuild visual index from MongoDB data"""
logger.info("Rebuilding visual index from MongoDB...")
# Create new index
self._create_visual_index()
# Fetch all events with visual embeddings
cursor = self.events.find(
{"visual_embedding": {"$exists": True, "$ne": []}},
{"_id": 0, "event_id": 1, "visual_embedding": 1}
)
embeddings = []
metadata = {}
for doc in cursor:
embedding = np.array(doc["visual_embedding"], dtype=np.float32)
if len(embedding) == self.visual_embedding_dim:
faiss_id = len(embeddings)
embeddings.append(embedding)
metadata[faiss_id] = doc["event_id"]
if embeddings:
embeddings_array = np.vstack(embeddings)
self.visual_index.add(embeddings_array)
self.visual_index_metadata = metadata
self._save_visual_index()
logger.info(f"Rebuilt visual index with {len(embeddings)} embeddings")
else:
logger.warning("No visual embeddings found in MongoDB")
def add_text_embedding(self, description_id: str, embedding: List[float]) -> bool:
"""Add a text embedding to the index"""
try:
embedding_array = np.array(embedding, dtype=np.float32).reshape(1, -1)
if embedding_array.shape[1] != self.text_embedding_dim:
logger.error(f"Text embedding dimension mismatch: expected {self.text_embedding_dim}, got {embedding_array.shape[1]}")
return False
faiss_id = self.text_index.ntotal
self.text_index.add(embedding_array)
self.text_index_metadata[faiss_id] = description_id
self._save_text_index()
logger.info(f"Added text embedding for description_id: {description_id}")
return True
except Exception as e:
logger.error(f"Error adding text embedding: {e}")
return False
def add_visual_embedding(self, event_id: str, embedding: List[float]) -> bool:
"""Add a visual embedding to the index"""
try:
embedding_array = np.array(embedding, dtype=np.float32).reshape(1, -1)
if embedding_array.shape[1] != self.visual_embedding_dim:
logger.error(f"Visual embedding dimension mismatch: expected {self.visual_embedding_dim}, got {embedding_array.shape[1]}")
return False
faiss_id = self.visual_index.ntotal
self.visual_index.add(embedding_array)
self.visual_index_metadata[faiss_id] = event_id
self._save_visual_index()
logger.info(f"Added visual embedding for event_id: {event_id}")
return True
except Exception as e:
logger.error(f"Error adding visual embedding: {e}")
return False
def search_text_embeddings(self, query_embedding: List[float], k: int = 10) -> List[Dict]:
"""Search for similar text embeddings"""
try:
if self.text_index.ntotal == 0:
return []
query_array = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
if query_array.shape[1] != self.text_embedding_dim:
logger.error(f"Query embedding dimension mismatch: expected {self.text_embedding_dim}, got {query_array.shape[1]}")
return []
# Search FAISS
scores, indices = self.text_index.search(query_array, min(k, self.text_index.ntotal))
# Fetch corresponding documents from MongoDB
results = []
for score, idx in zip(scores[0], indices[0]):
if idx in self.text_index_metadata:
description_id = self.text_index_metadata[idx]
doc = self.event_descriptions.find_one(
{"description_id": description_id},
{"_id": 0}
)
if doc:
doc["similarity_score"] = float(score)
results.append(doc)
return results
except Exception as e:
logger.error(f"Error searching text embeddings: {e}")
return []
def search_visual_embeddings(self, query_embedding: List[float], k: int = 10) -> List[Dict]:
"""Search for similar visual embeddings"""
try:
if self.visual_index.ntotal == 0:
return []
query_array = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
if query_array.shape[1] != self.visual_embedding_dim:
logger.error(f"Query embedding dimension mismatch: expected {self.visual_embedding_dim}, got {query_array.shape[1]}")
return []
# Search FAISS
scores, indices = self.visual_index.search(query_array, min(k, self.visual_index.ntotal))
# Fetch corresponding documents from MongoDB
results = []
for score, idx in zip(scores[0], indices[0]):
if idx in self.visual_index_metadata:
event_id = self.visual_index_metadata[idx]
doc = self.events.find_one(
{"event_id": event_id},
{"_id": 0}
)
if doc:
doc["similarity_score"] = float(score)
results.append(doc)
return results
except Exception as e:
logger.error(f"Error searching visual embeddings: {e}")
return []
def get_index_stats(self) -> Dict:
"""Get statistics about the indices"""
return {
"text_index_size": self.text_index.ntotal if self.text_index else 0,
"visual_index_size": self.visual_index.ntotal if self.visual_index else 0,
"text_embedding_dim": self.text_embedding_dim,
"visual_embedding_dim": self.visual_embedding_dim
}
def close(self):
"""Close the index manager and save indices"""
self._save_text_index()
self._save_visual_index()
self.mongo_client.close()
# Global instance
faiss_manager = None
def get_faiss_manager() -> FAISSIndexManager:
"""Get the global FAISS manager instance"""
global faiss_manager
if faiss_manager is None:
mongo_uri = os.getenv("MONGO_URI")
faiss_manager = FAISSIndexManager(mongo_uri)
return faiss_manager
def generate_text_embedding(text: str) -> List[float]:
"""
Generate text embeddings using SentenceTransformer.
Uses all-mpnet-base-v2 for compatibility with NLP search (query_retreival.py).
Model is lazy-loaded and cached on first call.
"""
global _text_embedding_model
if '_text_embedding_model' not in globals() or _text_embedding_model is None:
try:
from sentence_transformers import SentenceTransformer
_text_embedding_model = SentenceTransformer('all-mpnet-base-v2')
logger.info("✅ Loaded SentenceTransformer (all-mpnet-base-v2) for text embeddings")
except Exception as e:
logger.error(f"Failed to load SentenceTransformer: {e}")
# Fallback to deterministic random for graceful degradation
np.random.seed(hash(text) % 2**32)
return np.random.randn(768).astype(np.float32).tolist()
try:
embedding = _text_embedding_model.encode(text, normalize_embeddings=True)
return embedding.astype(np.float32).tolist()
except Exception as e:
logger.error(f"Failed to generate embedding for text: {e}")
np.random.seed(hash(text) % 2**32)
return np.random.randn(768).astype(np.float32).tolist()
# Global model cache
_text_embedding_model = None
def generate_visual_embedding(image_data: bytes = None) -> List[float]:
"""
Placeholder function to generate visual embeddings.
Replace this with your actual visual embedding model.
"""
# For now, return a random embedding of the correct dimension
# In production, use a proper visual embedding model
np.random.seed(42) # Fixed seed for demo
return np.random.randn(512).astype(np.float32).tolist()