File size: 8,139 Bytes
fd50325 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """
Caption Search Module for DetectifAI
This module provides caption-based search functionality using FAISS index
and MongoDB for retrieving video descriptions based on text queries.
"""
import os
import json
import logging
import numpy as np
import faiss
from typing import List, Dict, Optional, Tuple
from pymongo import MongoClient
from dotenv import load_dotenv
# Optional import for sentence transformers
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
logging.warning("sentence-transformers not available - caption search will not work")
load_dotenv()
logger = logging.getLogger(__name__)
# Paths for FAISS index and id map
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
FAISS_INDEX_PATH = os.path.join(BASE_DIR, "faiss_captions.index")
FAISS_IDMAP_PATH = os.path.join(BASE_DIR, "faiss_captions_idmap.json")
# MongoDB connection
MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017/detectifai")
# Embedding model name
EMBEDDING_MODEL = "all-mpnet-base-v2"
EMBEDDING_DIM = 768 # Dimension for all-mpnet-base-v2
class CaptionSearchEngine:
"""Search engine for caption-based video search using FAISS"""
def __init__(self):
"""Initialize the caption search engine"""
self.faiss_index = None
self.id_map = {} # Maps FAISS index -> description_id
self.embedding_model = None
self.mongo_client = None
self.db = None
self.collection = None
# Initialize components
self._load_faiss_index()
self._load_embedding_model()
self._connect_mongodb()
def _load_faiss_index(self):
"""Load FAISS index and id map from disk"""
try:
if os.path.exists(FAISS_INDEX_PATH):
self.faiss_index = faiss.read_index(FAISS_INDEX_PATH)
logger.info(f"β
Loaded FAISS index from {FAISS_INDEX_PATH}")
logger.info(f" Index size: {self.faiss_index.ntotal} vectors")
else:
logger.warning(f"β οΈ FAISS index not found at {FAISS_INDEX_PATH}")
return
if os.path.exists(FAISS_IDMAP_PATH):
with open(FAISS_IDMAP_PATH, 'r', encoding='utf-8') as f:
id_map_list = json.load(f)
# Convert list to dict: index -> description_id
self.id_map = {i: desc_id for i, desc_id in enumerate(id_map_list)}
logger.info(f"β
Loaded FAISS id map from {FAISS_IDMAP_PATH}")
logger.info(f" Mapped {len(self.id_map)} indices")
else:
logger.warning(f"β οΈ FAISS id map not found at {FAISS_IDMAP_PATH}")
except Exception as e:
logger.error(f"β Error loading FAISS index: {e}")
self.faiss_index = None
def _load_embedding_model(self):
"""Load sentence transformer model for generating query embeddings"""
if not SENTENCE_TRANSFORMERS_AVAILABLE:
logger.warning("β οΈ sentence-transformers not available - cannot generate embeddings")
return
try:
logger.info(f"Loading embedding model: {EMBEDDING_MODEL}...")
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
logger.info(f"β
Loaded embedding model: {EMBEDDING_MODEL}")
except Exception as e:
logger.error(f"β Error loading embedding model: {e}")
self.embedding_model = None
def _connect_mongodb(self):
"""Connect to MongoDB"""
try:
self.mongo_client = MongoClient(MONGO_URI)
self.db = self.mongo_client.get_default_database()
self.collection = self.db["event_descriptions"]
logger.info("β
Connected to MongoDB")
except Exception as e:
logger.error(f"β Error connecting to MongoDB: {e}")
self.mongo_client = None
def is_ready(self) -> bool:
"""Check if the search engine is ready to use"""
return (
self.faiss_index is not None and
self.embedding_model is not None and
self.mongo_client is not None and
self.faiss_index.ntotal > 0
)
def search(self, query_text: str, top_k: int = 10, min_score: float = 0.0) -> List[Dict]:
"""
Search for captions similar to the query text
Args:
query_text: Text query to search for
top_k: Number of results to return
min_score: Minimum similarity score threshold
Returns:
List of result dictionaries with caption, video reference, and similarity score
"""
if not self.is_ready():
logger.warning("β οΈ Search engine not ready - missing components")
return []
try:
# Generate query embedding
query_embedding = self.embedding_model.encode(
query_text,
normalize_embeddings=True,
show_progress_bar=False
).astype("float32")
# Reshape for FAISS (1, dim)
query_embedding = query_embedding.reshape(1, -1)
# Search FAISS index
k = min(top_k, self.faiss_index.ntotal)
scores, indices = self.faiss_index.search(query_embedding, k)
# Process results
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0 or idx not in self.id_map:
continue
if score < min_score:
continue
description_id = self.id_map[idx]
# Fetch document from MongoDB
doc = self.collection.find_one(
{"description_id": description_id},
{"_id": 0}
)
if doc:
result = {
"description_id": doc.get("description_id"),
"event_id": doc.get("event_id"),
"caption": doc.get("caption"),
"confidence": doc.get("confidence", 0.0),
"similarity_score": float(score),
"video_reference": doc.get("video_reference", {}),
"created_at": doc.get("created_at").isoformat() if doc.get("created_at") else None
}
results.append(result)
logger.info(f"β
Found {len(results)} results for query: '{query_text[:50]}...'")
return results
except Exception as e:
logger.error(f"β Error during search: {e}")
return []
def get_stats(self) -> Dict:
"""Get statistics about the search engine"""
return {
"faiss_index_loaded": self.faiss_index is not None,
"faiss_index_size": self.faiss_index.ntotal if self.faiss_index else 0,
"id_map_size": len(self.id_map),
"embedding_model_loaded": self.embedding_model is not None,
"embedding_model": EMBEDDING_MODEL if self.embedding_model else None,
"embedding_dim": EMBEDDING_DIM,
"mongodb_connected": self.mongo_client is not None,
"ready": self.is_ready()
}
# Global instance
_caption_search_engine = None
def get_caption_search_engine() -> CaptionSearchEngine:
"""Get the global caption search engine instance"""
global _caption_search_engine
if _caption_search_engine is None:
_caption_search_engine = CaptionSearchEngine()
return _caption_search_engine
|