import os import asyncio import json from google import genai from google.genai import types from dotenv import load_dotenv from backend.core.prompts import ( SYSTEM_INSTRUCTION, INTENT_DETECTION_PROMPT, DATA_DISCOVERY_PROMPT, SQL_GENERATION_PROMPT, EXPLANATION_PROMPT, SPATIAL_SQL_PROMPT, SPATIAL_SQL_PROMPT, SQL_CORRECTION_PROMPT, LAYER_NAME_PROMPT ) class LLMGateway: def __init__(self, model_name: str = "gemini-3-flash-preview"): # Load environment variables if not already loaded load_dotenv() self.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") if not self.api_key: print("WARNING: GEMINI_API_KEY/GOOGLE_API_KEY not found. LLM features will not work.") self.client = None else: # Explicitly setting the environment variable for the SDK if it's not set if "GEMINI_API_KEY" not in os.environ and self.api_key: os.environ["GEMINI_API_KEY"] = self.api_key # The SDK automatically picks up GEMINI_API_KEY self.client = genai.Client() self.model = model_name def _build_contents_from_history(self, history: list[dict], current_message: str) -> list: """ Converts conversation history to the format expected by the Gemini API. History format: [{"role": "user"|"assistant", "content": "..."}] """ contents = [] for msg in history: # Map 'assistant' to 'model' for Gemini API role = "model" if msg["role"] == "assistant" else "user" contents.append( types.Content( role=role, parts=[types.Part.from_text(text=msg["content"])] ) ) # Add the current message contents.append( types.Content( role="user", parts=[types.Part.from_text(text=current_message)] ) ) return contents async def generate_response_stream(self, user_query: str, history: list[dict] = None): """ Generates a streaming response using conversation history for context. Yields chunks of text and thought summaries. """ if not self.client: yield "I couldn't generate a response because the API key is missing." return if history is None: history = [] try: contents = self._build_contents_from_history(history, user_query) # Enable thinking mode for general chat as well config = types.GenerateContentConfig( system_instruction=SYSTEM_INSTRUCTION, thinking_config=types.ThinkingConfig( include_thoughts=True # Enable thought summaries ) ) stream = await asyncio.to_thread( self.client.models.generate_content_stream, model=self.model, contents=contents, config=config, ) for chunk in stream: for part in chunk.candidates[0].content.parts: if part.thought: yield {"type": "thought", "content": part.text} elif part.text: yield {"type": "content", "text": part.text} except Exception as e: print(f"Error calling Gemini stream: {e}") yield f"Error: {str(e)}" async def generate_response(self, user_query: str, history: list[dict] = None) -> str: """ Generates a response using conversation history for context. """ if not self.client: return "I couldn't generate a response because the API key is missing." if history is None: history = [] try: contents = self._build_contents_from_history(history, user_query) config = types.GenerateContentConfig( system_instruction=SYSTEM_INSTRUCTION, ) response = await asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=contents, config=config, ) return response.text except Exception as e: print(f"Error calling Gemini: {e}") return f"I encountered an error: {e}" async def detect_intent(self, user_query: str, history: list[dict] = None) -> str: """ Detects the intent of the user's query using Gemini thinking mode. Returns: GENERAL_CHAT, DATA_QUERY, MAP_REQUEST, SPATIAL_OP, or STAT_QUERY """ if not self.client: return "GENERAL_CHAT" intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) try: # Use thinking mode for better intent classification config = types.GenerateContentConfig( thinking_config=types.ThinkingConfig( thinking_level="medium" # Balanced thinking for intent detection ) ) response = await asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=intent_prompt, config=config, ) intent = response.text.strip().upper() # Validate the intent if intent in ["GENERAL_CHAT", "DATA_QUERY", "MAP_REQUEST", "SPATIAL_OP", "STAT_QUERY"]: return intent # Default fallback return "GENERAL_CHAT" except Exception as e: print(f"Error detecting intent: {e}") return "GENERAL_CHAT" async def stream_intent(self, user_query: str, history: list[dict] = None): """ Streams intent detection, yielding thoughts. """ if not self.client: yield {"type": "error", "text": "API Key missing"} return intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) try: config = types.GenerateContentConfig( thinking_config=types.ThinkingConfig( thinking_level="medium", include_thoughts=True ) ) stream = await asyncio.to_thread( self.client.models.generate_content_stream, model=self.model, contents=intent_prompt, config=config, ) for chunk in stream: for part in chunk.candidates[0].content.parts: if part.thought: yield {"type": "thought", "text": part.text} elif part.text: yield {"type": "content", "text": part.text} except Exception as e: print(f"Error detecting intent: {e}") yield {"type": "error", "text": str(e)} # Legacy generate_sql removed. async def identify_relevant_tables(self, user_query: str, table_summaries: str) -> list[str]: """ Identifies which tables are relevant for the user's query from the catalog summary. Returns a JSON list of table names. """ if not self.client: return [] prompt = DATA_DISCOVERY_PROMPT.format(user_query=user_query, table_summaries=table_summaries) try: config = types.GenerateContentConfig( response_mime_type="application/json" ) response = await asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=prompt, config=config, ) text = response.text.replace("```json", "").replace("```", "").strip() tables = json.loads(text) return tables if isinstance(tables, list) else [] except Exception as e: print(f"Error identifying tables: {e}") return [] async def generate_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None) -> str: """ Generates a DuckDB SQL query for analytical/statistical questions about geographic data. This is the core of the text-to-SQL system. """ if not self.client: return "-- Error: API Key missing" prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) try: # Use thinking mode for complex SQL generation config = types.GenerateContentConfig( temperature=1, thinking_config=types.ThinkingConfig( thinking_level="high" # Maximum reasoning for SQL generation ) ) response = await asyncio.wait_for( asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=prompt, config=config, ), timeout=120.0 ) sql = response.text.replace("```sql", "").replace("```", "").strip() # Basic validation: must start with SELECT if not sql.upper().strip().startswith("SELECT") and "-- ERROR" not in sql: print(f"Warning: Generated SQL doesn't start with SELECT: {sql[:100]}") if "SELECT" in sql.upper(): start_idx = sql.upper().find("SELECT") sql = sql[start_idx:] return sql except asyncio.TimeoutError: print("Gemini API call timed out after 30 seconds") return "-- Error: API call timed out. Please try again." except Exception as e: print(f"Error calling Gemini for analytical SQL: {e}") return f"-- Error generating SQL: {str(e)}" async def stream_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None): """ Streams the generation of DuckDB SQL, yielding thoughts and chunks. """ if not self.client: yield {"type": "error", "text": "API Key missing"} return prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) try: config = types.GenerateContentConfig( temperature=1, thinking_config=types.ThinkingConfig( thinking_level="high", include_thoughts=True ) ) stream = await asyncio.to_thread( self.client.models.generate_content_stream, model=self.model, contents=prompt, config=config, ) for chunk in stream: for part in chunk.candidates[0].content.parts: if part.thought: yield {"type": "thought", "text": part.text} elif part.text: yield {"type": "content", "text": part.text} except Exception as e: print(f"Error streaming SQL: {e}") yield {"type": "error", "text": str(e)} async def stream_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None): """ Streams the explanation. """ if not self.client: yield {"type": "error", "text": "API Key missing"} return # Build context from history if available context_str = "" if history: context_str = "Previous conversation context:\n" for msg in history[-4:]: # Last 4 messages for context context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) try: config = types.GenerateContentConfig( system_instruction=SYSTEM_INSTRUCTION, thinking_config=types.ThinkingConfig( thinking_level="low", include_thoughts=True ) ) stream = await asyncio.to_thread( self.client.models.generate_content_stream, model=self.model, contents=prompt, config=config, ) for chunk in stream: for part in chunk.candidates[0].content.parts: if part.thought: yield {"type": "thought", "text": part.text} elif part.text: yield {"type": "content", "text": part.text} except Exception as e: print(f"Error generating explanation: {e}") yield {"type": "error", "text": str(e)} async def generate_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None) -> str: """ Explains the results of the query to the user, maintaining conversation context. """ if not self.client: return "I couldn't generate an explanation because the API key is missing." # Build context from history if available context_str = "" if history: context_str = "Previous conversation context:\n" for msg in history[-4:]: # Last 4 messages for context context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) try: config = types.GenerateContentConfig( system_instruction=SYSTEM_INSTRUCTION, thinking_config=types.ThinkingConfig( thinking_level="low" # Fast response for explanations ) ) response = await asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=prompt, config=config, ) return response.text except Exception as e: print(f"Error generating explanation: {e}") return "Here are the results from the query." async def generate_spatial_sql(self, user_query: str, layer_context: str, history: list[dict] = None) -> str: """ Generates a DuckDB Spatial SQL query for geometric operations on layers. """ if not self.client: return "-- Error: API Key missing" prompt = SPATIAL_SQL_PROMPT.format(layer_context=layer_context, user_query=user_query) try: config = types.GenerateContentConfig( temperature=1, ) # Add timeout to prevent indefinite hangs response = await asyncio.wait_for( asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=prompt, config=config, ), timeout=120.0 ) sql = response.text.replace("```sql", "").replace("```", "").strip() return sql except asyncio.TimeoutError: print("Gemini API call timed out after 30 seconds") return "-- Error: API call timed out. Please try again." except Exception as e: print(f"Error calling Gemini: {e}") return f"-- Error generating SQL: {str(e)}" async def correct_sql(self, user_query: str, incorrect_sql: str, error_message: str, schema_context: str) -> str: """ Corrects a failed SQL query based on the error message. """ if not self.client: return "-- Error: API Key missing" prompt = SQL_CORRECTION_PROMPT.format( error_message=error_message, incorrect_sql=incorrect_sql, user_query=user_query, schema_context=schema_context ) try: config = types.GenerateContentConfig( temperature=1, ) response = await asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=prompt, config=config, ) sql = response.text.replace("```sql", "").replace("```", "").strip() return sql except Exception as e: print(f"Error correcting SQL: {e}") return incorrect_sql async def generate_layer_name(self, user_query: str, sql_query: str) -> dict: """ Generates a short, descriptive name, emoji, and point style for a map layer. Returns: {"name": str, "emoji": str, "pointStyle": str | None} """ if not self.client: return {"name": "New Layer", "emoji": "📍", "pointStyle": None} prompt = LAYER_NAME_PROMPT.format(user_query=user_query, sql_query=sql_query) try: config = types.GenerateContentConfig( temperature=1, response_mime_type="application/json" ) # Use simple generate content (not streaming) response = await asyncio.to_thread( self.client.models.generate_content, model=self.model, contents=prompt, config=config, ) result = json.loads(response.text) return { "name": result.get("name", "Map Layer"), "emoji": result.get("emoji", "📍"), "pointStyle": result.get("pointStyle", None) } except Exception as e: print(f"Error generating layer name: {e}") return {"name": "Map Layer", "emoji": "📍", "pointStyle": None}