pravah / app /ai_engine.py
triflix's picture
Upload 17 files
5a52b4f verified
import json
import time
import pandas as pd
from groq import Groq
from .config import settings
from .db import DatabaseHandler
from .knowledge import KnowledgeManager
class AIEngine:
def __init__(self, db: DatabaseHandler, kb: KnowledgeManager | None = None):
self.client = None
self.db = db
self.kb = kb
if settings.groq_api_key:
self.client = Groq(api_key=settings.groq_api_key)
def is_configured(self) -> bool:
return self.client is not None
def _column_semantics(self) -> str:
return """
- `name_of_dam`: Dam name.
- `dam_id`: Unique identifier.
- `lake_level_rl`: Lake level reading (RL).
- `lake_level_reading_time`: Date+time string for the reading. Parse with `TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM')`.
- Live storage may exist in multiple units: `live_storage_mcum` (MCUM) and `live_storage_tmc` (TMC). These represent the same measure in different units.
- Designed/gross storage represent capacity-type values (total/design capacity).
- `District`, `Revenue Region`, `Taluka`, `Village`: administrative/lookup fields.
- `Gated/Non-Gated`: important dam attribute.
SQL typing guidance:
- Some numeric-looking columns may be stored as text. When aggregating, cast safely: `SUM(NULLIF(col, '')::numeric)`.
- When doing date math, subtracting timestamps yields an INTERVAL; convert using `EXTRACT(EPOCH FROM (t2 - t1))` or `EXTRACT(day FROM (t2 - t1))`.
Safe casting patterns (use exactly these to avoid type errors):
- If you need TRIM/REPLACE/regex, first convert to text: `col::text`.
- Safe numeric: `NULLIF(REPLACE(TRIM(col::text), ',', ''), '')::numeric`.
- Safe regex check: `(col::text) ~ '^[0-9]+(\\.[0-9]+)?$'`.
"""
def _call_llm_json(self, messages: list[dict]) -> tuple[dict, dict]:
start_time = time.time()
completion = self.client.chat.completions.create(
model=settings.groq_model_id,
messages=messages,
response_format={"type": "json_object"},
temperature=0.1,
)
response_json = json.loads(completion.choices[0].message.content)
usage = completion.usage
metrics = {
"latency": round(time.time() - start_time, 2),
"tokens": getattr(usage, "total_tokens", None),
"prompt_tokens": getattr(usage, "prompt_tokens", None),
"completion_tokens": getattr(usage, "completion_tokens", None),
"request_id": getattr(completion, "id", None),
"model": settings.groq_model_id,
}
return response_json, metrics
def _planner_prompt(self, rag_context: str) -> str:
# Keep this prompt small to minimize tokens.
return f"""
You are the Planner for a hydrology analytics assistant.
Goal: Decide whether the user needs SQL/data or just a chat answer, and what the query should do.
Use these hints if relevant:
{rag_context}
Return JSON ONLY:
{{
"intent": "chat" | "sql" | "analytics",
"task": "short description of what to do",
"needs_date": true|false,
"chart": {{"type":"line|bar|area|pie|table"}} | null
}}
Rules:
- Prefer `chat` if no DB lookup is needed.
- Prefer `sql` for lookup tables (district, taluka, inventory, IDs).
- Prefer `analytics` for charts/trends/comparisons.
"""
def _sql_builder_prompt(self, schema_context: str, rag_context: str) -> str:
return f"""
You are the SQL Builder for Pravah AI.
You will generate valid Postgres SQL and (only if analytics) a chart_config.
DATABASE SCHEMA (copy-paste exact quoted identifiers):
{schema_context}
RAG RULES / PLAYBOOKS:
{rag_context}
STRICT OUTPUT JSON ONLY:
{{
"message": "Final user-facing answer in plain English (no raw column names unless user asked about columns)",
"sql": "SQL query or empty string",
"chart_config": {{"type":"line|bar|area|pie|table", "x":"...", "y":"...", "title":"..."}} | null
}}
SQL RULES:
- Read-only.
- Use LIMIT 100 for lookups/inventory only. For complete time-bounded daily/monthly trend queries, do NOT use LIMIT.
- Use `DISTINCT` for lookup questions to avoid duplicates.
- Dates: parse lake_level_reading_time using `TO_TIMESTAMP(..., 'DD/MM/YYYY HH:MI AM')`.
- If you use TRIM/REPLACE/regex, do it on `col::text`.
- Safe numeric: `NULLIF(REPLACE(TRIM(col::text), ',', ''), '')::numeric`.
PRODUCTION RULES (WRD/audit-grade):
- Storage/level/"usage" (live storage) are state values. If user asks a daily trend ("each day", "daily", a month like "May 2023"), do NOT use AVG.
- Instead select the latest reading per day using `ROW_NUMBER() OVER (PARTITION BY report_date ORDER BY ts DESC) = 1`.
- Avoid repeated `TO_TIMESTAMP` calls: parse once in a CTE (`parsed`).
- Prefer half-open ranges for month windows: `ts >= 'YYYY-MM-01' AND ts < 'YYYY-MM-01' + INTERVAL '1 month'`.
"""
def _system_prompt(self, schema_context: str, rag_context: str) -> str:
return f"""
You are 'Pravah AI', an expert Water Resources Analytics Assistant.
### DATABASE SCHEMA:
{schema_context}
### DOMAIN KNOWLEDGE (RAG):
{rag_context}
### COLUMN SEMANTICS / UNITS:
{self._column_semantics()}
### STRICT RULES:
1. Read-only SQL only.
2. Date Handling: Dates are often strings. CAST THEM: `TO_TIMESTAMP(col_name, 'DD/MM/YYYY HH:MI AM')`.
3. Date Filters: For "May 2025", use `BETWEEN '2025-05-01' AND '2025-05-31'`.
4. String Matching: Use `ILIKE` for names (e.g., `name_of_dam ILIKE '%Koyna%'`).
5. Limit: Always `LIMIT 100` unless specified.
6. Output: JSON ONLY.
7. The `message` field must contain the final user-facing answer in plain English. Do NOT only say "Below is the SQL".
8. For lookup questions (e.g., district/revenue region of a dam), return a SQL query that selects ONLY the needed columns and use `DISTINCT` to avoid duplicate rows.
9. Only set `chart_config` when `intent` is `analytics`. For `sql` intent, `chart_config` should be `null`.
### OUTPUT FORMAT (JSON):
{{
"intent": "chat" | "sql" | "analytics",
"message": "Conversational reply",
"sql": "Valid SQL query (only for sql/analytics)",
"chart_config": {{ "type": "line|bar|area|pie|table", "x": "col_name", "y": "col_name", "title": "..." }} | null
}}
"""
def repair_sql(self, user_query: str, sql: str, db_error: str) -> str | None:
if not self.client:
return None
schema_context = self.db.get_schema_summary()
rag_context = self.kb.retrieve_context(user_query) if self.kb else ""
messages = [
{
"role": "system",
"content": (
self._system_prompt(schema_context, rag_context)
+ "\n\nYou are now a SQL repair assistant. Return JSON ONLY with key `sql`."
),
},
{
"role": "user",
"content": (
"The following SQL failed to run on Postgres. Fix the SQL while preserving intent. "
"Use exact quoted identifiers for columns as shown in schema. "
"If aggregating numeric-looking text, cast to numeric with NULLIF. "
"Return JSON only: {\"sql\": \"...\"}.\n\n"
f"USER_QUESTION:\n{user_query}\n\nFAILED_SQL:\n{sql}\n\nDB_ERROR:\n{db_error}\n"
),
},
]
try:
completion = self.client.chat.completions.create(
model=settings.groq_model_id,
messages=messages,
response_format={"type": "json_object"},
temperature=0.0,
)
repaired = json.loads(completion.choices[0].message.content)
fixed_sql = repaired.get("sql") if isinstance(repaired, dict) else None
if isinstance(fixed_sql, str) and fixed_sql.strip():
return fixed_sql
return None
except Exception:
return None
def process_query(self, user_query: str, history: list[tuple[str, str]] | None = None) -> tuple[dict, dict]:
if not self.client:
raise RuntimeError("Groq is not configured. Set GROQ_API_KEY in HuggingFace Secrets/env.")
history = history or []
rag_context = self.kb.retrieve_context(user_query) if self.kb else ""
# Step 1: Planner (low token)
planner_messages = [{"role": "system", "content": self._planner_prompt(rag_context)}]
for h in history[-2:]:
planner_messages.append({"role": "user", "content": h[0]})
if h[1]:
planner_messages.append({"role": "assistant", "content": h[1]})
planner_messages.append({"role": "user", "content": user_query})
plan, m1 = self._call_llm_json(planner_messages)
intent = plan.get("intent") if isinstance(plan, dict) else None
if intent not in ["chat", "sql", "analytics"]:
intent = "chat"
# Step 2: SQL/Analytics Builder (uses schema + playbooks)
if intent == "chat":
# Provide a minimal, natural answer. (No SQL.)
# Keep compatibility with existing contract.
msg = plan.get("task") if isinstance(plan, dict) else None
response_json = {
"intent": "chat",
"message": (msg if isinstance(msg, str) and msg.strip() else ""),
"sql": "",
"chart_config": None,
}
return response_json, m1
schema_context = self.db.get_schema_summary()
builder_messages = [{"role": "system", "content": self._sql_builder_prompt(schema_context, rag_context)}]
for h in history[-3:]:
builder_messages.append({"role": "user", "content": h[0]})
if h[1]:
builder_messages.append({"role": "assistant", "content": h[1]})
task = plan.get("task") if isinstance(plan, dict) else ""
chart = plan.get("chart") if isinstance(plan, dict) else None
builder_user = {
"intent": intent,
"task": task,
"user_query": user_query,
"chart": chart,
}
builder_messages.append({"role": "user", "content": json.dumps(builder_user)})
built, m2 = self._call_llm_json(builder_messages)
message = built.get("message") if isinstance(built, dict) else ""
sql = built.get("sql") if isinstance(built, dict) else ""
chart_config = built.get("chart_config") if isinstance(built, dict) else None
# Enforce chart_config only for analytics
if intent != "analytics":
chart_config = None
response_json = {
"intent": intent,
"message": message or "",
"sql": sql or "",
"chart_config": chart_config,
}
# Aggregate metrics
metrics = {
"latency": round((m1.get("latency") or 0) + (m2.get("latency") or 0), 2),
"tokens": (m1.get("tokens") or 0) + (m2.get("tokens") or 0) if m1.get("tokens") or m2.get("tokens") else None,
"prompt_tokens": (m1.get("prompt_tokens") or 0) + (m2.get("prompt_tokens") or 0)
if m1.get("prompt_tokens") or m2.get("prompt_tokens")
else None,
"completion_tokens": (m1.get("completion_tokens") or 0) + (m2.get("completion_tokens") or 0)
if m1.get("completion_tokens") or m2.get("completion_tokens")
else None,
"request_id": m2.get("request_id") or m1.get("request_id"),
"model": settings.groq_model_id,
}
return response_json, metrics
def df_to_records(df: pd.DataFrame, limit: int = 200) -> list[dict]:
if df is None or df.empty:
return []
return df.head(limit).to_dict(orient="records")