Spaces:
Sleeping
Sleeping
| import re | |
| from loguru import logger | |
| from scientific_rag.application.rag.llm_client import llm_client | |
| from scientific_rag.domain.queries import ExpandedQuery, QueryFilters | |
| from scientific_rag.settings import settings | |
| class QueryProcessor: | |
| def __init__(self, expand_to_n: int = 3): | |
| self.expand_to_n = expand_to_n | |
| def process(self, query: str, use_expansion: bool = True, extract_filters: bool = True) -> ExpandedQuery: | |
| if extract_filters: | |
| cleaned_query, filters = self._extract_filters(query) | |
| else: | |
| cleaned_query = query | |
| filters = QueryFilters(source="any", section="any") | |
| variations = [] | |
| if use_expansion and self.expand_to_n > 1: | |
| variations = self._expand_query(cleaned_query) | |
| logger.info( | |
| f"Processed query: '{query}' -> '{cleaned_query}' | Filters: {filters} | Expansion: {len(variations)} vars" | |
| ) | |
| return ExpandedQuery(original=cleaned_query, variations=variations, filters=filters) | |
| def _extract_filters(self, query: str) -> tuple[str, QueryFilters]: | |
| source = "any" | |
| section = "any" | |
| query_lower = query.lower() | |
| # 1. Detect Source (without modifying query) | |
| if "arxiv" in query_lower: | |
| source = "arxiv" | |
| elif "pubmed" in query_lower: | |
| source = "pubmed" | |
| # 2. Detect Section (without modifying query) | |
| section_patterns = { | |
| "introduction": [r"introduction", r"intro"], | |
| "methods": [r"methods", r"methodology", r"experiment setup"], | |
| "results": [r"results", r"findings", r"performance"], | |
| "conclusion": [r"conclusion", r"summary", r"discussion"], | |
| } | |
| found_section = False | |
| for sec_name, patterns in section_patterns.items(): | |
| if found_section: | |
| break | |
| for pattern in patterns: | |
| # We keep the regex check to ensure we match the specific context | |
| # (e.g. "in methods section") rather than just the word appearing randomly. | |
| full_pattern = rf"\b(in|from|check|read)?\s*(the)?\s*{pattern}\s*(section)?\b" | |
| if re.search(full_pattern, query_lower): | |
| section = sec_name | |
| found_section = True | |
| break | |
| # Return the original 'query' unmodified | |
| return query, QueryFilters(source=source, section=section) | |
| def _expand_query(self, query: str) -> list[str]: | |
| if not settings.llm_api_key: | |
| logger.warning("No LLM API Key set. Skipping expansion.") | |
| return [] | |
| prompt = f""" | |
| Generate {self.expand_to_n - 1} different search queries for a scientific database based on the input. | |
| The variations should capture the same technical intent but use alternative terminology or keywords. | |
| Output ONLY the variations, separated by "###". Do not number them. | |
| Input: {query} | |
| """ | |
| try: | |
| content = llm_client.generate(prompt=prompt) | |
| variations = [v.strip() for v in content.split("###") if v.strip()] | |
| final_variations = [v for v in variations if v.lower() != query.lower()] | |
| return final_variations[: self.expand_to_n - 1] | |
| except Exception as e: | |
| logger.error(f"Query expansion failed: {e}") | |
| return [] | |