import json import time import pandas as pd from groq import Groq from .config import settings from .db import DatabaseHandler from .knowledge import KnowledgeManager class AIEngine: def __init__(self, db: DatabaseHandler, kb: KnowledgeManager | None = None): self.client = None self.db = db self.kb = kb if settings.groq_api_key: self.client = Groq(api_key=settings.groq_api_key) def is_configured(self) -> bool: return self.client is not None def _column_semantics(self) -> str: return """ - `name_of_dam`: Dam name. - `dam_id`: Unique identifier. - `lake_level_rl`: Lake level reading (RL). - `lake_level_reading_time`: Date+time string for the reading. Parse with `TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM')`. - Live storage may exist in multiple units: `live_storage_mcum` (MCUM) and `live_storage_tmc` (TMC). These represent the same measure in different units. - Designed/gross storage represent capacity-type values (total/design capacity). - `District`, `Revenue Region`, `Taluka`, `Village`: administrative/lookup fields. - `Gated/Non-Gated`: important dam attribute. SQL typing guidance: - Some numeric-looking columns may be stored as text. When aggregating, cast safely: `SUM(NULLIF(col, '')::numeric)`. - When doing date math, subtracting timestamps yields an INTERVAL; convert using `EXTRACT(EPOCH FROM (t2 - t1))` or `EXTRACT(day FROM (t2 - t1))`. Safe casting patterns (use exactly these to avoid type errors): - If you need TRIM/REPLACE/regex, first convert to text: `col::text`. - Safe numeric: `NULLIF(REPLACE(TRIM(col::text), ',', ''), '')::numeric`. - Safe regex check: `(col::text) ~ '^[0-9]+(\\.[0-9]+)?$'`. """ def _call_llm_json(self, messages: list[dict]) -> tuple[dict, dict]: start_time = time.time() completion = self.client.chat.completions.create( model=settings.groq_model_id, messages=messages, response_format={"type": "json_object"}, temperature=0.1, ) response_json = json.loads(completion.choices[0].message.content) usage = completion.usage metrics = { "latency": round(time.time() - start_time, 2), "tokens": getattr(usage, "total_tokens", None), "prompt_tokens": getattr(usage, "prompt_tokens", None), "completion_tokens": getattr(usage, "completion_tokens", None), "request_id": getattr(completion, "id", None), "model": settings.groq_model_id, } return response_json, metrics def _planner_prompt(self, rag_context: str) -> str: # Keep this prompt small to minimize tokens. return f""" You are the Planner for a hydrology analytics assistant. Goal: Decide whether the user needs SQL/data or just a chat answer, and what the query should do. Use these hints if relevant: {rag_context} Return JSON ONLY: {{ "intent": "chat" | "sql" | "analytics", "task": "short description of what to do", "needs_date": true|false, "chart": {{"type":"line|bar|area|pie|table"}} | null }} Rules: - Prefer `chat` if no DB lookup is needed. - Prefer `sql` for lookup tables (district, taluka, inventory, IDs). - Prefer `analytics` for charts/trends/comparisons. """ def _sql_builder_prompt(self, schema_context: str, rag_context: str) -> str: return f""" You are the SQL Builder for Pravah AI. You will generate valid Postgres SQL and (only if analytics) a chart_config. DATABASE SCHEMA (copy-paste exact quoted identifiers): {schema_context} RAG RULES / PLAYBOOKS: {rag_context} STRICT OUTPUT JSON ONLY: {{ "message": "Final user-facing answer in plain English (no raw column names unless user asked about columns)", "sql": "SQL query or empty string", "chart_config": {{"type":"line|bar|area|pie|table", "x":"...", "y":"...", "title":"..."}} | null }} SQL RULES: - Read-only. - Use LIMIT 100 for lookups/inventory only. For complete time-bounded daily/monthly trend queries, do NOT use LIMIT. - Use `DISTINCT` for lookup questions to avoid duplicates. - Dates: parse lake_level_reading_time using `TO_TIMESTAMP(..., 'DD/MM/YYYY HH:MI AM')`. - If you use TRIM/REPLACE/regex, do it on `col::text`. - Safe numeric: `NULLIF(REPLACE(TRIM(col::text), ',', ''), '')::numeric`. PRODUCTION RULES (WRD/audit-grade): - Storage/level/"usage" (live storage) are state values. If user asks a daily trend ("each day", "daily", a month like "May 2023"), do NOT use AVG. - Instead select the latest reading per day using `ROW_NUMBER() OVER (PARTITION BY report_date ORDER BY ts DESC) = 1`. - Avoid repeated `TO_TIMESTAMP` calls: parse once in a CTE (`parsed`). - Prefer half-open ranges for month windows: `ts >= 'YYYY-MM-01' AND ts < 'YYYY-MM-01' + INTERVAL '1 month'`. """ def _system_prompt(self, schema_context: str, rag_context: str) -> str: return f""" You are 'Pravah AI', an expert Water Resources Analytics Assistant. ### DATABASE SCHEMA: {schema_context} ### DOMAIN KNOWLEDGE (RAG): {rag_context} ### COLUMN SEMANTICS / UNITS: {self._column_semantics()} ### STRICT RULES: 1. Read-only SQL only. 2. Date Handling: Dates are often strings. CAST THEM: `TO_TIMESTAMP(col_name, 'DD/MM/YYYY HH:MI AM')`. 3. Date Filters: For "May 2025", use `BETWEEN '2025-05-01' AND '2025-05-31'`. 4. String Matching: Use `ILIKE` for names (e.g., `name_of_dam ILIKE '%Koyna%'`). 5. Limit: Always `LIMIT 100` unless specified. 6. Output: JSON ONLY. 7. The `message` field must contain the final user-facing answer in plain English. Do NOT only say "Below is the SQL". 8. For lookup questions (e.g., district/revenue region of a dam), return a SQL query that selects ONLY the needed columns and use `DISTINCT` to avoid duplicate rows. 9. Only set `chart_config` when `intent` is `analytics`. For `sql` intent, `chart_config` should be `null`. ### OUTPUT FORMAT (JSON): {{ "intent": "chat" | "sql" | "analytics", "message": "Conversational reply", "sql": "Valid SQL query (only for sql/analytics)", "chart_config": {{ "type": "line|bar|area|pie|table", "x": "col_name", "y": "col_name", "title": "..." }} | null }} """ def repair_sql(self, user_query: str, sql: str, db_error: str) -> str | None: if not self.client: return None schema_context = self.db.get_schema_summary() rag_context = self.kb.retrieve_context(user_query) if self.kb else "" messages = [ { "role": "system", "content": ( self._system_prompt(schema_context, rag_context) + "\n\nYou are now a SQL repair assistant. Return JSON ONLY with key `sql`." ), }, { "role": "user", "content": ( "The following SQL failed to run on Postgres. Fix the SQL while preserving intent. " "Use exact quoted identifiers for columns as shown in schema. " "If aggregating numeric-looking text, cast to numeric with NULLIF. " "Return JSON only: {\"sql\": \"...\"}.\n\n" f"USER_QUESTION:\n{user_query}\n\nFAILED_SQL:\n{sql}\n\nDB_ERROR:\n{db_error}\n" ), }, ] try: completion = self.client.chat.completions.create( model=settings.groq_model_id, messages=messages, response_format={"type": "json_object"}, temperature=0.0, ) repaired = json.loads(completion.choices[0].message.content) fixed_sql = repaired.get("sql") if isinstance(repaired, dict) else None if isinstance(fixed_sql, str) and fixed_sql.strip(): return fixed_sql return None except Exception: return None def process_query(self, user_query: str, history: list[tuple[str, str]] | None = None) -> tuple[dict, dict]: if not self.client: raise RuntimeError("Groq is not configured. Set GROQ_API_KEY in HuggingFace Secrets/env.") history = history or [] rag_context = self.kb.retrieve_context(user_query) if self.kb else "" # Step 1: Planner (low token) planner_messages = [{"role": "system", "content": self._planner_prompt(rag_context)}] for h in history[-2:]: planner_messages.append({"role": "user", "content": h[0]}) if h[1]: planner_messages.append({"role": "assistant", "content": h[1]}) planner_messages.append({"role": "user", "content": user_query}) plan, m1 = self._call_llm_json(planner_messages) intent = plan.get("intent") if isinstance(plan, dict) else None if intent not in ["chat", "sql", "analytics"]: intent = "chat" # Step 2: SQL/Analytics Builder (uses schema + playbooks) if intent == "chat": # Provide a minimal, natural answer. (No SQL.) # Keep compatibility with existing contract. msg = plan.get("task") if isinstance(plan, dict) else None response_json = { "intent": "chat", "message": (msg if isinstance(msg, str) and msg.strip() else ""), "sql": "", "chart_config": None, } return response_json, m1 schema_context = self.db.get_schema_summary() builder_messages = [{"role": "system", "content": self._sql_builder_prompt(schema_context, rag_context)}] for h in history[-3:]: builder_messages.append({"role": "user", "content": h[0]}) if h[1]: builder_messages.append({"role": "assistant", "content": h[1]}) task = plan.get("task") if isinstance(plan, dict) else "" chart = plan.get("chart") if isinstance(plan, dict) else None builder_user = { "intent": intent, "task": task, "user_query": user_query, "chart": chart, } builder_messages.append({"role": "user", "content": json.dumps(builder_user)}) built, m2 = self._call_llm_json(builder_messages) message = built.get("message") if isinstance(built, dict) else "" sql = built.get("sql") if isinstance(built, dict) else "" chart_config = built.get("chart_config") if isinstance(built, dict) else None # Enforce chart_config only for analytics if intent != "analytics": chart_config = None response_json = { "intent": intent, "message": message or "", "sql": sql or "", "chart_config": chart_config, } # Aggregate metrics metrics = { "latency": round((m1.get("latency") or 0) + (m2.get("latency") or 0), 2), "tokens": (m1.get("tokens") or 0) + (m2.get("tokens") or 0) if m1.get("tokens") or m2.get("tokens") else None, "prompt_tokens": (m1.get("prompt_tokens") or 0) + (m2.get("prompt_tokens") or 0) if m1.get("prompt_tokens") or m2.get("prompt_tokens") else None, "completion_tokens": (m1.get("completion_tokens") or 0) + (m2.get("completion_tokens") or 0) if m1.get("completion_tokens") or m2.get("completion_tokens") else None, "request_id": m2.get("request_id") or m1.get("request_id"), "model": settings.groq_model_id, } return response_json, metrics def df_to_records(df: pd.DataFrame, limit: int = 200) -> list[dict]: if df is None or df.empty: return [] return df.head(limit).to_dict(orient="records")