|
|
import json
|
|
|
import time
|
|
|
|
|
|
import pandas as pd
|
|
|
from groq import Groq
|
|
|
|
|
|
from .config import settings
|
|
|
from .db import DatabaseHandler
|
|
|
from .knowledge import KnowledgeManager
|
|
|
|
|
|
|
|
|
class AIEngine:
|
|
|
def __init__(self, db: DatabaseHandler, kb: KnowledgeManager | None = None):
|
|
|
self.client = None
|
|
|
self.db = db
|
|
|
self.kb = kb
|
|
|
|
|
|
if settings.groq_api_key:
|
|
|
self.client = Groq(api_key=settings.groq_api_key)
|
|
|
|
|
|
def is_configured(self) -> bool:
|
|
|
return self.client is not None
|
|
|
|
|
|
def _column_semantics(self) -> str:
|
|
|
return """
|
|
|
- `name_of_dam`: Dam name.
|
|
|
- `dam_id`: Unique identifier.
|
|
|
- `lake_level_rl`: Lake level reading (RL).
|
|
|
- `lake_level_reading_time`: Date+time string for the reading. Parse with `TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM')`.
|
|
|
- Live storage may exist in multiple units: `live_storage_mcum` (MCUM) and `live_storage_tmc` (TMC). These represent the same measure in different units.
|
|
|
- Designed/gross storage represent capacity-type values (total/design capacity).
|
|
|
- `District`, `Revenue Region`, `Taluka`, `Village`: administrative/lookup fields.
|
|
|
- `Gated/Non-Gated`: important dam attribute.
|
|
|
|
|
|
SQL typing guidance:
|
|
|
- Some numeric-looking columns may be stored as text. When aggregating, cast safely: `SUM(NULLIF(col, '')::numeric)`.
|
|
|
- When doing date math, subtracting timestamps yields an INTERVAL; convert using `EXTRACT(EPOCH FROM (t2 - t1))` or `EXTRACT(day FROM (t2 - t1))`.
|
|
|
|
|
|
Safe casting patterns (use exactly these to avoid type errors):
|
|
|
- If you need TRIM/REPLACE/regex, first convert to text: `col::text`.
|
|
|
- Safe numeric: `NULLIF(REPLACE(TRIM(col::text), ',', ''), '')::numeric`.
|
|
|
- Safe regex check: `(col::text) ~ '^[0-9]+(\\.[0-9]+)?$'`.
|
|
|
"""
|
|
|
|
|
|
def _call_llm_json(self, messages: list[dict]) -> tuple[dict, dict]:
|
|
|
start_time = time.time()
|
|
|
completion = self.client.chat.completions.create(
|
|
|
model=settings.groq_model_id,
|
|
|
messages=messages,
|
|
|
response_format={"type": "json_object"},
|
|
|
temperature=0.1,
|
|
|
)
|
|
|
response_json = json.loads(completion.choices[0].message.content)
|
|
|
usage = completion.usage
|
|
|
metrics = {
|
|
|
"latency": round(time.time() - start_time, 2),
|
|
|
"tokens": getattr(usage, "total_tokens", None),
|
|
|
"prompt_tokens": getattr(usage, "prompt_tokens", None),
|
|
|
"completion_tokens": getattr(usage, "completion_tokens", None),
|
|
|
"request_id": getattr(completion, "id", None),
|
|
|
"model": settings.groq_model_id,
|
|
|
}
|
|
|
return response_json, metrics
|
|
|
|
|
|
def _planner_prompt(self, rag_context: str) -> str:
|
|
|
|
|
|
return f"""
|
|
|
You are the Planner for a hydrology analytics assistant.
|
|
|
|
|
|
Goal: Decide whether the user needs SQL/data or just a chat answer, and what the query should do.
|
|
|
|
|
|
Use these hints if relevant:
|
|
|
{rag_context}
|
|
|
|
|
|
Return JSON ONLY:
|
|
|
{{
|
|
|
"intent": "chat" | "sql" | "analytics",
|
|
|
"task": "short description of what to do",
|
|
|
"needs_date": true|false,
|
|
|
"chart": {{"type":"line|bar|area|pie|table"}} | null
|
|
|
}}
|
|
|
|
|
|
Rules:
|
|
|
- Prefer `chat` if no DB lookup is needed.
|
|
|
- Prefer `sql` for lookup tables (district, taluka, inventory, IDs).
|
|
|
- Prefer `analytics` for charts/trends/comparisons.
|
|
|
"""
|
|
|
|
|
|
def _sql_builder_prompt(self, schema_context: str, rag_context: str) -> str:
|
|
|
return f"""
|
|
|
You are the SQL Builder for Pravah AI.
|
|
|
|
|
|
You will generate valid Postgres SQL and (only if analytics) a chart_config.
|
|
|
|
|
|
DATABASE SCHEMA (copy-paste exact quoted identifiers):
|
|
|
{schema_context}
|
|
|
|
|
|
RAG RULES / PLAYBOOKS:
|
|
|
{rag_context}
|
|
|
|
|
|
STRICT OUTPUT JSON ONLY:
|
|
|
{{
|
|
|
"message": "Final user-facing answer in plain English (no raw column names unless user asked about columns)",
|
|
|
"sql": "SQL query or empty string",
|
|
|
"chart_config": {{"type":"line|bar|area|pie|table", "x":"...", "y":"...", "title":"..."}} | null
|
|
|
}}
|
|
|
|
|
|
SQL RULES:
|
|
|
- Read-only.
|
|
|
- Use LIMIT 100 for lookups/inventory only. For complete time-bounded daily/monthly trend queries, do NOT use LIMIT.
|
|
|
- Use `DISTINCT` for lookup questions to avoid duplicates.
|
|
|
- Dates: parse lake_level_reading_time using `TO_TIMESTAMP(..., 'DD/MM/YYYY HH:MI AM')`.
|
|
|
- If you use TRIM/REPLACE/regex, do it on `col::text`.
|
|
|
- Safe numeric: `NULLIF(REPLACE(TRIM(col::text), ',', ''), '')::numeric`.
|
|
|
|
|
|
PRODUCTION RULES (WRD/audit-grade):
|
|
|
- Storage/level/"usage" (live storage) are state values. If user asks a daily trend ("each day", "daily", a month like "May 2023"), do NOT use AVG.
|
|
|
- Instead select the latest reading per day using `ROW_NUMBER() OVER (PARTITION BY report_date ORDER BY ts DESC) = 1`.
|
|
|
- Avoid repeated `TO_TIMESTAMP` calls: parse once in a CTE (`parsed`).
|
|
|
- Prefer half-open ranges for month windows: `ts >= 'YYYY-MM-01' AND ts < 'YYYY-MM-01' + INTERVAL '1 month'`.
|
|
|
"""
|
|
|
|
|
|
def _system_prompt(self, schema_context: str, rag_context: str) -> str:
|
|
|
return f"""
|
|
|
You are 'Pravah AI', an expert Water Resources Analytics Assistant.
|
|
|
|
|
|
### DATABASE SCHEMA:
|
|
|
{schema_context}
|
|
|
|
|
|
### DOMAIN KNOWLEDGE (RAG):
|
|
|
{rag_context}
|
|
|
|
|
|
### COLUMN SEMANTICS / UNITS:
|
|
|
{self._column_semantics()}
|
|
|
|
|
|
### STRICT RULES:
|
|
|
1. Read-only SQL only.
|
|
|
2. Date Handling: Dates are often strings. CAST THEM: `TO_TIMESTAMP(col_name, 'DD/MM/YYYY HH:MI AM')`.
|
|
|
3. Date Filters: For "May 2025", use `BETWEEN '2025-05-01' AND '2025-05-31'`.
|
|
|
4. String Matching: Use `ILIKE` for names (e.g., `name_of_dam ILIKE '%Koyna%'`).
|
|
|
5. Limit: Always `LIMIT 100` unless specified.
|
|
|
6. Output: JSON ONLY.
|
|
|
7. The `message` field must contain the final user-facing answer in plain English. Do NOT only say "Below is the SQL".
|
|
|
8. For lookup questions (e.g., district/revenue region of a dam), return a SQL query that selects ONLY the needed columns and use `DISTINCT` to avoid duplicate rows.
|
|
|
9. Only set `chart_config` when `intent` is `analytics`. For `sql` intent, `chart_config` should be `null`.
|
|
|
|
|
|
### OUTPUT FORMAT (JSON):
|
|
|
{{
|
|
|
"intent": "chat" | "sql" | "analytics",
|
|
|
"message": "Conversational reply",
|
|
|
"sql": "Valid SQL query (only for sql/analytics)",
|
|
|
"chart_config": {{ "type": "line|bar|area|pie|table", "x": "col_name", "y": "col_name", "title": "..." }} | null
|
|
|
}}
|
|
|
"""
|
|
|
|
|
|
def repair_sql(self, user_query: str, sql: str, db_error: str) -> str | None:
|
|
|
if not self.client:
|
|
|
return None
|
|
|
|
|
|
schema_context = self.db.get_schema_summary()
|
|
|
rag_context = self.kb.retrieve_context(user_query) if self.kb else ""
|
|
|
|
|
|
messages = [
|
|
|
{
|
|
|
"role": "system",
|
|
|
"content": (
|
|
|
self._system_prompt(schema_context, rag_context)
|
|
|
+ "\n\nYou are now a SQL repair assistant. Return JSON ONLY with key `sql`."
|
|
|
),
|
|
|
},
|
|
|
{
|
|
|
"role": "user",
|
|
|
"content": (
|
|
|
"The following SQL failed to run on Postgres. Fix the SQL while preserving intent. "
|
|
|
"Use exact quoted identifiers for columns as shown in schema. "
|
|
|
"If aggregating numeric-looking text, cast to numeric with NULLIF. "
|
|
|
"Return JSON only: {\"sql\": \"...\"}.\n\n"
|
|
|
f"USER_QUESTION:\n{user_query}\n\nFAILED_SQL:\n{sql}\n\nDB_ERROR:\n{db_error}\n"
|
|
|
),
|
|
|
},
|
|
|
]
|
|
|
|
|
|
try:
|
|
|
completion = self.client.chat.completions.create(
|
|
|
model=settings.groq_model_id,
|
|
|
messages=messages,
|
|
|
response_format={"type": "json_object"},
|
|
|
temperature=0.0,
|
|
|
)
|
|
|
repaired = json.loads(completion.choices[0].message.content)
|
|
|
fixed_sql = repaired.get("sql") if isinstance(repaired, dict) else None
|
|
|
if isinstance(fixed_sql, str) and fixed_sql.strip():
|
|
|
return fixed_sql
|
|
|
return None
|
|
|
except Exception:
|
|
|
return None
|
|
|
|
|
|
def process_query(self, user_query: str, history: list[tuple[str, str]] | None = None) -> tuple[dict, dict]:
|
|
|
if not self.client:
|
|
|
raise RuntimeError("Groq is not configured. Set GROQ_API_KEY in HuggingFace Secrets/env.")
|
|
|
|
|
|
history = history or []
|
|
|
|
|
|
rag_context = self.kb.retrieve_context(user_query) if self.kb else ""
|
|
|
|
|
|
|
|
|
planner_messages = [{"role": "system", "content": self._planner_prompt(rag_context)}]
|
|
|
for h in history[-2:]:
|
|
|
planner_messages.append({"role": "user", "content": h[0]})
|
|
|
if h[1]:
|
|
|
planner_messages.append({"role": "assistant", "content": h[1]})
|
|
|
planner_messages.append({"role": "user", "content": user_query})
|
|
|
|
|
|
plan, m1 = self._call_llm_json(planner_messages)
|
|
|
intent = plan.get("intent") if isinstance(plan, dict) else None
|
|
|
if intent not in ["chat", "sql", "analytics"]:
|
|
|
intent = "chat"
|
|
|
|
|
|
|
|
|
if intent == "chat":
|
|
|
|
|
|
|
|
|
msg = plan.get("task") if isinstance(plan, dict) else None
|
|
|
response_json = {
|
|
|
"intent": "chat",
|
|
|
"message": (msg if isinstance(msg, str) and msg.strip() else ""),
|
|
|
"sql": "",
|
|
|
"chart_config": None,
|
|
|
}
|
|
|
return response_json, m1
|
|
|
|
|
|
schema_context = self.db.get_schema_summary()
|
|
|
builder_messages = [{"role": "system", "content": self._sql_builder_prompt(schema_context, rag_context)}]
|
|
|
for h in history[-3:]:
|
|
|
builder_messages.append({"role": "user", "content": h[0]})
|
|
|
if h[1]:
|
|
|
builder_messages.append({"role": "assistant", "content": h[1]})
|
|
|
|
|
|
task = plan.get("task") if isinstance(plan, dict) else ""
|
|
|
chart = plan.get("chart") if isinstance(plan, dict) else None
|
|
|
builder_user = {
|
|
|
"intent": intent,
|
|
|
"task": task,
|
|
|
"user_query": user_query,
|
|
|
"chart": chart,
|
|
|
}
|
|
|
builder_messages.append({"role": "user", "content": json.dumps(builder_user)})
|
|
|
built, m2 = self._call_llm_json(builder_messages)
|
|
|
|
|
|
message = built.get("message") if isinstance(built, dict) else ""
|
|
|
sql = built.get("sql") if isinstance(built, dict) else ""
|
|
|
chart_config = built.get("chart_config") if isinstance(built, dict) else None
|
|
|
|
|
|
|
|
|
if intent != "analytics":
|
|
|
chart_config = None
|
|
|
|
|
|
response_json = {
|
|
|
"intent": intent,
|
|
|
"message": message or "",
|
|
|
"sql": sql or "",
|
|
|
"chart_config": chart_config,
|
|
|
}
|
|
|
|
|
|
|
|
|
metrics = {
|
|
|
"latency": round((m1.get("latency") or 0) + (m2.get("latency") or 0), 2),
|
|
|
"tokens": (m1.get("tokens") or 0) + (m2.get("tokens") or 0) if m1.get("tokens") or m2.get("tokens") else None,
|
|
|
"prompt_tokens": (m1.get("prompt_tokens") or 0) + (m2.get("prompt_tokens") or 0)
|
|
|
if m1.get("prompt_tokens") or m2.get("prompt_tokens")
|
|
|
else None,
|
|
|
"completion_tokens": (m1.get("completion_tokens") or 0) + (m2.get("completion_tokens") or 0)
|
|
|
if m1.get("completion_tokens") or m2.get("completion_tokens")
|
|
|
else None,
|
|
|
"request_id": m2.get("request_id") or m1.get("request_id"),
|
|
|
"model": settings.groq_model_id,
|
|
|
}
|
|
|
return response_json, metrics
|
|
|
|
|
|
|
|
|
def df_to_records(df: pd.DataFrame, limit: int = 200) -> list[dict]:
|
|
|
if df is None or df.empty:
|
|
|
return []
|
|
|
return df.head(limit).to_dict(orient="records")
|
|
|
|