File size: 3,406 Bytes
305c138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdd4a60
305c138
 
 
 
 
 
 
 
 
 
ea1910c
305c138
 
 
 
 
ea1910c
305c138
 
 
 
 
 
 
 
 
 
 
 
ea1910c
 
305c138
ea1910c
305c138
 
 
 
 
ea1910c
 
305c138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import re

from loguru import logger

from scientific_rag.application.rag.llm_client import llm_client
from scientific_rag.domain.queries import ExpandedQuery, QueryFilters
from scientific_rag.settings import settings


class QueryProcessor:
    def __init__(self, expand_to_n: int = 3):
        self.expand_to_n = expand_to_n

    def process(self, query: str, use_expansion: bool = True, extract_filters: bool = True) -> ExpandedQuery:
        if extract_filters:
            cleaned_query, filters = self._extract_filters(query)
        else:
            cleaned_query = query
            filters = QueryFilters(source="any", section="any")

        variations = []
        if use_expansion and self.expand_to_n > 1:
            variations = self._expand_query(cleaned_query)

        logger.info(
            f"Processed query: '{query}' -> '{cleaned_query}' | Filters: {filters} | Expansion: {len(variations)} vars"
        )

        return ExpandedQuery(original=cleaned_query, variations=variations, filters=filters)

    def _extract_filters(self, query: str) -> tuple[str, QueryFilters]:
        source = "any"
        section = "any"

        query_lower = query.lower()

        # 1. Detect Source (without modifying query)
        if "arxiv" in query_lower:
            source = "arxiv"
        elif "pubmed" in query_lower:
            source = "pubmed"

        # 2. Detect Section (without modifying query)
        section_patterns = {
            "introduction": [r"introduction", r"intro"],
            "methods": [r"methods", r"methodology", r"experiment setup"],
            "results": [r"results", r"findings", r"performance"],
            "conclusion": [r"conclusion", r"summary", r"discussion"],
        }

        found_section = False
        for sec_name, patterns in section_patterns.items():
            if found_section:
                break
            for pattern in patterns:
                # We keep the regex check to ensure we match the specific context
                # (e.g. "in methods section") rather than just the word appearing randomly.
                full_pattern = rf"\b(in|from|check|read)?\s*(the)?\s*{pattern}\s*(section)?\b"

                if re.search(full_pattern, query_lower):
                    section = sec_name
                    found_section = True
                    break

        # Return the original 'query' unmodified
        return query, QueryFilters(source=source, section=section)

    def _expand_query(self, query: str) -> list[str]:
        if not settings.llm_api_key:
            logger.warning("No LLM API Key set. Skipping expansion.")
            return []

        prompt = f"""
        Generate {self.expand_to_n - 1} different search queries for a scientific database based on the input.
        The variations should capture the same technical intent but use alternative terminology or keywords.

        Output ONLY the variations, separated by "###". Do not number them.

        Input: {query}
        """

        try:
            content = llm_client.generate(prompt=prompt)
            variations = [v.strip() for v in content.split("###") if v.strip()]
            final_variations = [v for v in variations if v.lower() != query.lower()]
            return final_variations[: self.expand_to_n - 1]

        except Exception as e:
            logger.error(f"Query expansion failed: {e}")
            return []