Spaces:
Sleeping
Sleeping
| """ | |
| Reranking module for improving retrieval precision. | |
| Uses cross-encoder models to reorder initial retrieval results | |
| based on query-document relevance. | |
| """ | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from dataclasses import dataclass | |
| # Model cache for lazy loading | |
| _reranker_model = None | |
| _reranker_model_name = None | |
| class RerankResult: | |
| """Result from reranking operation.""" | |
| chunks: List[Dict[str, Any]] | |
| model_used: str | |
| reranked: bool | |
| def _get_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| """ | |
| Lazy load and cache cross-encoder model. | |
| Args: | |
| model_name: HuggingFace model name for cross-encoder | |
| Returns: | |
| CrossEncoder model instance | |
| """ | |
| global _reranker_model, _reranker_model_name | |
| if _reranker_model is None or _reranker_model_name != model_name: | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| _reranker_model = CrossEncoder(model_name) | |
| _reranker_model_name = model_name | |
| except ImportError: | |
| raise ImportError( | |
| "sentence-transformers not installed. " | |
| "Install with: pip install sentence-transformers" | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load cross-encoder model: {e}") | |
| return _reranker_model | |
| def rerank_chunks( | |
| query: str, | |
| chunks: List[Dict[str, Any]], | |
| top_k: int = 5, | |
| model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| ) -> RerankResult: | |
| """ | |
| Rerank chunks using a cross-encoder model. | |
| Cross-encoders process query and document together, enabling | |
| more nuanced relevance scoring than bi-encoders. | |
| Args: | |
| query: User query | |
| chunks: List of chunks to rerank | |
| top_k: Number of top results to return | |
| model_name: Cross-encoder model to use | |
| Returns: | |
| RerankResult with reranked chunks | |
| """ | |
| if not chunks: | |
| return RerankResult(chunks=[], model_used="none", reranked=False) | |
| if len(chunks) <= 1: | |
| return RerankResult(chunks=chunks, model_used="none", reranked=False) | |
| try: | |
| model = _get_cross_encoder(model_name) | |
| # Prepare query-document pairs | |
| pairs = [] | |
| for chunk in chunks: | |
| text = chunk.get("text", "") | |
| if not text and "metadata" in chunk: | |
| text = chunk["metadata"].get("text", "") | |
| pairs.append((query, text)) | |
| # Get relevance scores | |
| scores = model.predict(pairs) | |
| # Combine chunks with scores and sort | |
| scored_chunks = list(zip(chunks, scores)) | |
| scored_chunks.sort(key=lambda x: x[1], reverse=True) | |
| # Build result with rerank scores | |
| results = [] | |
| for chunk, score in scored_chunks[:top_k]: | |
| result_chunk = chunk.copy() | |
| result_chunk["rerank_score"] = float(score) | |
| results.append(result_chunk) | |
| return RerankResult( | |
| chunks=results, | |
| model_used=model_name, | |
| reranked=True | |
| ) | |
| except Exception as e: | |
| # Fallback: return original chunks without reranking | |
| return RerankResult( | |
| chunks=chunks[:top_k], | |
| model_used=f"fallback (error: {str(e)[:50]})", | |
| reranked=False | |
| ) | |
| def rerank_with_llm( | |
| query: str, | |
| chunks: List[Dict[str, Any]], | |
| top_k: int = 5 | |
| ) -> RerankResult: | |
| """ | |
| Rerank chunks using LLM-based scoring (fallback method). | |
| More expensive but works without additional models. | |
| Args: | |
| query: User query | |
| chunks: List of chunks to rerank | |
| top_k: Number of top results to return | |
| Returns: | |
| RerankResult with reranked chunks | |
| """ | |
| if not chunks: | |
| return RerankResult(chunks=[], model_used="none", reranked=False) | |
| if len(chunks) <= 1: | |
| return RerankResult(chunks=chunks, model_used="none", reranked=False) | |
| try: | |
| from src.llm_providers import call_llm | |
| # Build scoring prompt | |
| chunk_texts = [] | |
| for i, chunk in enumerate(chunks): | |
| text = chunk.get("text", "")[:500] # Truncate for prompt size | |
| chunk_texts.append(f"[{i}] {text}") | |
| prompt = f"""Rate the relevance of each document to the query on a scale of 0-10. | |
| Return ONLY a comma-separated list of scores in order (e.g., "8,3,7,5"). | |
| Query: {query} | |
| Documents: | |
| {chr(10).join(chunk_texts)} | |
| Scores:""" | |
| response = call_llm(prompt=prompt, temperature=0.0, max_tokens=100) | |
| scores_text = response.get("text", "").strip() | |
| # Parse scores | |
| try: | |
| scores = [float(s.strip()) for s in scores_text.split(",")] | |
| if len(scores) != len(chunks): | |
| raise ValueError("Score count mismatch") | |
| except (ValueError, AttributeError): | |
| # Fallback to original order | |
| return RerankResult( | |
| chunks=chunks[:top_k], | |
| model_used="llm_parse_failed", | |
| reranked=False | |
| ) | |
| # Sort by scores | |
| scored_chunks = list(zip(chunks, scores)) | |
| scored_chunks.sort(key=lambda x: x[1], reverse=True) | |
| results = [] | |
| for chunk, score in scored_chunks[:top_k]: | |
| result_chunk = chunk.copy() | |
| result_chunk["rerank_score"] = float(score) | |
| results.append(result_chunk) | |
| return RerankResult( | |
| chunks=results, | |
| model_used="llm", | |
| reranked=True | |
| ) | |
| except Exception as e: | |
| return RerankResult( | |
| chunks=chunks[:top_k], | |
| model_used=f"llm_fallback (error: {str(e)[:50]})", | |
| reranked=False | |
| ) | |