|
|
import os |
|
|
import asyncio |
|
|
import json |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from backend.core.prompts import ( |
|
|
SYSTEM_INSTRUCTION, |
|
|
INTENT_DETECTION_PROMPT, |
|
|
DATA_DISCOVERY_PROMPT, |
|
|
SQL_GENERATION_PROMPT, |
|
|
EXPLANATION_PROMPT, |
|
|
SPATIAL_SQL_PROMPT, |
|
|
SPATIAL_SQL_PROMPT, |
|
|
SQL_CORRECTION_PROMPT, |
|
|
LAYER_NAME_PROMPT |
|
|
) |
|
|
|
|
|
class LLMGateway: |
|
|
def __init__(self, model_name: str = "gemini-3-flash-preview"): |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
self.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") |
|
|
if not self.api_key: |
|
|
print("WARNING: GEMINI_API_KEY/GOOGLE_API_KEY not found. LLM features will not work.") |
|
|
self.client = None |
|
|
else: |
|
|
|
|
|
if "GEMINI_API_KEY" not in os.environ and self.api_key: |
|
|
os.environ["GEMINI_API_KEY"] = self.api_key |
|
|
|
|
|
|
|
|
self.client = genai.Client() |
|
|
|
|
|
self.model = model_name |
|
|
|
|
|
def _build_contents_from_history(self, history: list[dict], current_message: str) -> list: |
|
|
""" |
|
|
Converts conversation history to the format expected by the Gemini API. |
|
|
History format: [{"role": "user"|"assistant", "content": "..."}] |
|
|
""" |
|
|
contents = [] |
|
|
for msg in history: |
|
|
|
|
|
role = "model" if msg["role"] == "assistant" else "user" |
|
|
contents.append( |
|
|
types.Content( |
|
|
role=role, |
|
|
parts=[types.Part.from_text(text=msg["content"])] |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
contents.append( |
|
|
types.Content( |
|
|
role="user", |
|
|
parts=[types.Part.from_text(text=current_message)] |
|
|
) |
|
|
) |
|
|
return contents |
|
|
|
|
|
async def generate_response_stream(self, user_query: str, history: list[dict] = None): |
|
|
""" |
|
|
Generates a streaming response using conversation history for context. |
|
|
Yields chunks of text and thought summaries. |
|
|
""" |
|
|
if not self.client: |
|
|
yield "I couldn't generate a response because the API key is missing." |
|
|
return |
|
|
|
|
|
if history is None: |
|
|
history = [] |
|
|
|
|
|
try: |
|
|
contents = self._build_contents_from_history(history, user_query) |
|
|
|
|
|
|
|
|
config = types.GenerateContentConfig( |
|
|
system_instruction=SYSTEM_INSTRUCTION, |
|
|
thinking_config=types.ThinkingConfig( |
|
|
include_thoughts=True |
|
|
) |
|
|
) |
|
|
|
|
|
stream = await asyncio.to_thread( |
|
|
self.client.models.generate_content_stream, |
|
|
model=self.model, |
|
|
contents=contents, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
for part in chunk.candidates[0].content.parts: |
|
|
if part.thought: |
|
|
yield {"type": "thought", "content": part.text} |
|
|
elif part.text: |
|
|
yield {"type": "content", "text": part.text} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error calling Gemini stream: {e}") |
|
|
yield f"Error: {str(e)}" |
|
|
|
|
|
async def generate_response(self, user_query: str, history: list[dict] = None) -> str: |
|
|
""" |
|
|
Generates a response using conversation history for context. |
|
|
""" |
|
|
if not self.client: |
|
|
return "I couldn't generate a response because the API key is missing." |
|
|
|
|
|
if history is None: |
|
|
history = [] |
|
|
|
|
|
try: |
|
|
contents = self._build_contents_from_history(history, user_query) |
|
|
|
|
|
config = types.GenerateContentConfig( |
|
|
system_instruction=SYSTEM_INSTRUCTION, |
|
|
) |
|
|
|
|
|
response = await asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=contents, |
|
|
config=config, |
|
|
) |
|
|
return response.text |
|
|
except Exception as e: |
|
|
print(f"Error calling Gemini: {e}") |
|
|
return f"I encountered an error: {e}" |
|
|
|
|
|
async def detect_intent(self, user_query: str, history: list[dict] = None) -> str: |
|
|
""" |
|
|
Detects the intent of the user's query using Gemini thinking mode. |
|
|
Returns: GENERAL_CHAT, DATA_QUERY, MAP_REQUEST, SPATIAL_OP, or STAT_QUERY |
|
|
""" |
|
|
if not self.client: |
|
|
return "GENERAL_CHAT" |
|
|
|
|
|
intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) |
|
|
|
|
|
try: |
|
|
|
|
|
config = types.GenerateContentConfig( |
|
|
thinking_config=types.ThinkingConfig( |
|
|
thinking_level="medium" |
|
|
) |
|
|
) |
|
|
|
|
|
response = await asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=intent_prompt, |
|
|
config=config, |
|
|
) |
|
|
intent = response.text.strip().upper() |
|
|
|
|
|
|
|
|
if intent in ["GENERAL_CHAT", "DATA_QUERY", "MAP_REQUEST", "SPATIAL_OP", "STAT_QUERY"]: |
|
|
return intent |
|
|
|
|
|
|
|
|
return "GENERAL_CHAT" |
|
|
except Exception as e: |
|
|
print(f"Error detecting intent: {e}") |
|
|
return "GENERAL_CHAT" |
|
|
|
|
|
async def stream_intent(self, user_query: str, history: list[dict] = None): |
|
|
""" |
|
|
Streams intent detection, yielding thoughts. |
|
|
""" |
|
|
if not self.client: |
|
|
yield {"type": "error", "text": "API Key missing"} |
|
|
return |
|
|
|
|
|
intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
thinking_config=types.ThinkingConfig( |
|
|
thinking_level="medium", |
|
|
include_thoughts=True |
|
|
) |
|
|
) |
|
|
|
|
|
stream = await asyncio.to_thread( |
|
|
self.client.models.generate_content_stream, |
|
|
model=self.model, |
|
|
contents=intent_prompt, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
for part in chunk.candidates[0].content.parts: |
|
|
if part.thought: |
|
|
yield {"type": "thought", "text": part.text} |
|
|
elif part.text: |
|
|
yield {"type": "content", "text": part.text} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error detecting intent: {e}") |
|
|
yield {"type": "error", "text": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
async def identify_relevant_tables(self, user_query: str, table_summaries: str) -> list[str]: |
|
|
""" |
|
|
Identifies which tables are relevant for the user's query from the catalog summary. |
|
|
Returns a JSON list of table names. |
|
|
""" |
|
|
if not self.client: |
|
|
return [] |
|
|
|
|
|
prompt = DATA_DISCOVERY_PROMPT.format(user_query=user_query, table_summaries=table_summaries) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
response_mime_type="application/json" |
|
|
) |
|
|
|
|
|
response = await asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
text = response.text.replace("```json", "").replace("```", "").strip() |
|
|
tables = json.loads(text) |
|
|
return tables if isinstance(tables, list) else [] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error identifying tables: {e}") |
|
|
return [] |
|
|
|
|
|
async def generate_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None) -> str: |
|
|
""" |
|
|
Generates a DuckDB SQL query for analytical/statistical questions about geographic data. |
|
|
This is the core of the text-to-SQL system. |
|
|
""" |
|
|
if not self.client: |
|
|
return "-- Error: API Key missing" |
|
|
|
|
|
prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) |
|
|
|
|
|
try: |
|
|
|
|
|
config = types.GenerateContentConfig( |
|
|
temperature=1, |
|
|
thinking_config=types.ThinkingConfig( |
|
|
thinking_level="high" |
|
|
) |
|
|
) |
|
|
|
|
|
response = await asyncio.wait_for( |
|
|
asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
), |
|
|
timeout=120.0 |
|
|
) |
|
|
|
|
|
sql = response.text.replace("```sql", "").replace("```", "").strip() |
|
|
|
|
|
|
|
|
if not sql.upper().strip().startswith("SELECT") and "-- ERROR" not in sql: |
|
|
print(f"Warning: Generated SQL doesn't start with SELECT: {sql[:100]}") |
|
|
if "SELECT" in sql.upper(): |
|
|
start_idx = sql.upper().find("SELECT") |
|
|
sql = sql[start_idx:] |
|
|
|
|
|
return sql |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
print("Gemini API call timed out after 30 seconds") |
|
|
return "-- Error: API call timed out. Please try again." |
|
|
except Exception as e: |
|
|
print(f"Error calling Gemini for analytical SQL: {e}") |
|
|
return f"-- Error generating SQL: {str(e)}" |
|
|
|
|
|
async def stream_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None): |
|
|
""" |
|
|
Streams the generation of DuckDB SQL, yielding thoughts and chunks. |
|
|
""" |
|
|
if not self.client: |
|
|
yield {"type": "error", "text": "API Key missing"} |
|
|
return |
|
|
|
|
|
prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
temperature=1, |
|
|
thinking_config=types.ThinkingConfig( |
|
|
thinking_level="high", |
|
|
include_thoughts=True |
|
|
) |
|
|
) |
|
|
|
|
|
stream = await asyncio.to_thread( |
|
|
self.client.models.generate_content_stream, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
for part in chunk.candidates[0].content.parts: |
|
|
if part.thought: |
|
|
yield {"type": "thought", "text": part.text} |
|
|
elif part.text: |
|
|
yield {"type": "content", "text": part.text} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error streaming SQL: {e}") |
|
|
yield {"type": "error", "text": str(e)} |
|
|
|
|
|
async def stream_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None): |
|
|
""" |
|
|
Streams the explanation. |
|
|
""" |
|
|
if not self.client: |
|
|
yield {"type": "error", "text": "API Key missing"} |
|
|
return |
|
|
|
|
|
|
|
|
context_str = "" |
|
|
if history: |
|
|
context_str = "Previous conversation context:\n" |
|
|
for msg in history[-4:]: |
|
|
context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" |
|
|
|
|
|
prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
system_instruction=SYSTEM_INSTRUCTION, |
|
|
thinking_config=types.ThinkingConfig( |
|
|
thinking_level="low", |
|
|
include_thoughts=True |
|
|
) |
|
|
) |
|
|
|
|
|
stream = await asyncio.to_thread( |
|
|
self.client.models.generate_content_stream, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
for part in chunk.candidates[0].content.parts: |
|
|
if part.thought: |
|
|
yield {"type": "thought", "text": part.text} |
|
|
elif part.text: |
|
|
yield {"type": "content", "text": part.text} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating explanation: {e}") |
|
|
yield {"type": "error", "text": str(e)} |
|
|
|
|
|
async def generate_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None) -> str: |
|
|
""" |
|
|
Explains the results of the query to the user, maintaining conversation context. |
|
|
""" |
|
|
if not self.client: |
|
|
return "I couldn't generate an explanation because the API key is missing." |
|
|
|
|
|
|
|
|
context_str = "" |
|
|
if history: |
|
|
context_str = "Previous conversation context:\n" |
|
|
for msg in history[-4:]: |
|
|
context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" |
|
|
|
|
|
prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
system_instruction=SYSTEM_INSTRUCTION, |
|
|
thinking_config=types.ThinkingConfig( |
|
|
thinking_level="low" |
|
|
) |
|
|
) |
|
|
|
|
|
response = await asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
) |
|
|
return response.text |
|
|
except Exception as e: |
|
|
print(f"Error generating explanation: {e}") |
|
|
return "Here are the results from the query." |
|
|
|
|
|
async def generate_spatial_sql(self, user_query: str, layer_context: str, history: list[dict] = None) -> str: |
|
|
""" |
|
|
Generates a DuckDB Spatial SQL query for geometric operations on layers. |
|
|
""" |
|
|
if not self.client: |
|
|
return "-- Error: API Key missing" |
|
|
|
|
|
prompt = SPATIAL_SQL_PROMPT.format(layer_context=layer_context, user_query=user_query) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
temperature=1, |
|
|
) |
|
|
|
|
|
|
|
|
response = await asyncio.wait_for( |
|
|
asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
), |
|
|
timeout=120.0 |
|
|
) |
|
|
|
|
|
sql = response.text.replace("```sql", "").replace("```", "").strip() |
|
|
return sql |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
print("Gemini API call timed out after 30 seconds") |
|
|
return "-- Error: API call timed out. Please try again." |
|
|
except Exception as e: |
|
|
print(f"Error calling Gemini: {e}") |
|
|
return f"-- Error generating SQL: {str(e)}" |
|
|
|
|
|
async def correct_sql(self, user_query: str, incorrect_sql: str, error_message: str, schema_context: str) -> str: |
|
|
""" |
|
|
Corrects a failed SQL query based on the error message. |
|
|
""" |
|
|
if not self.client: |
|
|
return "-- Error: API Key missing" |
|
|
|
|
|
prompt = SQL_CORRECTION_PROMPT.format( |
|
|
error_message=error_message, |
|
|
incorrect_sql=incorrect_sql, |
|
|
user_query=user_query, |
|
|
schema_context=schema_context |
|
|
) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
temperature=1, |
|
|
) |
|
|
|
|
|
response = await asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
sql = response.text.replace("```sql", "").replace("```", "").strip() |
|
|
return sql |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error correcting SQL: {e}") |
|
|
return incorrect_sql |
|
|
|
|
|
async def generate_layer_name(self, user_query: str, sql_query: str) -> dict: |
|
|
""" |
|
|
Generates a short, descriptive name, emoji, and point style for a map layer. |
|
|
Returns: {"name": str, "emoji": str, "pointStyle": str | None} |
|
|
""" |
|
|
if not self.client: |
|
|
return {"name": "New Layer", "emoji": "π", "pointStyle": None} |
|
|
|
|
|
prompt = LAYER_NAME_PROMPT.format(user_query=user_query, sql_query=sql_query) |
|
|
|
|
|
try: |
|
|
config = types.GenerateContentConfig( |
|
|
temperature=1, |
|
|
response_mime_type="application/json" |
|
|
) |
|
|
|
|
|
|
|
|
response = await asyncio.to_thread( |
|
|
self.client.models.generate_content, |
|
|
model=self.model, |
|
|
contents=prompt, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
result = json.loads(response.text) |
|
|
return { |
|
|
"name": result.get("name", "Map Layer"), |
|
|
"emoji": result.get("emoji", "π"), |
|
|
"pointStyle": result.get("pointStyle", None) |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error generating layer name: {e}") |
|
|
return {"name": "Map Layer", "emoji": "π", "pointStyle": None} |
|
|
|