from __future__ import annotations from typing import Any import chromadb from sentence_transformers import SentenceTransformer KB_CONFIG: list[dict[str, Any]] = [ { "category": "entity_resolution", "keywords": ["dam", "reservoir", "name", "which dam", "id", "dam id"], "context_note": "When the user names a dam, match it using case-insensitive `ILIKE` on the dam name. Prefer `SELECT DISTINCT name_of_dam` first if the spelling may vary.", "sql_template": "NOTE_ONLY", }, { "category": "rules_lookup_distinct", "keywords": ["district", "revenue region", "taluka", "village", "location", "where"], "context_note": "For lookup/location questions, always use `DISTINCT` and only return the needed columns to avoid duplicates.", "sql_template": "NOTE_ONLY", }, { "category": "rules_dates", "keywords": ["date", "time", "today", "latest", "reading"], "context_note": "The `lake_level_reading_time` column is a string date+time. Always parse using `TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM')` for filtering and ordering.", "sql_template": "NOTE_ONLY", }, { "category": "rules_numeric_cast", "keywords": ["sum", "avg", "tmc", "mcum", "percent", "%", "aggregate"], "context_note": "If a numeric field is stored as text or contains blanks/commas, use safe numeric casting: `NULLIF(REPLACE(TRIM(col::text), ',', ''), '')::numeric`. Use `col::text` before TRIM/regex.", "sql_template": "NOTE_ONLY", }, { "category": "metrics", "keywords": ["level", "storage", "capacity", "mcum"], "context_note": "For storage and levels, choose the correct unit column (MCUM vs TMC) and keep units consistent in the answer. Dates must be cast: `TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM')`.", "sql_template": "NOTE_ONLY", }, { "category": "intent_analytics", "keywords": ["trend", "plot", "chart", "graph", "fluctuation", "over time"], "context_note": "For trends, select the date column (casted) and the metric. Order by date ASC.", "sql_template": "NOTE_ONLY", }, { "category": "rules_state_timeseries", "keywords": ["trend", "daily", "each day", "time series", "live storage", "lake level"], "context_note": "Storage/level are state values. For daily trends, use the latest reading per day (ROW_NUMBER() OVER (PARTITION BY date ORDER BY ts DESC)=1). Do NOT use AVG for daily values.", "sql_template": "NOTE_ONLY", }, { "category": "rules_limit_safety", "keywords": ["may", "month", "between", "date range", "from", "to"], "context_note": "Do not use LIMIT for strictly time-bounded daily/monthly trend queries (it can silently truncate days). LIMIT is fine for lookups/inventory but not for complete date windows.", "sql_template": "NOTE_ONLY", }, { "category": "intent_inventory", "keywords": ["inventory", "details", "show all", "table", "gated", "non-gated"], "context_note": "For 'inventory' style tables, return a limited set of human-meaningful columns and include Gated/Non-Gated when relevant. LIMIT 100.", "sql_template": "NOTE_ONLY", }, { "category": "playbook_lookup_location", "keywords": ["district", "revenue region", "taluka", "village", "where is", "location"], "context_note": "For location/lookup questions, use DISTINCT and return only needed fields to avoid duplicates.", "sql_template": "SELECT DISTINCT \"District\", \"Revenue Region\" FROM reservoir_reports WHERE name_of_dam ILIKE '%{dam_name}%' LIMIT 100;", }, { "category": "playbook_latest_reading_for_dam", "keywords": ["latest", "most recent", "recent reading", "timestamp", "last update"], "context_note": "To fetch the latest reading for a dam, order by parsed timestamp DESC and LIMIT 1.", "sql_template": "SELECT * FROM reservoir_reports WHERE name_of_dam ILIKE '%{dam_name}%' ORDER BY TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM') DESC LIMIT 1;", }, { "category": "playbook_daily_state_trend", "keywords": ["trend", "daily", "may", "month", "each day", "live storage", "tmc", "mcum", "lake level"], "context_note": "Canonical daily trend for state metrics: parse timestamp once, compute report_date, rank rows per day by ts DESC and pick rn=1.", "sql_template": "WITH parsed AS (SELECT TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM') AS ts, TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM')::date AS report_date, NULLIF(REPLACE(TRIM(live_storage_tmc::text), ',', ''), '')::numeric AS value FROM reservoir_reports WHERE name_of_dam ILIKE '%{dam_name}%'), ranked AS (SELECT report_date, value, ROW_NUMBER() OVER (PARTITION BY report_date ORDER BY ts DESC) AS rn FROM parsed WHERE ts >= TO_DATE('{start}','YYYY-MM-DD') AND ts < TO_DATE('{end}','YYYY-MM-DD') AND value IS NOT NULL) SELECT report_date AS date, value FROM ranked WHERE rn = 1 ORDER BY report_date;", }, { "category": "playbook_districtwise_stock", "keywords": ["district-wise", "useful water stock", "projected", "designed", "current", "last year"], "context_note": "District-wise stock by revenue region typically needs sums of current/designed/last-year in TMC with safe numeric casting and date filtering.", "sql_template": "SELECT \" Revenue Region \" AS revenue_region, \" District \" AS district, SUM(NULLIF(REPLACE(TRIM(\"live_storage_tmc\"::text), ',', ''), '')::numeric) AS current_tmc, SUM(NULLIF(REPLACE(TRIM(\"designed_live_tmc\"::text), ',', ''), '')::numeric) AS designed_tmc FROM reservoir_reports WHERE TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM')::date = TO_DATE('{date}','DD-MM-YYYY') GROUP BY 1,2 ORDER BY 1,2 LIMIT 100;", }, { "category": "playbook_fortnight_compare", "keywords": ["fortnight", "water year", "10-year", "average", "compare"], "context_note": "Fortnight comparison should avoid interval casting. Prefer `EXTRACT(doy FROM ts)` and compute fortnight = ((doy-1)/14)+1, then compare current water_year vs last 10 water_years.", "sql_template": "WITH base AS (SELECT TO_TIMESTAMP(lake_level_reading_time, 'DD/MM/YYYY HH:MI AM') AS ts, NULLIF(REPLACE(TRIM(live_storage_mcum::text), ',', ''), '')::numeric AS live_storage_mcum FROM reservoir_reports WHERE lake_level_reading_time IS NOT NULL), water AS (SELECT ts, live_storage_mcum, EXTRACT(year FROM ts) AS yr, EXTRACT(month FROM ts) AS mon, EXTRACT(doy FROM ts) AS doy, CASE WHEN EXTRACT(month FROM ts) >= 6 THEN EXTRACT(year FROM ts)::int + 1 ELSE EXTRACT(year FROM ts)::int END AS water_year, ((EXTRACT(doy FROM ts)::int - 1) / 14) + 1 AS fortnight FROM base) SELECT fortnight, AVG(CASE WHEN water_year = (SELECT MAX(water_year) FROM water) THEN live_storage_mcum END) AS current_year_avg, AVG(CASE WHEN water_year BETWEEN (SELECT MAX(water_year) FROM water)-10 AND (SELECT MAX(water_year) FROM water)-1 THEN live_storage_mcum END) AS ten_year_avg FROM water GROUP BY fortnight ORDER BY fortnight LIMIT 100;", }, ] class KnowledgeManager: def __init__(self, config_data: list[dict[str, Any]]): self.config = config_data self.embed_model = SentenceTransformer("all-MiniLM-L6-v2") # In-memory Chroma client (HF Spaces friendly) self.client = chromadb.Client() self.collection = self.client.get_or_create_collection(name="reservoir_kb") self._index_data() def _index_data(self) -> None: if self.collection.count() > 0: return ids: list[str] = [] docs: list[str] = [] metas: list[dict[str, Any]] = [] for idx, item in enumerate(self.config): # Store both the human rule note and (optionally) a playbook template for retrieval. tpl = item.get("sql_template", "") if tpl and tpl != "NOTE_ONLY": tpl = f"SQL_TEMPLATE: {tpl}" else: tpl = "" text_rep = f"{' '.join(item.get('keywords', []))} {item.get('context_note', '')} {tpl}".strip() ids.append(str(idx)) docs.append(text_rep) metas.append( { "note": item.get("context_note", ""), "cat": item.get("category", ""), "sql_template": item.get("sql_template", ""), } ) embeddings = self.embed_model.encode(docs).tolist() self.collection.add(ids=ids, embeddings=embeddings, documents=docs, metadatas=metas) def retrieve_context(self, query: str) -> str: query_emb = self.embed_model.encode([query]).tolist() results = self.collection.query(query_embeddings=query_emb, n_results=4) context_str = "" if results.get("documents"): for i in range(len(results["documents"][0])): meta = results["metadatas"][0][i] context_str += f"- [Hint: {meta.get('cat','')}] {meta.get('note','')}\n" tpl = meta.get("sql_template") if tpl and tpl != "NOTE_ONLY": context_str += f" Example: {tpl}\n" return context_str