UTN-Student-Chatbot / retriever.py
saeedbenadeeb's picture
Upload retriever.py with huggingface_hub
79ad8ab verified
"""Hybrid retriever: BM25 (sparse) + FAISS/BGE (dense) with Reciprocal Rank Fusion."""
import json
import logging
import re
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
def _tokenize(text: str) -> list[str]:
return re.findall(r"\w+", text.lower())
def reciprocal_rank_fusion(
ranked_lists: list[list[int]], k: int = 60
) -> list[tuple[int, float]]:
scores: dict[int, float] = {}
for ranked in ranked_lists:
for rank, idx in enumerate(ranked):
scores[idx] = scores.get(idx, 0.0) + 1.0 / (k + rank + 1)
return sorted(scores.items(), key=lambda x: x[1], reverse=True)
class Retriever:
def __init__(
self,
faiss_index_path: str = "faiss.index",
chunks_meta_path: str = "chunks_meta.jsonl",
embedding_model: str = "BAAI/bge-small-en-v1.5",
top_k: int = 5,
):
self.top_k = top_k
logger.info("Loading embedding model: %s", embedding_model)
self.embed_model = SentenceTransformer(embedding_model)
logger.info("Loading FAISS index: %s", faiss_index_path)
self.index = faiss.read_index(faiss_index_path)
logger.info("Loading chunk metadata: %s", chunks_meta_path)
self.chunks: list[dict] = []
with open(chunks_meta_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self.chunks.append(json.loads(line))
logger.info("Building BM25 index over %d chunks...", len(self.chunks))
corpus_tokens = [_tokenize(c["text"]) for c in self.chunks]
self.bm25 = BM25Okapi(corpus_tokens)
logger.info("Retriever ready: %d vectors, %d chunks", self.index.ntotal, len(self.chunks))
def retrieve(self, query: str, top_k: int | None = None) -> list[dict]:
k = top_k or self.top_k
candidates_k = min(k * 20, self.index.ntotal)
dense_ranked = self._dense_search(query, candidates_k)
sparse_ranked = self._sparse_search(query, candidates_k)
fused = reciprocal_rank_fusion([dense_ranked, sparse_ranked])
results = []
for idx, rrf_score in fused:
if idx < 0 or idx >= len(self.chunks):
continue
chunk = self.chunks[idx].copy()
chunk["score"] = float(rrf_score)
results.append(chunk)
for r in results:
if r.get("is_faq"):
r["score"] = r["score"] * 1.2
results.sort(key=lambda x: x["score"], reverse=True)
return results[:k]
def _dense_search(self, query: str, k: int) -> list[int]:
prefixed = f"Represent this sentence for searching relevant passages: {query}"
qvec = self.embed_model.encode([prefixed], normalize_embeddings=True)
qvec = np.array(qvec, dtype=np.float32)
scores, indices = self.index.search(qvec, k)
return [int(i) for i in indices[0] if i >= 0]
def _sparse_search(self, query: str, k: int) -> list[int]:
tokens = _tokenize(query)
if not tokens:
return []
bm25_scores = self.bm25.get_scores(tokens)
top_indices = np.argsort(bm25_scores)[::-1][:k]
return [int(i) for i in top_indices if bm25_scores[i] > 0]
def format_context(self, results: list[dict]) -> str:
parts = []
for i, r in enumerate(reversed(results), 1):
source_label = f"[{r['source'].upper()}]" if r.get("source") else ""
title_label = f" - {r['title']}" if r.get("title") else ""
parts.append(f"--- Source {i} {source_label}{title_label} ---\n{r['text']}")
return "\n\n".join(parts)
def format_sources_markdown(self, results: list[dict]) -> str:
if not results:
return ""
lines = ["\n---\n**Sources:**"]
for i, r in enumerate(results, 1):
tag = "FAQ" if r.get("is_faq") else r.get("source", "").upper()
title = r.get("title", "Untitled")[:80]
score = r.get("score", 0)
preview = r["text"][:150].replace("\n", " ")
lines.append(f"{i}. **[{tag}]** {title} (score: {score:.4f})\n _{preview}..._")
return "\n".join(lines)