Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Reranking module for RAG system. | |
| Handles BM25 reranking with metadata-based scoring. | |
| """ | |
| import re | |
| import time | |
| import logging | |
| from typing import List | |
| from collections import defaultdict | |
| from langchain_core.documents import Document | |
| from langchain_community.retrievers import BM25Retriever | |
| # Metadata field weights for scoring | |
| METADATA_WEIGHTS = { | |
| "title_info_primary_tsi": 1.5, | |
| "name_role_tsim": 1.4, | |
| "date_tsim": 1.3, | |
| "abstract_tsi": 1.0, | |
| "note_tsim": 0.8, | |
| "subject_geographic_sim": 0.5, | |
| "genre_basic_ssim": 0.5, | |
| "genre_specific_ssim": 0.5, | |
| } | |
| def extract_years_from_query(query: str) -> List[str]: | |
| """ | |
| Extract 4-digit years from query string. | |
| Args: | |
| query: User query string | |
| Returns: | |
| List of year strings found in query | |
| """ | |
| return re.findall(r"\b(1[5-9]\d{2}|20\d{2}|21\d{2}|22\d{2}|23\d{2})\b", query) | |
| def rerank(docs: List[Document], query: str, top_k: int = 10) -> List[Document]: | |
| """ | |
| Rerank documents using BM25 and metadata-based scoring. | |
| Process: | |
| 1. Merge chunks by document_id | |
| 2. Apply BM25 lexical reranking | |
| 3. Boost scores based on metadata field presence | |
| 4. Add large boost for exact year matches | |
| 5. Return top-k documents | |
| Args: | |
| docs: List of Document objects to rerank | |
| query: User query string | |
| top_k: Number of top documents to return | |
| Returns: | |
| List of top-k reranked Document objects | |
| """ | |
| if not docs: | |
| logging.warning("⚠️ No documents provided for reranking.") | |
| return [] | |
| logging.info("⚖️ Starting BM25 reranking...") | |
| start = time.time() | |
| # Extract years from query for date matching | |
| query_years = extract_years_from_query(query) | |
| # Group chunks by document_id and merge | |
| grouped = defaultdict(list) | |
| for doc in docs: | |
| grouped[doc.metadata.get("source")].append(doc) | |
| merged_docs = [] | |
| for src, chunks in grouped.items(): | |
| text = " ".join(c.page_content for c in chunks if c.page_content) | |
| merged_docs.append(Document(page_content=text, metadata=chunks[0].metadata)) | |
| # Apply BM25 ranking | |
| bm25 = BM25Retriever.from_documents(merged_docs, k=len(merged_docs)) | |
| ranked = bm25.invoke(query) | |
| # Apply metadata-based scoring | |
| final_ranked = [] | |
| for d in ranked: | |
| score = 1.0 | |
| # Add weight for each present metadata field | |
| for field, weight in METADATA_WEIGHTS.items(): | |
| if field in d.metadata and d.metadata[field]: | |
| score += weight | |
| # Large boost for exact year matches in date field | |
| date_field = str(d.metadata.get("date_tsim", "")) | |
| for y in query_years: | |
| if re.search(rf"\b{y}\b", date_field): | |
| score += 50 | |
| break | |
| final_ranked.append((d, score)) | |
| # Sort by score and return top-k | |
| final_ranked.sort(key=lambda x: x[1], reverse=True) | |
| logging.info(f"✅ Reranked {len(final_ranked)} documents in {time.time() - start:.2f}s.") | |
| return [doc for doc, _ in final_ranked[:top_k]] | |