Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import logging | |
| import re | |
| from typing import List, Dict, Any, Optional | |
| from supabase import create_client, Client | |
| from groq import Groq | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class GroqRAGChatbot: | |
| def __init__(self): | |
| """Initialize optimized RAG Chatbot with correct models and Supabase""" | |
| self.groq_client = Groq(api_key=os.getenv('GROQ_API_KEY')) | |
| self.models = { | |
| 'intent_analyzer': 'llama-3.1-8b-instant', | |
| 'query_builder': 'llama-3.3-70b-versatile', | |
| 'response_generator': 'llama-3.3-70b-versatile' | |
| } | |
| self.supabase_url = os.getenv('SUPABASE_URL') | |
| self.supabase_key = os.getenv('SUPABASE_KEY') | |
| self.supabase: Client = create_client(self.supabase_url, self.supabase_key) | |
| self.schema_info = { | |
| 'table_name': 'groundwater_data', | |
| 'key_columns': { | |
| 'district': 'District name (VARCHAR) - ALWAYS REQUIRED - lowercase', | |
| 'state': 'State name (VARCHAR) - ALWAYS REQUIRED - lowercase', | |
| 'annual_gw_draft_total': 'Total groundwater draft in hectare meters (DECIMAL)', | |
| 'annual_replenishable_gw_resource': 'Replenishable groundwater resource (DECIMAL)', | |
| 'stage_of_development': 'Development stage percentage (DECIMAL)', | |
| 'net_gw_availability': 'Net groundwater availability (DECIMAL)', | |
| 'annual_draft_irrigation': 'Irrigation draft (DECIMAL)', | |
| 'st_area_shape': 'Underground water coverage area in square meters (DOUBLE PRECISION)', | |
| 'st_length_shape': 'Underground water perimeter in meters (DOUBLE PRECISION)', | |
| 'geometry': 'Geographic boundaries for underground water mapping (TEXT)' | |
| } | |
| } | |
| def get_db_connection(self): | |
| try: | |
| result = self.supabase.table(self.schema_info['table_name']).select('*').limit(1).execute() | |
| return True | |
| except Exception as e: | |
| logger.error(f"Supabase connection error: {e}") | |
| return False | |
| def analyze_user_intent(self, user_query: str) -> Dict[str, Any]: | |
| try: | |
| prompt = f"""Analyze this user query and respond with JSON only: | |
| Query: "{user_query}" | |
| Available columns: {', '.join(self.schema_info['key_columns'].keys())} | |
| IMPORTANT: For underground water analysis, always consider st_area_shape (coverage area) and st_length_shape (perimeter). | |
| Response format: | |
| {{ | |
| "intent_type": "comparison|ranking|statistics|filter|geographic", | |
| "entities": ["district names mentioned"], | |
| "target_columns": ["relevant column names"], | |
| "needs_visualization": true|false, | |
| "requires_geography": true|false, | |
| "underground_focus": true|false | |
| }}""" | |
| response = self.groq_client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a query analyzer. Respond only with valid JSON. Always include district name and underground water metrics."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=self.models['intent_analyzer'], | |
| temperature=0.1, | |
| max_tokens=200 | |
| ) | |
| return json.loads(response.choices[0].message.content) | |
| except Exception as e: | |
| logger.error(f"Intent analysis error: {e}") | |
| return { | |
| "intent_type": "ranking", | |
| "entities": [], | |
| "target_columns": ["annual_gw_draft_total"], | |
| "needs_visualization": True, | |
| "requires_geography": False, | |
| "underground_focus": True | |
| } | |
| def build_supabase_query(self, user_query: str, intent_analysis: Dict[str, Any]) -> Any: | |
| """ | |
| Build a Supabase query using the client's query methods instead of generating raw SQL | |
| """ | |
| try: | |
| intent_type = intent_analysis.get('intent_type', 'ranking') | |
| entities = intent_analysis.get('entities', []) | |
| target_columns = intent_analysis.get('target_columns', ['annual_gw_draft_total']) | |
| # Infer top N from free text | |
| inferred_limit: Optional[int] = None | |
| try: | |
| m = re.search(r"\btop\s*(\d+)\b", (user_query or '').lower()) | |
| if m: | |
| inferred_limit = int(m.group(1)) | |
| except Exception: | |
| inferred_limit = None | |
| # Always include these columns | |
| mandatory_columns = ['district', 'state', 'st_area_shape', 'st_length_shape', 'geometry'] | |
| selected_columns = list(set(mandatory_columns + target_columns)) | |
| # Start building the query | |
| query = self.supabase.table(self.schema_info['table_name']).select(','.join(selected_columns)) | |
| # Apply filters based on intent | |
| if entities: | |
| # Apply OR across district/state for each entity with wildcards | |
| blacklist = { | |
| 'district', 'districts', 'state', 'states', | |
| 'district names mentioned', 'district names', 'unknown', 'india', 'indian' | |
| } | |
| safe_entities = [] | |
| for raw in entities: | |
| try: | |
| token = str(raw).strip().lower() | |
| except Exception: | |
| continue | |
| if not token: | |
| continue | |
| # ignore placeholders or generic tokens containing admin unit words | |
| if token in blacklist or ('district' in token) or ('state' in token): | |
| continue | |
| # ignore extremely short tokens | |
| if len(token) < 3: | |
| continue | |
| safe_entities.append(token) | |
| if safe_entities: | |
| or_clauses = [] | |
| for e in safe_entities: | |
| # Use PostgREST ilike syntax with *wildcards* | |
| pattern = f"*{e}*" | |
| or_clauses.append(f"district.ilike.{pattern}") | |
| or_clauses.append(f"state.ilike.{pattern}") | |
| # Combine into a single OR string | |
| or_str = ','.join(or_clauses) | |
| try: | |
| query = query.or_(or_str) | |
| except Exception: | |
| # Fallback: chain first entity as ilike filter | |
| try: | |
| query = query.ilike('district', pattern) | |
| except Exception: | |
| pass | |
| else: | |
| # No safe entities; do not constrain by entity at all | |
| pass | |
| elif intent_type == "filter": | |
| # For filtering queries, we might need to add specific conditions | |
| # This is a simple implementation - you might want to enhance it | |
| if "high" in user_query.lower() or "greater" in user_query.lower(): | |
| query = query.gt('stage_of_development', 80) | |
| elif "low" in user_query.lower() or "less" in user_query.lower(): | |
| query = query.lt('stage_of_development', 40) | |
| # Choose metric preference from query keywords | |
| ql = (user_query or '').lower() | |
| metric_preference = None | |
| # Map "water level high" to underground coverage area if available | |
| if any(k in ql for k in ["water level", "waterlevel", "underground", "groundwater", "coverage"]): | |
| metric_preference = 'st_area_shape' | |
| # Apply ordering based on intent/metric | |
| if intent_type == "ranking": | |
| column = metric_preference or (target_columns[0] if target_columns else 'annual_gw_draft_total') | |
| order = "desc" if any(word in ql for word in ['highest', 'top', 'most', 'maximum', 'high']) else "asc" | |
| query = query.order(column, desc=(order == "desc")) | |
| elif intent_type == "geographic": | |
| query = query.order('st_area_shape', desc=True) | |
| # Apply limit (return more rows to power Knowledge/Insights) | |
| # If specific entities mentioned, default to a tighter limit unless user said otherwise | |
| if entities: | |
| default_limit = 50 | |
| else: | |
| default_limit = 50 if intent_type in ["ranking", "geographic"] else 200 | |
| final_limit = inferred_limit if (isinstance(inferred_limit, int) and inferred_limit > 0) else default_limit | |
| # Clamp reasonable bounds (1..500) | |
| final_limit = max(1, min(500, final_limit)) | |
| query = query.limit(final_limit) | |
| return query | |
| except Exception as e: | |
| logger.error(f"Supabase query building error: {e}") | |
| # Fallback to a simple query | |
| return self.supabase.table(self.schema_info['table_name']).select('*').limit(10) | |
| def execute_supabase_query(self, query) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Execute the Supabase query and return results | |
| """ | |
| try: | |
| result = query.execute() | |
| # Supabase-py v2 returns a PostgrestResponse with .data | |
| rows = getattr(result, 'data', None) | |
| if rows is None: | |
| rows = [] | |
| # Normalize string fields to lowercase except geometry | |
| for row in rows: | |
| for k, v in list(row.items()): | |
| if isinstance(v, str) and k != 'geometry': | |
| row[k] = v.lower() | |
| # print(rows) | |
| logger.info(f"Supabase query returned {len(rows)} results") | |
| return rows | |
| except Exception as e: | |
| logger.error(f"Supabase query execution error: {e}") | |
| return [] | |
| def get_quick_stats(self) -> Dict[str, Any]: | |
| try: | |
| total_result = self.supabase.table(self.schema_info['table_name']).select("district", count="exact").limit(1).execute() | |
| total_districts = total_result.count if hasattr(total_result, 'count') else len(total_result.data) if total_result.data else 0 | |
| all_data = self.supabase.table(self.schema_info['table_name']).select("stage_of_development").execute() | |
| if all_data.data: | |
| developments = [] | |
| for row in all_data.data: | |
| val = row.get('stage_of_development') | |
| try: | |
| if val is not None: | |
| num_val = float(val) if isinstance(val, str) else val | |
| if isinstance(num_val, (int, float)): | |
| # Sanitize: ignore invalid/negative and extreme outliers | |
| if 0 <= num_val <= 500: | |
| developments.append(num_val) | |
| except (ValueError, TypeError): | |
| continue | |
| if developments: | |
| avg_development = sum(developments) / len(developments) | |
| # Clamp to sensible range | |
| avg_development = max(0.0, min(200.0, avg_development)) | |
| over_exploited = len([d for d in developments if d is not None and d > 100]) | |
| critical = len([d for d in developments if d is not None and 80 <= d <= 100]) | |
| else: | |
| avg_development = 0 | |
| over_exploited = 0 | |
| critical = 0 | |
| else: | |
| avg_development = 0 | |
| over_exploited = 0 | |
| critical = 0 | |
| return { | |
| "total_districts": total_districts, | |
| "avg_development": round(float(avg_development), 1), | |
| "over_exploited": over_exploited, | |
| "critical": critical | |
| } | |
| except Exception as e: | |
| logger.error(f"Stats query error: {e}") | |
| return { | |
| "total_districts": 0, | |
| "avg_development": 0, | |
| "over_exploited": 0, | |
| "critical": 0 | |
| } | |
| def generate_response(self, user_query: str, query_results: List[Dict[str, Any]]) -> str: | |
| try: | |
| if not query_results: | |
| return "No data found matching your query. Please try rephrasing your question or check if the district names are correct." | |
| results_summary = [] | |
| for result in query_results[:5]: | |
| result_items = [] | |
| for k, v in result.items(): | |
| if v is not None and k != 'geometry': | |
| if k == 'st_area_shape': | |
| try: | |
| area_val = float(v) | |
| result_items.append(f"Underground Coverage Area: {area_val:,.0f} sq.m") | |
| except (ValueError, TypeError): | |
| result_items.append(f"Underground Coverage Area: {v}") | |
| elif k == 'st_length_shape': | |
| try: | |
| length_val = float(v) | |
| result_items.append(f"Underground Perimeter: {length_val:,.0f} m") | |
| except (ValueError, TypeError): | |
| result_items.append(f"Underground Perimeter: {v}") | |
| else: | |
| result_items.append(f"{k}: {v}") | |
| results_summary.append(", ".join(result_items)) | |
| results_text = "\n".join(results_summary) | |
| prompt = f"""Analyze Indian groundwater data results with focus on underground water availability. | |
| User Question: {user_query} | |
| Results ({len(query_results)} total): | |
| {results_text} | |
| IMPORTANT CONTEXT: | |
| - st_area_shape represents underground water coverage area (larger = more underground water extent) | |
| - st_length_shape represents underground water perimeter (longer = more complex underground water boundaries) | |
| - These metrics help assess underground water availability and distribution | |
| Provide analysis with: | |
| 1. Direct answer to the user's question | |
| 2. District names with specific numbers | |
| 3. Underground water coverage insights using st_area_shape and st_length_shape | |
| 4. Practical implications for water management | |
| 5. Which districts have better underground water availability based on area/perimeter metrics | |
| Keep response informative and highlight underground water aspects.""" | |
| response = self.groq_client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are an Indian groundwater expert specializing in underground water analysis. Provide insights using area and perimeter metrics for underground water availability."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=self.models['response_generator'], | |
| temperature=0.3, | |
| max_tokens=500 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Response generation error: {e}") | |
| if query_results: | |
| districts = [r.get('district', 'Unknown') for r in query_results[:3]] | |
| return f"Found underground water data for {len(query_results)} districts including {', '.join(districts)}. Check the map visualization for underground water coverage areas and detailed results below." | |
| return "Unable to generate detailed analysis, but query executed successfully." | |
| def generate_summary_and_followups(self, user_query: str, query_results: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Generate a concise summary and 3 follow-up questions to deepen analysis.""" | |
| try: | |
| # Build compact, token-light context | |
| top_rows = [] | |
| for r in (query_results or [])[:5]: | |
| summary_row = { | |
| k: v for k, v in r.items() | |
| if k in { | |
| 'district', 'state', 'annual_gw_draft_total', 'stage_of_development', | |
| 'net_gw_availability', 'st_area_shape', 'st_length_shape' | |
| } and v is not None | |
| } | |
| top_rows.append(summary_row) | |
| prompt = ( | |
| "You are an assistant that outputs strict JSON. Given a user query and a small set " | |
| "of Indian groundwater results (with underground coverage metrics), produce: " | |
| "1) a one-paragraph summary (<= 80 words), 2) three concise follow-up questions.\n\n" | |
| f"User Query: {user_query}\n\n" | |
| f"Results Sample: {json.dumps(top_rows) }\n\n" | |
| "Respond ONLY as JSON with keys 'summary' and 'follow_ups' (array of 3 strings)." | |
| ) | |
| response = self.groq_client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "Output valid JSON only."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=self.models['intent_analyzer'], | |
| temperature=0.2, | |
| max_tokens=200 | |
| ) | |
| data = json.loads(response.choices[0].message.content) | |
| summary = data.get('summary') or "" | |
| follow_ups = data.get('follow_ups') or [] | |
| # Ensure exactly up to 3 | |
| follow_ups = [str(q) for q in follow_ups][:3] | |
| return {"summary": summary, "follow_ups": follow_ups} | |
| except Exception as e: | |
| logger.warning(f"Summary/follow-ups generation failed: {e}") | |
| # Sensible fallback | |
| fallback = [ | |
| "Do you want to compare two or more districts?", | |
| "Should I filter by over-exploited or critical status?", | |
| "Would you like a geographic view of underground coverage?" | |
| ] | |
| return {"summary": "", "follow_ups": fallback} | |
| def build_visualization_spec(self, user_query: str, intent_analysis: Dict[str, Any], query_results: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Derive a lightweight visualization spec without altering existing logic.""" | |
| try: | |
| if not query_results: | |
| return {"enabled": False} | |
| intent_type = intent_analysis.get("intent_type", "ranking") | |
| target_columns = intent_analysis.get("target_columns", ["annual_gw_draft_total"]) or ["annual_gw_draft_total"] | |
| primary = target_columns[0] | |
| # Prefer known numeric metrics | |
| numeric_preferences = [ | |
| "annual_gw_draft_total", | |
| "stage_of_development", | |
| "net_gw_availability", | |
| "annual_replenishable_gw_resource", | |
| "annual_draft_irrigation", | |
| "st_area_shape", | |
| "st_length_shape" | |
| ] | |
| metric = next((c for c in [primary] + numeric_preferences if any(c in r for r in query_results)), primary) | |
| # Fallback metric if not present | |
| if not any(metric in r for r in query_results): | |
| metric = "st_area_shape" if any("st_area_shape" in r for r in query_results) else primary | |
| spec: Dict[str, Any] = { | |
| "enabled": True, | |
| "chart_type": "bar", | |
| "x": "district" if any("district" in r for r in query_results) else None, | |
| "y": metric, | |
| "title": "", | |
| "top_n": 10, | |
| } | |
| if intent_type == "comparison": | |
| spec["title"] = f"Comparison of {metric.replace('_',' ').title()}" | |
| spec["chart_type"] = "bar" | |
| elif intent_type == "ranking": | |
| spec["title"] = f"Ranking by {metric.replace('_',' ').title()}" | |
| spec["chart_type"] = "bar" | |
| elif intent_type == "statistics": | |
| spec["title"] = f"Distribution of {metric.replace('_',' ').title()}" | |
| spec["chart_type"] = "histogram" | |
| spec["x"] = metric | |
| spec["y"] = None | |
| elif intent_type == "geographic": | |
| spec["title"] = "Underground Coverage by District" | |
| spec["chart_type"] = "bar" | |
| spec["y"] = "st_area_shape" if any("st_area_shape" in r for r in query_results) else metric | |
| return spec | |
| except Exception: | |
| return {"enabled": False} | |
| def compute_insights(self, query_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Compute actionable insights from a result set without additional API calls. | |
| Returns a list of {title, detail} objects suitable for display. | |
| """ | |
| try: | |
| if not query_results: | |
| return [] | |
| # Build a DataFrame-like view without importing pandas here | |
| rows = [] | |
| for r in query_results: | |
| try: | |
| rows.append({ | |
| 'district': r.get('district'), | |
| 'stage': float(r.get('stage_of_development')) if r.get('stage_of_development') not in (None, "") else None, | |
| 'draft_total': float(r.get('annual_gw_draft_total')) if r.get('annual_gw_draft_total') not in (None, "") else None, | |
| 'availability': float(r.get('net_gw_availability')) if r.get('net_gw_availability') not in (None, "") else None, | |
| 'replenishable': float(r.get('annual_replenishable_gw_resource')) if r.get('annual_replenishable_gw_resource') not in (None, "") else None, | |
| 'draft_irrigation': float(r.get('annual_draft_irrigation')) if r.get('annual_draft_irrigation') not in (None, "") else None, | |
| 'area': float(r.get('st_area_shape')) if r.get('st_area_shape') not in (None, "") else None, | |
| 'perimeter': float(r.get('st_length_shape')) if r.get('st_length_shape') not in (None, "") else None, | |
| }) | |
| except Exception: | |
| continue | |
| if not rows: | |
| return [] | |
| insights: List[Dict[str, Any]] = [] | |
| # Over-exploited and critical counts | |
| over_ex = [r for r in rows if r['stage'] is not None and r['stage'] > 100] | |
| critical = [r for r in rows if r['stage'] is not None and 80 <= r['stage'] <= 100] | |
| if over_ex: | |
| top_over = sorted(over_ex, key=lambda x: x['stage'], reverse=True)[:3] | |
| names = ", ".join([str(r.get('district', 'unknown')).title() for r in top_over]) | |
| insights.append({ | |
| "title": "Over‑exploited hotspots", | |
| "detail": f"{len(over_ex)} districts >100% development. Top: {names}." | |
| }) | |
| if critical: | |
| insights.append({ | |
| "title": "Critical watchlist", | |
| "detail": f"{len(critical)} districts between 80–100% development; prioritize monitoring." | |
| }) | |
| # Highest draft and availability gaps | |
| with_draft = [r for r in rows if r['draft_total'] is not None] | |
| if with_draft: | |
| top_draft = sorted(with_draft, key=lambda x: x['draft_total'], reverse=True)[:3] | |
| names = ", ".join([str(r.get('district', 'unknown')).title() for r in top_draft]) | |
| insights.append({ | |
| "title": "Top pressure points", | |
| "detail": f"Highest total draft in: {names}. Target demand management here first." | |
| }) | |
| # Supply-demand gap if both available | |
| gap_rows = [r for r in rows if r['availability'] is not None and r['draft_total'] is not None] | |
| if gap_rows: | |
| gaps = sorted(gap_rows, key=lambda x: (x['draft_total'] - x['availability']), reverse=True) | |
| worst = gaps[0] | |
| if worst: | |
| insights.append({ | |
| "title": "Availability gap", | |
| "detail": f"Largest draft minus availability gap in {str(worst.get('district', 'unknown')).title()}." | |
| }) | |
| # Recharge potential: big underground coverage areas | |
| with_area = [r for r in rows if r['area'] is not None] | |
| if with_area: | |
| top_area = sorted(with_area, key=lambda x: x['area'], reverse=True)[:3] | |
| names = ", ".join([str(r.get('district', 'unknown')).title() for r in top_area]) | |
| insights.append({ | |
| "title": "Recharge potential", | |
| "detail": f"Large underground coverage in: {names}. Consider MAR sites." | |
| }) | |
| # Complex boundaries: high perimeter relative to area (shape complexity) | |
| complex_rows = [r for r in rows if r['perimeter'] and r['area'] and r['area'] > 0] | |
| if complex_rows: | |
| # Complexity ~ perimeter / sqrt(area) | |
| ranked = sorted(complex_rows, key=lambda x: x['perimeter'] / max(1.0, x['area'] ** 0.5), reverse=True)[:3] | |
| names = ", ".join([str(r.get('district', 'unknown')).title() for r in ranked]) | |
| insights.append({ | |
| "title": "Boundary complexity", | |
| "detail": f"Complex underground boundaries in: {names}. Densify observation wells." | |
| }) | |
| # Ensure at least 5 insights by adding generic, data-backed items | |
| if len(insights) < 5 and with_draft: | |
| avg_draft = sum([r['draft_total'] for r in with_draft if r['draft_total'] is not None]) / max(1, len(with_draft)) | |
| insights.append({ | |
| "title": "Average draft benchmark", | |
| "detail": f"Avg annual draft across results is ~{avg_draft:,.0f} HM." | |
| }) | |
| if len(insights) < 5 and with_area: | |
| median_area = sorted([r['area'] for r in with_area if r['area'] is not None]) | |
| if median_area: | |
| mid = median_area[len(median_area)//2] | |
| insights.append({ | |
| "title": "Coverage benchmark", | |
| "detail": f"Median underground coverage area is ~{mid:,.0f} sq.m." | |
| }) | |
| return insights[:8] | |
| except Exception: | |
| return [] | |
| def chat(self, user_query: str) -> Dict[str, Any]: | |
| logger.info(f"Processing query: {user_query}") | |
| try: | |
| intent_analysis = self.analyze_user_intent(user_query) | |
| logger.info(f"Intent analysis: {intent_analysis}") | |
| # Build and execute the Supabase query directly | |
| query = self.build_supabase_query(user_query, intent_analysis) | |
| query_results = self.execute_supabase_query(query) | |
| if not query_results: | |
| return { | |
| "response": "Unable to retrieve data. This could be due to incorrect district names or database connectivity issues. Please try rephrasing your query.", | |
| "intent_analysis": intent_analysis, | |
| "results": [], | |
| "results_count": 0, | |
| "success": False | |
| } | |
| response = self.generate_response(user_query, query_results) | |
| viz_spec = self.build_visualization_spec(user_query, intent_analysis, query_results) | |
| aux = self.generate_summary_and_followups(user_query, query_results) | |
| insights = self.compute_insights(query_results) | |
| return { | |
| "response": response, | |
| "intent_analysis": intent_analysis, | |
| "results": query_results, | |
| "results_count": len(query_results), | |
| "success": True, | |
| "visualization": viz_spec, | |
| "summary": aux.get("summary", ""), | |
| "follow_ups": aux.get("follow_ups", []), | |
| "insights": insights | |
| } | |
| except Exception as e: | |
| logger.error(f"Chat processing error: {e}") | |
| return { | |
| "response": f"An error occurred while processing your query: {str(e)}", | |
| "intent_analysis": {"error": str(e)}, | |
| "results": [], | |
| "results_count": 0, | |
| "success": False | |
| } |