File size: 12,013 Bytes
90bfc25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
utils/gemini_utils.py
---------------------
Shared Gemini setup, data context builder, and model caller.
Used by gemini_insights.py, home.py, and every page's AI summary widget.
"""

import json
import streamlit as st
import google.generativeai as genai


import re

def strip_markdown(text: str) -> str:
    text = re.sub(r'\*\*(.+?)\*\*', r'\1', text)   # **bold**
    text = re.sub(r'\*(.+?)\*',     r'\1', text)   # *italic*
    text = re.sub(r'__(.+?)__',     r'\1', text)   # __bold__
    text = re.sub(r'_(.+?)_',       r'\1', text)   # _italic_
    text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)  # headings
    text = re.sub(r'^\s*[-*β€’]\s+',  '', text, flags=re.MULTILINE)  # bullets
    text = re.sub(r'^\s*\d+\.\s+',  '', text, flags=re.MULTILINE)  # numbered lists
    text = re.sub(r'`(.+?)`',       r'\1', text)   # inline code
    text = re.sub(r'\n{3,}', '\n\n', text)
    return text.strip()

from utils.api_client import (
    fetch_stats, fetch_predictions, fetch_optimizer_results, fetch_yearly_trend,
)

MODEL_NAME = "gemini-2.5-flash-lite"

# ── Preset questions (used by insights page) ──────────────────────────────────
PRESET_QUESTIONS = [
    {
        "label": "Which districts are predicted to see the steepest employment decline?",
        "key": "declining",
        "icon": "πŸ“‰",
    },
    {
        "label": "Which districts offer the best return on additional budget investment?",
        "key": "roi",
        "icon": "πŸ’°",
    },
    {
        "label": "What does the model predict for national employment in the next cycle?",
        "key": "forecast",
        "icon": "πŸ”­",
    },
    {
        "label": "Which states should be prioritised for budget reallocation and why?",
        "key": "realloc",
        "icon": "βš–οΈ",
    },
    {
        "label": "What is the predicted COVID recovery trajectory across districts?",
        "key": "covid",
        "icon": "🦠",
    },
    {
        "label": "Which districts are most underfunded relative to their predicted demand?",
        "key": "underfunded",
        "icon": "🚨",
    },
    {
        "label": "What are the top 5 efficiency leaders and what can we learn from them?",
        "key": "efficiency",
        "icon": "πŸ†",
    },
    {
        "label": "Summarise the overall model prediction results in plain language.",
        "key": "summary",
        "icon": "πŸ“‹",
    },
]

# ── Per-page summary prompts ──────────────────────────────────────────────────
PAGE_SUMMARY_PROMPTS = {
    "overview": "In 3–4 sentences, summarise the key takeaways from the national MNREGA employment trend data shown. Focus on the most important patterns, anomalies, and what they imply for policy. Use specific numbers.",
    "districts": "In 3–4 sentences, give a sharp analytical summary of this district's MNREGA performance trajectory. What is the trend, how did COVID affect it, and what does the model predict? Be specific.",
    "predictions": "In 3–4 sentences, summarise what the model predictions reveal. Comment on accuracy, any notable over/under-predictions, and what the forecasts imply for the next cycle.",
    "optimizer": "In 3–4 sentences, explain the budget optimiser results in plain language. What is the headline gain, which districts benefit most, and is the reallocation realistic to implement?",
    "insights": "In 3–4 sentences, provide a crisp executive summary of the strategic insights. What are the 2–3 most urgent actions a policymaker should take based on this data?",
    "spatial": "In 3–4 sentences, describe what the spatial distribution of predicted employment reveals. Are there regional clusters of high or low performance? What geographic patterns stand out?",
}


def get_gemini_key() -> str | None:
    """Get key from session state (set once in sidebar)."""
    return st.session_state.get("gemini_api_key", "")


def configure_gemini(api_key: str):
    genai.configure(api_key=api_key)
    return genai.GenerativeModel(MODEL_NAME)


@st.cache_data(ttl=300, show_spinner=False)
def build_context(state_param: str | None) -> dict:
    """Build a structured data context dict from live API data."""
    stats    = fetch_stats()
    pred_df  = fetch_predictions(state=state_param)
    opt_df   = fetch_optimizer_results(state=state_param)
    trend_df = fetch_yearly_trend(state=state_param)

    ctx: dict = {}

    ctx["scope"] = state_param or "All India"
    ctx["overview"] = {
        "total_districts": stats.get("total_districts"),
        "total_states": stats.get("total_states"),
        "year_range": stats.get("year_range"),
        "total_persondays_lakhs": round(stats.get("total_persondays_lakhs", 0), 1),
        "covid_spike_pct": stats.get("covid_spike_pct"),
    }

    if not trend_df.empty:
        ctx["yearly_trend"] = (
            trend_df[["financial_year","total_persondays","avg_wage"]]
            .round(2).to_dict(orient="records")
        )

    if not pred_df.empty:
        ly  = int(pred_df["financial_year"].max())
        prv = ly - 1
        lat = pred_df[pred_df["financial_year"] == ly]
        prv_df = pred_df[pred_df["financial_year"] == prv]

        ctx["model"] = {
            "algorithm": "GradientBoostingRegressor",
            "latest_predicted_year": ly,
            "walk_forward_r2": 0.91,
            "note": "2022 West Bengal anomaly excluded from CV",
        }

        if not prv_df.empty:
            mg = lat.merge(
                prv_df[["state","district","person_days_lakhs"]]
                .rename(columns={"person_days_lakhs":"prev"}),
                on=["state","district"], how="inner",
            )
            mg["chg"] = (mg["predicted_persondays"] - mg["prev"]).round(2)
            mg["chg_pct"] = (mg["chg"] / mg["prev"] * 100).round(1)

            ctx["predictions"] = {
                "n_improving": int((mg["chg"] >= 0).sum()),
                "n_declining":  int((mg["chg"] < 0).sum()),
                "top_improving": mg.nlargest(5, "chg")[
                    ["state","district","prev","predicted_persondays","chg","chg_pct"]
                ].to_dict(orient="records"),
                "top_declining": mg.nsmallest(5, "chg")[
                    ["state","district","prev","predicted_persondays","chg","chg_pct"]
                ].to_dict(orient="records"),
                "national_predicted_total": round(float(lat["predicted_persondays"].sum()), 1),
                "national_actual_prev": round(float(prv_df["person_days_lakhs"].sum()), 1),
            }

    if not opt_df.empty and "persondays_gain" in opt_df.columns:
        sq   = float(opt_df["sq_persondays"].sum())
        gain = float(opt_df["persondays_gain"].sum())
        ctx["optimizer"] = {
            "total_budget_lakhs": round(float(opt_df.get("budget_allocated_lakhs", opt_df["sq_persondays"]).sum()), 0),
            "status_quo_persondays": round(sq, 1),
            "gain_lakhs": round(gain, 2),
            "gain_pct": round(gain / sq * 100, 2) if sq else 0,
            "top_gain": opt_df.nlargest(5, "persondays_gain")[
                ["state","district","persondays_gain","persondays_per_lakh","budget_change_pct"]
            ].round(3).to_dict(orient="records"),
            "top_cut": opt_df.nsmallest(5, "persondays_gain")[
                ["state","district","persondays_gain","persondays_per_lakh","budget_change_pct"]
            ].round(3).to_dict(orient="records"),
            "by_state": (
                opt_df.groupby("state")["persondays_gain"]
                .sum().nlargest(8).round(2).to_dict()
            ),
            "underfunded": opt_df[
                opt_df["budget_allocated_lakhs"] < opt_df["budget_allocated_lakhs"].quantile(0.33)
            ].nlargest(5, "persondays_per_lakh")[
                ["state","district","persondays_per_lakh","budget_allocated_lakhs"]
            ].round(3).to_dict(orient="records") if "budget_allocated_lakhs" in opt_df.columns else [],
        }

    return ctx


def call_gemini(api_key: str, prompt: str, temperature: float = 0.35) -> str:
    """Call Gemini and return the text response."""
    try:
        m = configure_gemini(api_key)
        resp = m.generate_content(
            prompt,
            generation_config=genai.types.GenerationConfig(
                temperature=temperature,
                max_output_tokens=1024,
            ),
        )
        return strip_markdown(resp.text)
    except Exception as e:
        return f"⚠️ Gemini error: {e}"


def base_prompt(ctx: dict) -> str:
    return f"""You are a senior policy analyst specialising in India's MNREGA rural employment scheme.
