Spaces:
Running
Running
File size: 2,642 Bytes
565a379 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | import logging
import os
import json
import re
from typing import Dict, Any, Optional
from google import genai
from google.genai import types
import pybreaker
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
logger = logging.getLogger(__name__)
class QueryClassifier:
"""
Classifies user queries to determine if they require web options
and what specific information to extract.
"""
def __init__(self, api_key: Optional[str] = None, model_name: str = "gemini-2.5-flash"):
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
if not self.api_key:
logger.warning("No API key provided for QueryClassifier.")
self.client = genai.Client(api_key=self.api_key)
self.model_name = model_name
# Robustness
self.breaker = pybreaker.CircuitBreaker(fail_max=5, reset_timeout=60)
@retry(stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=60))
def classify(self, query: str) -> Dict[str, Any]:
"""
Classifies the query.
"""
return self.breaker.call(self._classify_internal, query)
def _classify_internal(self, query: str) -> Dict[str, Any]:
prompt = f"""
Analyze the following user query to determine if it requires external information (web search, live data, specific facts) or if it can be answered by a standard math/logic solver.
If it requires web search, identify the specific 'extraction_focus' (the exact value or fact needed, e.g., 'stock price', 'release date', 'population').
Query: "{query}"
Output JSON format:
{{
"requires_web_search": boolean,
"search_queries": ["list of optimal search queries"],
"extraction_focus": "keyword or phrase to look for in the page content to find the answer",
"intent": "general_info" | "specific_value" | "date_lookup"
}}
"""
try:
response = self.client.models.generate_content(
model=self.model_name,
contents=[prompt],
config=types.GenerateContentConfig(
response_mime_type="application/json",
temperature=0.0
)
)
if not response.text:
return {"requires_web_search": False}
return json.loads(response.text)
except Exception as e:
logger.error(f"Classification failed: {e}")
# Fail safe to no search
return {"requires_web_search": False}
|