Spaces:
Sleeping
Sleeping
File size: 3,406 Bytes
305c138 bdd4a60 305c138 ea1910c 305c138 ea1910c 305c138 ea1910c 305c138 ea1910c 305c138 ea1910c 305c138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import re
from loguru import logger
from scientific_rag.application.rag.llm_client import llm_client
from scientific_rag.domain.queries import ExpandedQuery, QueryFilters
from scientific_rag.settings import settings
class QueryProcessor:
def __init__(self, expand_to_n: int = 3):
self.expand_to_n = expand_to_n
def process(self, query: str, use_expansion: bool = True, extract_filters: bool = True) -> ExpandedQuery:
if extract_filters:
cleaned_query, filters = self._extract_filters(query)
else:
cleaned_query = query
filters = QueryFilters(source="any", section="any")
variations = []
if use_expansion and self.expand_to_n > 1:
variations = self._expand_query(cleaned_query)
logger.info(
f"Processed query: '{query}' -> '{cleaned_query}' | Filters: {filters} | Expansion: {len(variations)} vars"
)
return ExpandedQuery(original=cleaned_query, variations=variations, filters=filters)
def _extract_filters(self, query: str) -> tuple[str, QueryFilters]:
source = "any"
section = "any"
query_lower = query.lower()
# 1. Detect Source (without modifying query)
if "arxiv" in query_lower:
source = "arxiv"
elif "pubmed" in query_lower:
source = "pubmed"
# 2. Detect Section (without modifying query)
section_patterns = {
"introduction": [r"introduction", r"intro"],
"methods": [r"methods", r"methodology", r"experiment setup"],
"results": [r"results", r"findings", r"performance"],
"conclusion": [r"conclusion", r"summary", r"discussion"],
}
found_section = False
for sec_name, patterns in section_patterns.items():
if found_section:
break
for pattern in patterns:
# We keep the regex check to ensure we match the specific context
# (e.g. "in methods section") rather than just the word appearing randomly.
full_pattern = rf"\b(in|from|check|read)?\s*(the)?\s*{pattern}\s*(section)?\b"
if re.search(full_pattern, query_lower):
section = sec_name
found_section = True
break
# Return the original 'query' unmodified
return query, QueryFilters(source=source, section=section)
def _expand_query(self, query: str) -> list[str]:
if not settings.llm_api_key:
logger.warning("No LLM API Key set. Skipping expansion.")
return []
prompt = f"""
Generate {self.expand_to_n - 1} different search queries for a scientific database based on the input.
The variations should capture the same technical intent but use alternative terminology or keywords.
Output ONLY the variations, separated by "###". Do not number them.
Input: {query}
"""
try:
content = llm_client.generate(prompt=prompt)
variations = [v.strip() for v in content.split("###") if v.strip()]
final_variations = [v for v in variations if v.lower() != query.lower()]
return final_variations[: self.expand_to_n - 1]
except Exception as e:
logger.error(f"Query expansion failed: {e}")
return []
|