GeoQuery / backend /core /llm_gateway.py
GerardCB's picture
Deploy to Spaces (Final Clean)
4851501
import os
import asyncio
import json
from google import genai
from google.genai import types
from dotenv import load_dotenv
from backend.core.prompts import (
SYSTEM_INSTRUCTION,
INTENT_DETECTION_PROMPT,
DATA_DISCOVERY_PROMPT,
SQL_GENERATION_PROMPT,
EXPLANATION_PROMPT,
SPATIAL_SQL_PROMPT,
SPATIAL_SQL_PROMPT,
SQL_CORRECTION_PROMPT,
LAYER_NAME_PROMPT
)
class LLMGateway:
def __init__(self, model_name: str = "gemini-3-flash-preview"):
# Load environment variables if not already loaded
load_dotenv()
self.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
if not self.api_key:
print("WARNING: GEMINI_API_KEY/GOOGLE_API_KEY not found. LLM features will not work.")
self.client = None
else:
# Explicitly setting the environment variable for the SDK if it's not set
if "GEMINI_API_KEY" not in os.environ and self.api_key:
os.environ["GEMINI_API_KEY"] = self.api_key
# The SDK automatically picks up GEMINI_API_KEY
self.client = genai.Client()
self.model = model_name
def _build_contents_from_history(self, history: list[dict], current_message: str) -> list:
"""
Converts conversation history to the format expected by the Gemini API.
History format: [{"role": "user"|"assistant", "content": "..."}]
"""
contents = []
for msg in history:
# Map 'assistant' to 'model' for Gemini API
role = "model" if msg["role"] == "assistant" else "user"
contents.append(
types.Content(
role=role,
parts=[types.Part.from_text(text=msg["content"])]
)
)
# Add the current message
contents.append(
types.Content(
role="user",
parts=[types.Part.from_text(text=current_message)]
)
)
return contents
async def generate_response_stream(self, user_query: str, history: list[dict] = None):
"""
Generates a streaming response using conversation history for context.
Yields chunks of text and thought summaries.
"""
if not self.client:
yield "I couldn't generate a response because the API key is missing."
return
if history is None:
history = []
try:
contents = self._build_contents_from_history(history, user_query)
# Enable thinking mode for general chat as well
config = types.GenerateContentConfig(
system_instruction=SYSTEM_INSTRUCTION,
thinking_config=types.ThinkingConfig(
include_thoughts=True # Enable thought summaries
)
)
stream = await asyncio.to_thread(
self.client.models.generate_content_stream,
model=self.model,
contents=contents,
config=config,
)
for chunk in stream:
for part in chunk.candidates[0].content.parts:
if part.thought:
yield {"type": "thought", "content": part.text}
elif part.text:
yield {"type": "content", "text": part.text}
except Exception as e:
print(f"Error calling Gemini stream: {e}")
yield f"Error: {str(e)}"
async def generate_response(self, user_query: str, history: list[dict] = None) -> str:
"""
Generates a response using conversation history for context.
"""
if not self.client:
return "I couldn't generate a response because the API key is missing."
if history is None:
history = []
try:
contents = self._build_contents_from_history(history, user_query)
config = types.GenerateContentConfig(
system_instruction=SYSTEM_INSTRUCTION,
)
response = await asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=contents,
config=config,
)
return response.text
except Exception as e:
print(f"Error calling Gemini: {e}")
return f"I encountered an error: {e}"
async def detect_intent(self, user_query: str, history: list[dict] = None) -> str:
"""
Detects the intent of the user's query using Gemini thinking mode.
Returns: GENERAL_CHAT, DATA_QUERY, MAP_REQUEST, SPATIAL_OP, or STAT_QUERY
"""
if not self.client:
return "GENERAL_CHAT"
intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query)
try:
# Use thinking mode for better intent classification
config = types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(
thinking_level="medium" # Balanced thinking for intent detection
)
)
response = await asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=intent_prompt,
config=config,
)
intent = response.text.strip().upper()
# Validate the intent
if intent in ["GENERAL_CHAT", "DATA_QUERY", "MAP_REQUEST", "SPATIAL_OP", "STAT_QUERY"]:
return intent
# Default fallback
return "GENERAL_CHAT"
except Exception as e:
print(f"Error detecting intent: {e}")
return "GENERAL_CHAT"
async def stream_intent(self, user_query: str, history: list[dict] = None):
"""
Streams intent detection, yielding thoughts.
"""
if not self.client:
yield {"type": "error", "text": "API Key missing"}
return
intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query)
try:
config = types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(
thinking_level="medium",
include_thoughts=True
)
)
stream = await asyncio.to_thread(
self.client.models.generate_content_stream,
model=self.model,
contents=intent_prompt,
config=config,
)
for chunk in stream:
for part in chunk.candidates[0].content.parts:
if part.thought:
yield {"type": "thought", "text": part.text}
elif part.text:
yield {"type": "content", "text": part.text}
except Exception as e:
print(f"Error detecting intent: {e}")
yield {"type": "error", "text": str(e)}
# Legacy generate_sql removed.
async def identify_relevant_tables(self, user_query: str, table_summaries: str) -> list[str]:
"""
Identifies which tables are relevant for the user's query from the catalog summary.
Returns a JSON list of table names.
"""
if not self.client:
return []
prompt = DATA_DISCOVERY_PROMPT.format(user_query=user_query, table_summaries=table_summaries)
try:
config = types.GenerateContentConfig(
response_mime_type="application/json"
)
response = await asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=prompt,
config=config,
)
text = response.text.replace("```json", "").replace("```", "").strip()
tables = json.loads(text)
return tables if isinstance(tables, list) else []
except Exception as e:
print(f"Error identifying tables: {e}")
return []
async def generate_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None) -> str:
"""
Generates a DuckDB SQL query for analytical/statistical questions about geographic data.
This is the core of the text-to-SQL system.
"""
if not self.client:
return "-- Error: API Key missing"
prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query)
try:
# Use thinking mode for complex SQL generation
config = types.GenerateContentConfig(
temperature=1,
thinking_config=types.ThinkingConfig(
thinking_level="high" # Maximum reasoning for SQL generation
)
)
response = await asyncio.wait_for(
asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=prompt,
config=config,
),
timeout=120.0
)
sql = response.text.replace("```sql", "").replace("```", "").strip()
# Basic validation: must start with SELECT
if not sql.upper().strip().startswith("SELECT") and "-- ERROR" not in sql:
print(f"Warning: Generated SQL doesn't start with SELECT: {sql[:100]}")
if "SELECT" in sql.upper():
start_idx = sql.upper().find("SELECT")
sql = sql[start_idx:]
return sql
except asyncio.TimeoutError:
print("Gemini API call timed out after 30 seconds")
return "-- Error: API call timed out. Please try again."
except Exception as e:
print(f"Error calling Gemini for analytical SQL: {e}")
return f"-- Error generating SQL: {str(e)}"
async def stream_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None):
"""
Streams the generation of DuckDB SQL, yielding thoughts and chunks.
"""
if not self.client:
yield {"type": "error", "text": "API Key missing"}
return
prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query)
try:
config = types.GenerateContentConfig(
temperature=1,
thinking_config=types.ThinkingConfig(
thinking_level="high",
include_thoughts=True
)
)
stream = await asyncio.to_thread(
self.client.models.generate_content_stream,
model=self.model,
contents=prompt,
config=config,
)
for chunk in stream:
for part in chunk.candidates[0].content.parts:
if part.thought:
yield {"type": "thought", "text": part.text}
elif part.text:
yield {"type": "content", "text": part.text}
except Exception as e:
print(f"Error streaming SQL: {e}")
yield {"type": "error", "text": str(e)}
async def stream_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None):
"""
Streams the explanation.
"""
if not self.client:
yield {"type": "error", "text": "API Key missing"}
return
# Build context from history if available
context_str = ""
if history:
context_str = "Previous conversation context:\n"
for msg in history[-4:]: # Last 4 messages for context
context_str += f"- {msg['role']}: {msg['content'][:100]}...\n"
prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary)
try:
config = types.GenerateContentConfig(
system_instruction=SYSTEM_INSTRUCTION,
thinking_config=types.ThinkingConfig(
thinking_level="low",
include_thoughts=True
)
)
stream = await asyncio.to_thread(
self.client.models.generate_content_stream,
model=self.model,
contents=prompt,
config=config,
)
for chunk in stream:
for part in chunk.candidates[0].content.parts:
if part.thought:
yield {"type": "thought", "text": part.text}
elif part.text:
yield {"type": "content", "text": part.text}
except Exception as e:
print(f"Error generating explanation: {e}")
yield {"type": "error", "text": str(e)}
async def generate_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None) -> str:
"""
Explains the results of the query to the user, maintaining conversation context.
"""
if not self.client:
return "I couldn't generate an explanation because the API key is missing."
# Build context from history if available
context_str = ""
if history:
context_str = "Previous conversation context:\n"
for msg in history[-4:]: # Last 4 messages for context
context_str += f"- {msg['role']}: {msg['content'][:100]}...\n"
prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary)
try:
config = types.GenerateContentConfig(
system_instruction=SYSTEM_INSTRUCTION,
thinking_config=types.ThinkingConfig(
thinking_level="low" # Fast response for explanations
)
)
response = await asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=prompt,
config=config,
)
return response.text
except Exception as e:
print(f"Error generating explanation: {e}")
return "Here are the results from the query."
async def generate_spatial_sql(self, user_query: str, layer_context: str, history: list[dict] = None) -> str:
"""
Generates a DuckDB Spatial SQL query for geometric operations on layers.
"""
if not self.client:
return "-- Error: API Key missing"
prompt = SPATIAL_SQL_PROMPT.format(layer_context=layer_context, user_query=user_query)
try:
config = types.GenerateContentConfig(
temperature=1,
)
# Add timeout to prevent indefinite hangs
response = await asyncio.wait_for(
asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=prompt,
config=config,
),
timeout=120.0
)
sql = response.text.replace("```sql", "").replace("```", "").strip()
return sql
except asyncio.TimeoutError:
print("Gemini API call timed out after 30 seconds")
return "-- Error: API call timed out. Please try again."
except Exception as e:
print(f"Error calling Gemini: {e}")
return f"-- Error generating SQL: {str(e)}"
async def correct_sql(self, user_query: str, incorrect_sql: str, error_message: str, schema_context: str) -> str:
"""
Corrects a failed SQL query based on the error message.
"""
if not self.client:
return "-- Error: API Key missing"
prompt = SQL_CORRECTION_PROMPT.format(
error_message=error_message,
incorrect_sql=incorrect_sql,
user_query=user_query,
schema_context=schema_context
)
try:
config = types.GenerateContentConfig(
temperature=1,
)
response = await asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=prompt,
config=config,
)
sql = response.text.replace("```sql", "").replace("```", "").strip()
return sql
except Exception as e:
print(f"Error correcting SQL: {e}")
return incorrect_sql
async def generate_layer_name(self, user_query: str, sql_query: str) -> dict:
"""
Generates a short, descriptive name, emoji, and point style for a map layer.
Returns: {"name": str, "emoji": str, "pointStyle": str | None}
"""
if not self.client:
return {"name": "New Layer", "emoji": "πŸ“", "pointStyle": None}
prompt = LAYER_NAME_PROMPT.format(user_query=user_query, sql_query=sql_query)
try:
config = types.GenerateContentConfig(
temperature=1,
response_mime_type="application/json"
)
# Use simple generate content (not streaming)
response = await asyncio.to_thread(
self.client.models.generate_content,
model=self.model,
contents=prompt,
config=config,
)
result = json.loads(response.text)
return {
"name": result.get("name", "Map Layer"),
"emoji": result.get("emoji", "πŸ“"),
"pointStyle": result.get("pointStyle", None)
}
except Exception as e:
print(f"Error generating layer name: {e}")
return {"name": "Map Layer", "emoji": "πŸ“", "pointStyle": None}