Spaces:
Sleeping
Sleeping
| """Hybrid retriever: BM25 (sparse) + FAISS/BGE (dense) with Reciprocal Rank Fusion.""" | |
| import json | |
| import logging | |
| import re | |
| import faiss | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| logger = logging.getLogger(__name__) | |
| def _tokenize(text: str) -> list[str]: | |
| return re.findall(r"\w+", text.lower()) | |
| def reciprocal_rank_fusion( | |
| ranked_lists: list[list[int]], k: int = 60 | |
| ) -> list[tuple[int, float]]: | |
| scores: dict[int, float] = {} | |
| for ranked in ranked_lists: | |
| for rank, idx in enumerate(ranked): | |
| scores[idx] = scores.get(idx, 0.0) + 1.0 / (k + rank + 1) | |
| return sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
| class Retriever: | |
| def __init__( | |
| self, | |
| faiss_index_path: str = "faiss.index", | |
| chunks_meta_path: str = "chunks_meta.jsonl", | |
| embedding_model: str = "BAAI/bge-small-en-v1.5", | |
| top_k: int = 5, | |
| ): | |
| self.top_k = top_k | |
| logger.info("Loading embedding model: %s", embedding_model) | |
| self.embed_model = SentenceTransformer(embedding_model) | |
| logger.info("Loading FAISS index: %s", faiss_index_path) | |
| self.index = faiss.read_index(faiss_index_path) | |
| logger.info("Loading chunk metadata: %s", chunks_meta_path) | |
| self.chunks: list[dict] = [] | |
| with open(chunks_meta_path, encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| self.chunks.append(json.loads(line)) | |
| logger.info("Building BM25 index over %d chunks...", len(self.chunks)) | |
| corpus_tokens = [_tokenize(c["text"]) for c in self.chunks] | |
| self.bm25 = BM25Okapi(corpus_tokens) | |
| logger.info("Retriever ready: %d vectors, %d chunks", self.index.ntotal, len(self.chunks)) | |
| def retrieve(self, query: str, top_k: int | None = None) -> list[dict]: | |
| k = top_k or self.top_k | |
| candidates_k = min(k * 20, self.index.ntotal) | |
| dense_ranked = self._dense_search(query, candidates_k) | |
| sparse_ranked = self._sparse_search(query, candidates_k) | |
| fused = reciprocal_rank_fusion([dense_ranked, sparse_ranked]) | |
| results = [] | |
| for idx, rrf_score in fused: | |
| if idx < 0 or idx >= len(self.chunks): | |
| continue | |
| chunk = self.chunks[idx].copy() | |
| chunk["score"] = float(rrf_score) | |
| results.append(chunk) | |
| for r in results: | |
| if r.get("is_faq"): | |
| r["score"] = r["score"] * 1.2 | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| return results[:k] | |
| def _dense_search(self, query: str, k: int) -> list[int]: | |
| prefixed = f"Represent this sentence for searching relevant passages: {query}" | |
| qvec = self.embed_model.encode([prefixed], normalize_embeddings=True) | |
| qvec = np.array(qvec, dtype=np.float32) | |
| scores, indices = self.index.search(qvec, k) | |
| return [int(i) for i in indices[0] if i >= 0] | |
| def _sparse_search(self, query: str, k: int) -> list[int]: | |
| tokens = _tokenize(query) | |
| if not tokens: | |
| return [] | |
| bm25_scores = self.bm25.get_scores(tokens) | |
| top_indices = np.argsort(bm25_scores)[::-1][:k] | |
| return [int(i) for i in top_indices if bm25_scores[i] > 0] | |
| def format_context(self, results: list[dict]) -> str: | |
| parts = [] | |
| for i, r in enumerate(reversed(results), 1): | |
| source_label = f"[{r['source'].upper()}]" if r.get("source") else "" | |
| title_label = f" - {r['title']}" if r.get("title") else "" | |
| parts.append(f"--- Source {i} {source_label}{title_label} ---\n{r['text']}") | |
| return "\n\n".join(parts) | |
| def format_sources_markdown(self, results: list[dict]) -> str: | |
| if not results: | |
| return "" | |
| lines = ["\n---\n**Sources:**"] | |
| for i, r in enumerate(results, 1): | |
| tag = "FAQ" if r.get("is_faq") else r.get("source", "").upper() | |
| title = r.get("title", "Untitled")[:80] | |
| score = r.get("score", 0) | |
| preview = r["text"][:150].replace("\n", " ") | |
| lines.append(f"{i}. **[{tag}]** {title} (score: {score:.4f})\n _{preview}..._") | |
| return "\n".join(lines) | |