| import os |
| import asyncio |
| import json |
| from google import genai |
| from google.genai import types |
| from dotenv import load_dotenv |
|
|
| from backend.core.prompts import ( |
| SYSTEM_INSTRUCTION, |
| INTENT_DETECTION_PROMPT, |
| DATA_DISCOVERY_PROMPT, |
| SQL_GENERATION_PROMPT, |
| EXPLANATION_PROMPT, |
| SPATIAL_SQL_PROMPT, |
| SPATIAL_SQL_PROMPT, |
| SQL_CORRECTION_PROMPT, |
| LAYER_NAME_PROMPT |
| ) |
|
|
| class LLMGateway: |
| def __init__(self, model_name: str = "gemini-3-flash-preview"): |
| |
| load_dotenv() |
| |
| self.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") |
| if not self.api_key: |
| print("WARNING: GEMINI_API_KEY/GOOGLE_API_KEY not found. LLM features will not work.") |
| self.client = None |
| else: |
| |
| if "GEMINI_API_KEY" not in os.environ and self.api_key: |
| os.environ["GEMINI_API_KEY"] = self.api_key |
| |
| |
| self.client = genai.Client() |
| |
| self.model = model_name |
|
|
| def _build_contents_from_history(self, history: list[dict], current_message: str) -> list: |
| """ |
| Converts conversation history to the format expected by the Gemini API. |
| History format: [{"role": "user"|"assistant", "content": "..."}] |
| """ |
| contents = [] |
| for msg in history: |
| |
| role = "model" if msg["role"] == "assistant" else "user" |
| contents.append( |
| types.Content( |
| role=role, |
| parts=[types.Part.from_text(text=msg["content"])] |
| ) |
| ) |
| |
| |
| contents.append( |
| types.Content( |
| role="user", |
| parts=[types.Part.from_text(text=current_message)] |
| ) |
| ) |
| return contents |
|
|
| async def generate_response_stream(self, user_query: str, history: list[dict] = None): |
| """ |
| Generates a streaming response using conversation history for context. |
| Yields chunks of text and thought summaries. |
| """ |
| if not self.client: |
| yield "I couldn't generate a response because the API key is missing." |
| return |
| |
| if history is None: |
| history = [] |
| |
| try: |
| contents = self._build_contents_from_history(history, user_query) |
| |
| |
| config = types.GenerateContentConfig( |
| system_instruction=SYSTEM_INSTRUCTION, |
| thinking_config=types.ThinkingConfig( |
| include_thoughts=True |
| ) |
| ) |
| |
| stream = await asyncio.to_thread( |
| self.client.models.generate_content_stream, |
| model=self.model, |
| contents=contents, |
| config=config, |
| ) |
|
|
| for chunk in stream: |
| for part in chunk.candidates[0].content.parts: |
| if part.thought: |
| yield {"type": "thought", "content": part.text} |
| elif part.text: |
| yield {"type": "content", "text": part.text} |
|
|
| except Exception as e: |
| print(f"Error calling Gemini stream: {e}") |
| yield f"Error: {str(e)}" |
|
|
| async def generate_response(self, user_query: str, history: list[dict] = None) -> str: |
| """ |
| Generates a response using conversation history for context. |
| """ |
| if not self.client: |
| return "I couldn't generate a response because the API key is missing." |
| |
| if history is None: |
| history = [] |
| |
| try: |
| contents = self._build_contents_from_history(history, user_query) |
| |
| config = types.GenerateContentConfig( |
| system_instruction=SYSTEM_INSTRUCTION, |
| ) |
| |
| response = await asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=contents, |
| config=config, |
| ) |
| return response.text |
| except Exception as e: |
| print(f"Error calling Gemini: {e}") |
| return f"I encountered an error: {e}" |
|
|
| async def detect_intent(self, user_query: str, history: list[dict] = None) -> str: |
| """ |
| Detects the intent of the user's query using Gemini thinking mode. |
| Returns: GENERAL_CHAT, DATA_QUERY, MAP_REQUEST, SPATIAL_OP, or STAT_QUERY |
| """ |
| if not self.client: |
| return "GENERAL_CHAT" |
| |
| intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) |
|
|
| try: |
| |
| config = types.GenerateContentConfig( |
| thinking_config=types.ThinkingConfig( |
| thinking_level="medium" |
| ) |
| ) |
| |
| response = await asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=intent_prompt, |
| config=config, |
| ) |
| intent = response.text.strip().upper() |
| |
| |
| if intent in ["GENERAL_CHAT", "DATA_QUERY", "MAP_REQUEST", "SPATIAL_OP", "STAT_QUERY"]: |
| return intent |
| |
| |
| return "GENERAL_CHAT" |
| except Exception as e: |
| print(f"Error detecting intent: {e}") |
| return "GENERAL_CHAT" |
|
|
| async def stream_intent(self, user_query: str, history: list[dict] = None): |
| """ |
| Streams intent detection, yielding thoughts. |
| """ |
| if not self.client: |
| yield {"type": "error", "text": "API Key missing"} |
| return |
| |
| intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) |
|
|
| try: |
| config = types.GenerateContentConfig( |
| thinking_config=types.ThinkingConfig( |
| thinking_level="medium", |
| include_thoughts=True |
| ) |
| ) |
| |
| stream = await asyncio.to_thread( |
| self.client.models.generate_content_stream, |
| model=self.model, |
| contents=intent_prompt, |
| config=config, |
| ) |
| |
| for chunk in stream: |
| for part in chunk.candidates[0].content.parts: |
| if part.thought: |
| yield {"type": "thought", "text": part.text} |
| elif part.text: |
| yield {"type": "content", "text": part.text} |
|
|
| except Exception as e: |
| print(f"Error detecting intent: {e}") |
| yield {"type": "error", "text": str(e)} |
|
|
| |
|
|
| async def identify_relevant_tables(self, user_query: str, table_summaries: str) -> list[str]: |
| """ |
| Identifies which tables are relevant for the user's query from the catalog summary. |
| Returns a JSON list of table names. |
| """ |
| if not self.client: |
| return [] |
| |
| prompt = DATA_DISCOVERY_PROMPT.format(user_query=user_query, table_summaries=table_summaries) |
| |
| try: |
| config = types.GenerateContentConfig( |
| response_mime_type="application/json" |
| ) |
| |
| response = await asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ) |
| |
| text = response.text.replace("```json", "").replace("```", "").strip() |
| tables = json.loads(text) |
| return tables if isinstance(tables, list) else [] |
| |
| except Exception as e: |
| print(f"Error identifying tables: {e}") |
| return [] |
|
|
| async def generate_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None) -> str: |
| """ |
| Generates a DuckDB SQL query for analytical/statistical questions about geographic data. |
| This is the core of the text-to-SQL system. |
| """ |
| if not self.client: |
| return "-- Error: API Key missing" |
|
|
| prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) |
|
|
| try: |
| |
| config = types.GenerateContentConfig( |
| temperature=1, |
| thinking_config=types.ThinkingConfig( |
| thinking_level="high" |
| ) |
| ) |
|
|
| response = await asyncio.wait_for( |
| asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ), |
| timeout=120.0 |
| ) |
| |
| sql = response.text.replace("```sql", "").replace("```", "").strip() |
| |
| |
| if not sql.upper().strip().startswith("SELECT") and "-- ERROR" not in sql: |
| print(f"Warning: Generated SQL doesn't start with SELECT: {sql[:100]}") |
| if "SELECT" in sql.upper(): |
| start_idx = sql.upper().find("SELECT") |
| sql = sql[start_idx:] |
| |
| return sql |
|
|
| except asyncio.TimeoutError: |
| print("Gemini API call timed out after 30 seconds") |
| return "-- Error: API call timed out. Please try again." |
| except Exception as e: |
| print(f"Error calling Gemini for analytical SQL: {e}") |
| return f"-- Error generating SQL: {str(e)}" |
|
|
| async def stream_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None): |
| """ |
| Streams the generation of DuckDB SQL, yielding thoughts and chunks. |
| """ |
| if not self.client: |
| yield {"type": "error", "text": "API Key missing"} |
| return |
|
|
| prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) |
|
|
| try: |
| config = types.GenerateContentConfig( |
| temperature=1, |
| thinking_config=types.ThinkingConfig( |
| thinking_level="high", |
| include_thoughts=True |
| ) |
| ) |
|
|
| stream = await asyncio.to_thread( |
| self.client.models.generate_content_stream, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ) |
|
|
| for chunk in stream: |
| for part in chunk.candidates[0].content.parts: |
| if part.thought: |
| yield {"type": "thought", "text": part.text} |
| elif part.text: |
| yield {"type": "content", "text": part.text} |
|
|
| except Exception as e: |
| print(f"Error streaming SQL: {e}") |
| yield {"type": "error", "text": str(e)} |
|
|
| async def stream_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None): |
| """ |
| Streams the explanation. |
| """ |
| if not self.client: |
| yield {"type": "error", "text": "API Key missing"} |
| return |
|
|
| |
| context_str = "" |
| if history: |
| context_str = "Previous conversation context:\n" |
| for msg in history[-4:]: |
| context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" |
|
|
| prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) |
| |
| try: |
| config = types.GenerateContentConfig( |
| system_instruction=SYSTEM_INSTRUCTION, |
| thinking_config=types.ThinkingConfig( |
| thinking_level="low", |
| include_thoughts=True |
| ) |
| ) |
| |
| stream = await asyncio.to_thread( |
| self.client.models.generate_content_stream, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ) |
|
|
| for chunk in stream: |
| for part in chunk.candidates[0].content.parts: |
| if part.thought: |
| yield {"type": "thought", "text": part.text} |
| elif part.text: |
| yield {"type": "content", "text": part.text} |
|
|
| except Exception as e: |
| print(f"Error generating explanation: {e}") |
| yield {"type": "error", "text": str(e)} |
|
|
| async def generate_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None) -> str: |
| """ |
| Explains the results of the query to the user, maintaining conversation context. |
| """ |
| if not self.client: |
| return "I couldn't generate an explanation because the API key is missing." |
|
|
| |
| context_str = "" |
| if history: |
| context_str = "Previous conversation context:\n" |
| for msg in history[-4:]: |
| context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" |
|
|
| prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) |
| |
| try: |
| config = types.GenerateContentConfig( |
| system_instruction=SYSTEM_INSTRUCTION, |
| thinking_config=types.ThinkingConfig( |
| thinking_level="low" |
| ) |
| ) |
| |
| response = await asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ) |
| return response.text |
| except Exception as e: |
| print(f"Error generating explanation: {e}") |
| return "Here are the results from the query." |
|
|
| async def generate_spatial_sql(self, user_query: str, layer_context: str, history: list[dict] = None) -> str: |
| """ |
| Generates a DuckDB Spatial SQL query for geometric operations on layers. |
| """ |
| if not self.client: |
| return "-- Error: API Key missing" |
|
|
| prompt = SPATIAL_SQL_PROMPT.format(layer_context=layer_context, user_query=user_query) |
| |
| try: |
| config = types.GenerateContentConfig( |
| temperature=1, |
| ) |
|
|
| |
| response = await asyncio.wait_for( |
| asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ), |
| timeout=120.0 |
| ) |
| |
| sql = response.text.replace("```sql", "").replace("```", "").strip() |
| return sql |
|
|
| except asyncio.TimeoutError: |
| print("Gemini API call timed out after 30 seconds") |
| return "-- Error: API call timed out. Please try again." |
| except Exception as e: |
| print(f"Error calling Gemini: {e}") |
| return f"-- Error generating SQL: {str(e)}" |
| |
| async def correct_sql(self, user_query: str, incorrect_sql: str, error_message: str, schema_context: str) -> str: |
| """ |
| Corrects a failed SQL query based on the error message. |
| """ |
| if not self.client: |
| return "-- Error: API Key missing" |
|
|
| prompt = SQL_CORRECTION_PROMPT.format( |
| error_message=error_message, |
| incorrect_sql=incorrect_sql, |
| user_query=user_query, |
| schema_context=schema_context |
| ) |
|
|
| try: |
| config = types.GenerateContentConfig( |
| temperature=1, |
| ) |
|
|
| response = await asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ) |
| |
| sql = response.text.replace("```sql", "").replace("```", "").strip() |
| return sql |
|
|
| except Exception as e: |
| print(f"Error correcting SQL: {e}") |
| return incorrect_sql |
| |
| async def generate_layer_name(self, user_query: str, sql_query: str) -> dict: |
| """ |
| Generates a short, descriptive name, emoji, and point style for a map layer. |
| Returns: {"name": str, "emoji": str, "pointStyle": str | None} |
| """ |
| if not self.client: |
| return {"name": "New Layer", "emoji": "π", "pointStyle": None} |
|
|
| prompt = LAYER_NAME_PROMPT.format(user_query=user_query, sql_query=sql_query) |
|
|
| try: |
| config = types.GenerateContentConfig( |
| temperature=1, |
| response_mime_type="application/json" |
| ) |
|
|
| |
| response = await asyncio.to_thread( |
| self.client.models.generate_content, |
| model=self.model, |
| contents=prompt, |
| config=config, |
| ) |
| |
| result = json.loads(response.text) |
| return { |
| "name": result.get("name", "Map Layer"), |
| "emoji": result.get("emoji", "π"), |
| "pointStyle": result.get("pointStyle", None) |
| } |
| except Exception as e: |
| print(f"Error generating layer name: {e}") |
| return {"name": "Map Layer", "emoji": "π", "pointStyle": None} |
|
|