Spaces:
Build error
Build error
| """ | |
| AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation. | |
| This module contains: | |
| - Retriever: Hybrid BM25 + dense retrieval over parsed chunks | |
| - Reranker: Cross-encoder based re-ranking of candidate chunks | |
| - AnswerGenerator: ties together retrieval, re-ranking, and LLM generation | |
| Each component is modular and can be swapped or extended (e.g., add HyDE retriever). | |
| """ | |
| import os | |
| from typing import List, Dict, Any, Tuple | |
| from sentence_transformers import SentenceTransformer | |
| from rank_bm25 import BM25Okapi | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| from src import sanitize_html | |
| from src.utils import LLMClient, logger | |
| from src.retriever import Retriever, RetrieverConfig | |
| class RerankerConfig: | |
| MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma') | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| class Reranker: | |
| """ | |
| Cross-encoder re-ranker using a transformer-based sequence classification model. | |
| """ | |
| def __init__(self, config: RerankerConfig): | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME) | |
| self.model.to(config.DEVICE) | |
| except Exception as e: | |
| logger.error(f'Failed to load reranker model: {e}') | |
| raise | |
| def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: | |
| """Score each candidate and return top_k sorted by relevance.""" | |
| if not candidates: | |
| logger.warning('No candidates provided to rerank.') | |
| return [] | |
| try: | |
| inputs = self.tokenizer( | |
| [query] * len(candidates), | |
| [c.get('narration', '') for c in candidates], | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(RerankerConfig.DEVICE) | |
| with torch.no_grad(): | |
| out = self.model(**inputs) | |
| logits = out.logits | |
| if logits.ndim == 2 and logits.shape[1] == 1: | |
| logits = logits.squeeze(-1) # only squeeze if it's (batch, 1) | |
| probs = torch.sigmoid(logits).cpu().numpy().flatten() # flatten always ensures 1D array | |
| paired = [] | |
| for idx, c in enumerate(candidates): | |
| score = float(probs[idx]) | |
| paired.append((c, score)) | |
| ranked = sorted(paired, key=lambda x: x[1], reverse=True) | |
| return [c for c, _ in ranked[:top_k]] | |
| except Exception as e: | |
| logger.error(f'Reranking failed: {e}') | |
| return candidates[:top_k] | |
| class AnswerGenerator: | |
| """ | |
| Main interface: initializes Retriever + Reranker once, then | |
| answers multiple questions without re-loading models each time. | |
| """ | |
| def __init__(self, chunks: List[Dict[str, Any]]): | |
| self.chunks = chunks | |
| self.retriever = Retriever(chunks, RetrieverConfig) | |
| self.reranker = Reranker(RerankerConfig) | |
| self.top_k = RetrieverConfig.TOP_K // 2 | |
| def answer( | |
| self, question: str | |
| ) -> Tuple[str, List[Dict[str, Any]]]: | |
| candidates = self.retriever.retrieve(question) | |
| top_chunks = self.reranker.rerank(question, candidates, self.top_k) | |
| context = "\n\n".join(f"- {c['narration']}" for c in top_chunks) | |
| prompt = ( | |
| "You are a knowledgeable assistant. Use the following snippets to answer." | |
| f"\n\nContext information is below: \n" | |
| '------------------------------------' | |
| f"{context}" | |
| '------------------------------------' | |
| "Given the context information above I want you \n" | |
| "to think step by step to answer the query in a crisp \n" | |
| "manner, incase you don't have enough information, \n" | |
| "just say I don't know!. \n\n" | |
| f"\n\nQuestion: {question} \n" | |
| "Answer:" | |
| ) | |
| answer = LLMClient.generate(prompt) | |
| return answer, top_chunks | |