Spaces:
Sleeping
Sleeping
| """FAISS vector store with hybrid (vector + BM25) search.""" | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import re | |
| import threading | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import faiss | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from config import settings | |
| logger = logging.getLogger(__name__) | |
| EMBED_DIM = 2304 | |
| _TOKEN_RE = re.compile(r"\w+") | |
| def _tokenize(text: str) -> List[str]: | |
| return _TOKEN_RE.findall(text.lower()) | |
| def _normalize_vector(vector: List[float]) -> np.ndarray: | |
| arr = np.array(vector, dtype=np.float32).reshape(1, -1) | |
| faiss.normalize_L2(arr) | |
| return arr[0] | |
| def _min_max_normalize(scores: Dict[int, float]) -> Dict[int, float]: | |
| if not scores: | |
| return {} | |
| values = list(scores.values()) | |
| lo, hi = min(values), max(values) | |
| if hi - lo < 1e-9: | |
| return {idx: 1.0 for idx in scores} | |
| return {idx: (score - lo) / (hi - lo) for idx, score in scores.items()} | |
| class FaissDB: | |
| """Local FAISS index with chunk metadata and hybrid retrieval.""" | |
| def __init__(self): | |
| self.data_dir = Path(settings.FAISS_DATA_DIR) | |
| self.data_dir.mkdir(parents=True, exist_ok=True) | |
| self.index_file = self.data_dir / "index.faiss" | |
| self.meta_file = self.data_dir / "metadata.json" | |
| self.vectors_file = self.data_dir / "vectors.npy" | |
| self._lock = threading.Lock() | |
| self.index: faiss.IndexFlatIP = faiss.IndexFlatIP(EMBED_DIM) | |
| self.metadata: List[Dict[str, Any]] = [] | |
| self.vectors = np.zeros((0, EMBED_DIM), dtype=np.float32) | |
| self._bm25: Optional[BM25Okapi] = None | |
| self._load() | |
| self._sync_index() | |
| def _load(self) -> None: | |
| if self.meta_file.exists(): | |
| self.metadata = json.loads(self.meta_file.read_text(encoding="utf-8")) | |
| if self.vectors_file.exists(): | |
| self.vectors = np.load(self.vectors_file) | |
| elif self.index_file.exists(): | |
| self.index = faiss.read_index(str(self.index_file)) | |
| self._rebuild_bm25() | |
| def _persist(self) -> None: | |
| np.save(self.vectors_file, self.vectors) | |
| faiss.write_index(self.index, str(self.index_file)) | |
| self.meta_file.write_text(json.dumps(self.metadata), encoding="utf-8") | |
| def _sync_index(self) -> None: | |
| self.index = faiss.IndexFlatIP(EMBED_DIM) | |
| if len(self.vectors): | |
| self.index.add(self.vectors) | |
| self._rebuild_bm25() | |
| def _rebuild_bm25(self) -> None: | |
| corpus = [_tokenize(chunk.get("text", "")) for chunk in self.metadata] | |
| self._bm25 = BM25Okapi(corpus) if corpus else None | |
| def upsert_chunks(self, chunks: List[Dict], vectors: List[List[float]]) -> None: | |
| if not chunks: | |
| return | |
| now = datetime.now(timezone.utc).isoformat() | |
| normalized = np.vstack([_normalize_vector(vector) for vector in vectors]) | |
| with self._lock: | |
| for chunk in chunks: | |
| if "created_at" not in chunk or not chunk["created_at"]: | |
| chunk["created_at"] = now | |
| if len(self.vectors): | |
| self.vectors = np.vstack([self.vectors, normalized]) | |
| else: | |
| self.vectors = normalized | |
| self.metadata.extend(chunks) | |
| self.index.add(normalized) | |
| self._rebuild_bm25() | |
| self._persist() | |
| logger.info("Stored %d chunks (total %d)", len(chunks), len(self.metadata)) | |
| def _active_indices( | |
| self, document_ids: Optional[List[str]] = None | |
| ) -> List[int]: | |
| indices = list(range(len(self.metadata))) | |
| if document_ids: | |
| allowed = set(document_ids) | |
| indices = [ | |
| i | |
| for i in indices | |
| if self.metadata[i].get("document_id") in allowed | |
| ] | |
| return indices | |
| def hybrid_search( | |
| self, | |
| query_vector: List[float], | |
| query_text: str, | |
| top_k: int = 6, | |
| document_ids: Optional[List[str]] = None, | |
| alpha: Optional[float] = None, | |
| ) -> List[Dict]: | |
| blend = alpha if alpha is not None else settings.HYBRID_ALPHA | |
| active = self._active_indices(document_ids) | |
| if not active: | |
| return [] | |
| query_norm = _normalize_vector(query_vector) | |
| vec_scores = { | |
| idx: float(np.dot(query_norm, self.vectors[idx])) for idx in active | |
| } | |
| vec_norm = _min_max_normalize(vec_scores) | |
| bm25_norm: Dict[int, float] = {} | |
| if self._bm25 is not None and query_text.strip(): | |
| tokens = _tokenize(query_text) | |
| raw_bm25 = self._bm25.get_scores(tokens) | |
| bm25_scores = {idx: float(raw_bm25[idx]) for idx in active} | |
| bm25_norm = _min_max_normalize(bm25_scores) | |
| combined = { | |
| idx: blend * vec_norm.get(idx, 0.0) | |
| + (1.0 - blend) * bm25_norm.get(idx, 0.0) | |
| for idx in active | |
| } | |
| ranked = sorted(combined.items(), key=lambda item: item[1], reverse=True)[ | |
| :top_k | |
| ] | |
| results: List[Dict] = [] | |
| for idx, score in ranked: | |
| chunk = self.metadata[idx] | |
| results.append( | |
| { | |
| "text": chunk.get("text", ""), | |
| "document_name": chunk.get("document_name", ""), | |
| "document_id": chunk.get("document_id", ""), | |
| "page_number": chunk.get("page_number", 0), | |
| "section": chunk.get("section", ""), | |
| "score": score, | |
| } | |
| ) | |
| return results | |
| def fetch_chunks_by_document_id( | |
| self, document_id: str, limit: int = 100 | |
| ) -> List[Dict]: | |
| chunks = [ | |
| { | |
| "text": chunk.get("text", ""), | |
| "document_name": chunk.get("document_name", ""), | |
| "document_id": chunk.get("document_id", ""), | |
| "page_number": chunk.get("page_number", 0), | |
| "section": chunk.get("section", ""), | |
| "chunk_index": chunk.get("chunk_index", 0), | |
| } | |
| for chunk in self.metadata | |
| if chunk.get("document_id") == document_id | |
| ] | |
| chunks.sort(key=lambda c: (c.get("page_number", 0), c.get("chunk_index", 0))) | |
| return chunks[:limit] | |
| def delete_document(self, document_id: str) -> None: | |
| with self._lock: | |
| keep = [ | |
| i | |
| for i, chunk in enumerate(self.metadata) | |
| if chunk.get("document_id") != document_id | |
| ] | |
| if len(keep) == len(self.metadata): | |
| return | |
| self.metadata = [self.metadata[i] for i in keep] | |
| self.vectors = self.vectors[keep] if len(keep) else np.zeros( | |
| (0, EMBED_DIM), dtype=np.float32 | |
| ) | |
| self._sync_index() | |
| self._persist() | |
| logger.info("Deleted document %s", document_id) | |
| def list_documents(self) -> List[Dict[str, Any]]: | |
| docs: Dict[str, Dict[str, Any]] = {} | |
| for chunk in self.metadata: | |
| doc_id = chunk.get("document_id", "") | |
| if not doc_id: | |
| continue | |
| if doc_id not in docs: | |
| docs[doc_id] = { | |
| "document_id": doc_id, | |
| "document_name": chunk.get("document_name", doc_id), | |
| "chunk_count": 0, | |
| "created_at": chunk.get("created_at"), | |
| } | |
| docs[doc_id]["chunk_count"] += 1 | |
| if chunk.get("document_name"): | |
| docs[doc_id]["document_name"] = chunk["document_name"] | |
| return list(docs.values()) | |
| def close(self) -> None: | |
| with self._lock: | |
| self._persist() | |
| logger.info("FAISS store saved and closed.") | |