"""Optimized retriever with pattern-based expansion and cross-encoder reranking.""" from __future__ import annotations import hashlib import re from typing import Any, Dict, List, Optional, Tuple from langchain.schema import Document from src.config import get_logger, log_step logger = get_logger(__name__) class OptimizedRetriever: """Fast retriever without LLM calls for expansion/reranking. Uses pattern-based query expansion and cross-encoder reranking instead of LLM calls for faster retrieval. """ EXPANSION_PATTERNS = { "budget": ["cost", "investment", "TIV", "capex", "funding", "allocation", "financial"], "location": ["site", "address", "city", "country", "region", "plant", "facility"], "timeline": ["schedule", "milestone", "deadline", "completion", "duration", "phase"], "challenge": ["risk", "issue", "constraint", "problem", "delay", "obstacle", "barrier"], "project": ["plant", "facility", "refinery", "station", "development"], "status": ["progress", "state", "condition", "update"], } def __init__( self, vector_store: Any, reranker: Optional[Any] = None, k_initial: int = 12, k_final: int = 6, use_expansion: bool = True, use_reranking: bool = True, use_cache: bool = True, ) -> None: self.vector_store = vector_store self.k_initial = k_initial self.k_final = k_final self.use_expansion = use_expansion self.use_reranking = use_reranking self.use_cache = use_cache self._cache: Dict[str, List[Document]] = {} self._reranker = reranker self._reranker_loaded = reranker is not None def _get_reranker(self) -> Optional[Any]: if self._reranker_loaded: return self._reranker try: from src.services.reranker import get_reranker self._reranker = get_reranker("fast") self._reranker_loaded = True logger.info("Loaded cross-encoder reranker") except Exception as e: logger.warning(f"Could not load reranker: {e}") self._reranker = None self._reranker_loaded = True return self._reranker def _cache_key(self, query: str) -> str: return hashlib.md5(query.lower().strip().encode()).hexdigest() def _expand_query_fast(self, query: str) -> List[str]: queries = [query] query_lower = query.lower() for keyword, expansions in self.EXPANSION_PATTERNS.items(): if keyword in query_lower: for exp in expansions[:2]: if exp.lower() not in query_lower: variation = re.sub( rf'\b{keyword}\b', exp, query, flags=re.IGNORECASE ) if variation != query and variation not in queries: queries.append(variation) break return queries[:3] def _reciprocal_rank_fusion( self, result_lists: List[List[Tuple[Document, float]]], k: int = 60, ) -> List[Document]: doc_scores: Dict[str, Dict[str, Any]] = {} for results in result_lists: for rank, (doc, _) in enumerate(results): doc_id = hashlib.md5(doc.page_content[:200].encode()).hexdigest() if doc_id not in doc_scores: doc_scores[doc_id] = {"doc": doc, "score": 0} doc_scores[doc_id]["score"] += 1.0 / (k + rank + 1) sorted_items = sorted( doc_scores.values(), key=lambda x: x["score"], reverse=True, ) return [item["doc"] for item in sorted_items] def retrieve(self, question: str) -> List[Document]: with log_step(logger, "Optimized retrieval"): if self.use_cache: cache_key = self._cache_key(question) if cache_key in self._cache: logger.info("Cache hit - returning cached results") return self._cache[cache_key] if self.use_expansion: queries = self._expand_query_fast(question) logger.substep(f"Expanded to {len(queries)} queries") else: queries = [question] all_results: List[List[Tuple[Document, float]]] = [] for i, query in enumerate(queries): try: if hasattr(self.vector_store, 'similarity_search_with_score'): results = self.vector_store.similarity_search_with_score( query, k=self.k_initial ) else: docs = self.vector_store.similarity_search( query, k=self.k_initial ) results = [(doc, 1.0 - j * 0.01) for j, doc in enumerate(docs)] all_results.append(results) except Exception as e: logger.warning(f"Query {i+1} failed: {e}") if not all_results: logger.warning("No results from any query") return [] if len(all_results) > 1: fused_docs = self._reciprocal_rank_fusion(all_results) else: fused_docs = [doc for doc, _ in all_results[0]] fused_docs = fused_docs[:self.k_initial] logger.substep(f"Fused to {len(fused_docs)} documents") if self.use_reranking and len(fused_docs) > self.k_final: reranker = self._get_reranker() if reranker: with log_step(logger, "Cross-encoder reranking"): fused_docs = reranker.rerank(question, fused_docs, self.k_final) final_docs = fused_docs[:self.k_final] if self.use_cache: self._cache[cache_key] = final_docs logger.info(f"Returning {len(final_docs)} documents") return final_docs def clear_cache(self) -> None: self._cache.clear() def get_cache_stats(self) -> Dict[str, int]: return {"cached_queries": len(self._cache)}