File size: 2,642 Bytes
565a379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
import os
import json
import re
from typing import Dict, Any, Optional

from google import genai
from google.genai import types
import pybreaker
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type

logger = logging.getLogger(__name__)

class QueryClassifier:
    """
    Classifies user queries to determine if they require web options
    and what specific information to extract.
    """

    def __init__(self, api_key: Optional[str] = None, model_name: str = "gemini-2.5-flash"):
        self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
        if not self.api_key:
            logger.warning("No API key provided for QueryClassifier.")
        
        self.client = genai.Client(api_key=self.api_key)
        self.model_name = model_name
        
        # Robustness
        self.breaker = pybreaker.CircuitBreaker(fail_max=5, reset_timeout=60)

    @retry(stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=60))
    def classify(self, query: str) -> Dict[str, Any]:
        """
        Classifies the query.
        """
        return self.breaker.call(self._classify_internal, query)

    def _classify_internal(self, query: str) -> Dict[str, Any]:
        prompt = f"""
        Analyze the following user query to determine if it requires external information (web search, live data, specific facts) or if it can be answered by a standard math/logic solver.

        If it requires web search, identify the specific 'extraction_focus' (the exact value or fact needed, e.g., 'stock price', 'release date', 'population').

        Query: "{query}"

        Output JSON format:
        {{
            "requires_web_search": boolean,
            "search_queries": ["list of optimal search queries"],
            "extraction_focus": "keyword or phrase to look for in the page content to find the answer",
            "intent": "general_info" | "specific_value" | "date_lookup"
        }}
        """

        try:
            response = self.client.models.generate_content(
                model=self.model_name,
                contents=[prompt],
                config=types.GenerateContentConfig(
                    response_mime_type="application/json",
                    temperature=0.0
                )
            )
            
            if not response.text:
                return {"requires_web_search": False}

            return json.loads(response.text)
        except Exception as e:
            logger.error(f"Classification failed: {e}")
            # Fail safe to no search
            return {"requires_web_search": False}