Scope: {ctx.get('scope', 'All India')}

Live data from SchemeImpactNet (GradientBoostingRegressor, walk-forward CV RΒ²β‰ˆ0.91):
{json.dumps(ctx, indent=2)}

Rules:
- Person-days in lakhs (1 lakh = 100,000). Budget in β‚Ή lakhs.
- 2020: COVID surge (reverse migration drove demand spike).
- 2022: West Bengal data anomaly (-93% to -98%) β€” not a real decline.
- The LP optimizer reallocates budget across districts at zero additional cost.
- Base every claim on the numbers above. Name specific districts and states.
- Be direct, analytical, and avoid generic statements.

"""


def preset_prompt(ctx: dict, question_key: str) -> str:
    base = base_prompt(ctx)
    prompts = {
        "declining": base + "Which districts are predicted to see the steepest employment decline? Name the top 5, give exact predicted change figures, identify any state-level patterns, and suggest specific interventions. (~300 words)",
        "roi": base + "Which districts offer the best return on additional budget investment based on efficiency (persondays_per_lakh) scores? Name top districts, explain why their efficiency is high, and estimate the employment gain from a 10% budget increase. (~300 words)",
        "forecast": base + "What does the model predict for national employment in the next cycle? Compare predicted vs previous actual totals, identify which states drive the change, and assess confidence given model performance. (~300 words)",
        "realloc": base + "Which states should be prioritised for budget reallocation and why? Use the optimizer state-level data, name the top 3 states for increase and top 3 for reduction, with the employment gain rationale. (~300 words)",
        "covid": base + "What is the predicted COVID recovery trajectory? Has employment normalised post-2020 surge, or are certain districts still at elevated levels? What does this imply for future demand planning? (~300 words)",
        "underfunded": base + "Which districts are most underfunded relative to their predicted demand and efficiency scores? Name specific districts, show the gap between their efficiency and their budget allocation, and recommend reallocation amounts. (~300 words)",
        "efficiency": base + "Who are the top 5 efficiency leaders (highest persondays_per_lakh)? What structural factors likely explain their high efficiency? What can other districts learn and replicate? (~300 words)",
        "summary": base + "Summarise the overall model prediction results in plain language for a non-technical policymaker. Cover: what the model predicts nationally, which regions face challenges, and the 3 most important numbers to know. (~300 words)",
    }
    return prompts.get(question_key, base + "Provide a strategic analysis of the MNREGA data.")


def page_summary_prompt(ctx: dict, page_key: str, extra_context: str = "") -> str:
    base = base_prompt(ctx)
    instruction = PAGE_SUMMARY_PROMPTS.get(page_key, "Summarise the key insights from this page's data in 3–4 sentences.")
    extra = f"\nAdditional page context:\n{extra_context}\n" if extra_context else ""
    return base + extra + "\n" + instruction + "\nRespond in 3–4 sentences only. Be precise and use specific numbers."