FinSightAI / backend /db /faiss_client.py
Aniket2003333333's picture
start
7248d39
Raw
History Blame Contribute Delete
7.9 kB
"""FAISS vector store with hybrid (vector + BM25) search."""
from __future__ import annotations
import json
import logging
import re
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
from config import settings
logger = logging.getLogger(__name__)
EMBED_DIM = 2304
_TOKEN_RE = re.compile(r"\w+")
def _tokenize(text: str) -> List[str]:
return _TOKEN_RE.findall(text.lower())
def _normalize_vector(vector: List[float]) -> np.ndarray:
arr = np.array(vector, dtype=np.float32).reshape(1, -1)
faiss.normalize_L2(arr)
return arr[0]
def _min_max_normalize(scores: Dict[int, float]) -> Dict[int, float]:
if not scores:
return {}
values = list(scores.values())
lo, hi = min(values), max(values)
if hi - lo < 1e-9:
return {idx: 1.0 for idx in scores}
return {idx: (score - lo) / (hi - lo) for idx, score in scores.items()}
class FaissDB:
"""Local FAISS index with chunk metadata and hybrid retrieval."""
def __init__(self):
self.data_dir = Path(settings.FAISS_DATA_DIR)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.index_file = self.data_dir / "index.faiss"
self.meta_file = self.data_dir / "metadata.json"
self.vectors_file = self.data_dir / "vectors.npy"
self._lock = threading.Lock()
self.index: faiss.IndexFlatIP = faiss.IndexFlatIP(EMBED_DIM)
self.metadata: List[Dict[str, Any]] = []
self.vectors = np.zeros((0, EMBED_DIM), dtype=np.float32)
self._bm25: Optional[BM25Okapi] = None
self._load()
self._sync_index()
def _load(self) -> None:
if self.meta_file.exists():
self.metadata = json.loads(self.meta_file.read_text(encoding="utf-8"))
if self.vectors_file.exists():
self.vectors = np.load(self.vectors_file)
elif self.index_file.exists():
self.index = faiss.read_index(str(self.index_file))
self._rebuild_bm25()
def _persist(self) -> None:
np.save(self.vectors_file, self.vectors)
faiss.write_index(self.index, str(self.index_file))
self.meta_file.write_text(json.dumps(self.metadata), encoding="utf-8")
def _sync_index(self) -> None:
self.index = faiss.IndexFlatIP(EMBED_DIM)
if len(self.vectors):
self.index.add(self.vectors)
self._rebuild_bm25()
def _rebuild_bm25(self) -> None:
corpus = [_tokenize(chunk.get("text", "")) for chunk in self.metadata]
self._bm25 = BM25Okapi(corpus) if corpus else None
def upsert_chunks(self, chunks: List[Dict], vectors: List[List[float]]) -> None:
if not chunks:
return
now = datetime.now(timezone.utc).isoformat()
normalized = np.vstack([_normalize_vector(vector) for vector in vectors])
with self._lock:
for chunk in chunks:
if "created_at" not in chunk or not chunk["created_at"]:
chunk["created_at"] = now
if len(self.vectors):
self.vectors = np.vstack([self.vectors, normalized])
else:
self.vectors = normalized
self.metadata.extend(chunks)
self.index.add(normalized)
self._rebuild_bm25()
self._persist()
logger.info("Stored %d chunks (total %d)", len(chunks), len(self.metadata))
def _active_indices(
self, document_ids: Optional[List[str]] = None
) -> List[int]:
indices = list(range(len(self.metadata)))
if document_ids:
allowed = set(document_ids)
indices = [
i
for i in indices
if self.metadata[i].get("document_id") in allowed
]
return indices
def hybrid_search(
self,
query_vector: List[float],
query_text: str,
top_k: int = 6,
document_ids: Optional[List[str]] = None,
alpha: Optional[float] = None,
) -> List[Dict]:
blend = alpha if alpha is not None else settings.HYBRID_ALPHA
active = self._active_indices(document_ids)
if not active:
return []
query_norm = _normalize_vector(query_vector)
vec_scores = {
idx: float(np.dot(query_norm, self.vectors[idx])) for idx in active
}
vec_norm = _min_max_normalize(vec_scores)
bm25_norm: Dict[int, float] = {}
if self._bm25 is not None and query_text.strip():
tokens = _tokenize(query_text)
raw_bm25 = self._bm25.get_scores(tokens)
bm25_scores = {idx: float(raw_bm25[idx]) for idx in active}
bm25_norm = _min_max_normalize(bm25_scores)
combined = {
idx: blend * vec_norm.get(idx, 0.0)
+ (1.0 - blend) * bm25_norm.get(idx, 0.0)
for idx in active
}
ranked = sorted(combined.items(), key=lambda item: item[1], reverse=True)[
:top_k
]
results: List[Dict] = []
for idx, score in ranked:
chunk = self.metadata[idx]
results.append(
{
"text": chunk.get("text", ""),
"document_name": chunk.get("document_name", ""),
"document_id": chunk.get("document_id", ""),
"page_number": chunk.get("page_number", 0),
"section": chunk.get("section", ""),
"score": score,
}
)
return results
def fetch_chunks_by_document_id(
self, document_id: str, limit: int = 100
) -> List[Dict]:
chunks = [
{
"text": chunk.get("text", ""),
"document_name": chunk.get("document_name", ""),
"document_id": chunk.get("document_id", ""),
"page_number": chunk.get("page_number", 0),
"section": chunk.get("section", ""),
"chunk_index": chunk.get("chunk_index", 0),
}
for chunk in self.metadata
if chunk.get("document_id") == document_id
]
chunks.sort(key=lambda c: (c.get("page_number", 0), c.get("chunk_index", 0)))
return chunks[:limit]
def delete_document(self, document_id: str) -> None:
with self._lock:
keep = [
i
for i, chunk in enumerate(self.metadata)
if chunk.get("document_id") != document_id
]
if len(keep) == len(self.metadata):
return
self.metadata = [self.metadata[i] for i in keep]
self.vectors = self.vectors[keep] if len(keep) else np.zeros(
(0, EMBED_DIM), dtype=np.float32
)
self._sync_index()
self._persist()
logger.info("Deleted document %s", document_id)
def list_documents(self) -> List[Dict[str, Any]]:
docs: Dict[str, Dict[str, Any]] = {}
for chunk in self.metadata:
doc_id = chunk.get("document_id", "")
if not doc_id:
continue
if doc_id not in docs:
docs[doc_id] = {
"document_id": doc_id,
"document_name": chunk.get("document_name", doc_id),
"chunk_count": 0,
"created_at": chunk.get("created_at"),
}
docs[doc_id]["chunk_count"] += 1
if chunk.get("document_name"):
docs[doc_id]["document_name"] = chunk["document_name"]
return list(docs.values())
def close(self) -> None:
with self._lock:
self._persist()
logger.info("FAISS store saved and closed.")