""" Flask API server for the policy engine. Endpoints for generating policies, applying them, and analyzing impacts. """ from flask import Flask, request, jsonify from flask_cors import CORS from datetime import datetime import json import traceback from policy_engine import PolicyEngine, get_graph_context_from_file from graph_engine import GraphState, ImpactAnalyzer from health_analyzer import HealthImpactAnalyzer from explainability import generate_policy_explanation from explainability import generate_policy_explanation from aqi import register_aqi_routes from emission_forecast import register_emission_routes import config import os # ============================================================================ # SETUP # ============================================================================ app = Flask(__name__) CORS(app) # Enable CORS for frontend requests # Initialize engines policy_engine = PolicyEngine() health_analyzer = HealthImpactAnalyzer() # Register AQI routes register_aqi_routes(app) register_emission_routes(app) from aqi_history import register_aqi_history_routes register_aqi_history_routes(app) # Pre-initialize AQI Data & Model to prevent timeout on first request from aqi_history import get_aqi_history print("Initializing AQI History & Model...") get_aqi_history() from aqi_map import generate_heatmap_html, generate_hotspots_html, generate_forecast_hotspots_html, render_map_to_png from emission_map import generate_emission_heatmap_html, generate_emission_hotspots_html, generate_forecast_emission_hotspots_html, render_map_to_png as render_emission_map_to_png import threading # Ensure map directory exists STATIC_DIR = os.path.join(os.getcwd(), 'static') os.makedirs(STATIC_DIR, exist_ok=True) HEATMAP_PATH = os.path.join(STATIC_DIR, 'aqi_map_heatmap.png') HOTSPOTS_PATH = os.path.join(STATIC_DIR, 'aqi_map_hotspots.png') # Emission map paths EMISSION_HEATMAP_PATH = os.path.join(STATIC_DIR, 'emission_map_heatmap.png') EMISSION_HOTSPOTS_PATH = os.path.join(STATIC_DIR, 'emission_map_hotspots.png') # Default map path (alias to heatmap for backward compat) MAP_IMAGE_PATH = HEATMAP_PATH # ============================================================================ # STARTUP CLEANUP - Remove orphaned temp files from crashed runs # ============================================================================ def cleanup_temp_files(): """Remove orphaned temp_*.html files from previous crashed runs.""" import glob backend_dir = os.path.dirname(os.path.abspath(__file__)) patterns = [ os.path.join(backend_dir, 'temp_map_*.html'), os.path.join(backend_dir, 'temp_emission_map_*.html'), ] cleaned = 0 for pattern in patterns: for temp_file in glob.glob(pattern): try: os.remove(temp_file) cleaned += 1 print(f"Cleaned up orphaned temp file: {os.path.basename(temp_file)}") except Exception as e: print(f"Failed to cleanup {temp_file}: {e}") if cleaned > 0: print(f"Startup cleanup: Removed {cleaned} orphaned temp file(s)") # Run cleanup on startup cleanup_temp_files() @app.route('/api/aqi-map', methods=['GET']) def get_aqi_map_html(): """Returns the Interactive Heatmap HTML (Default)""" return generate_heatmap_html() @app.route('/api/aqi-map/heatmap', methods=['GET']) def get_heatmap_html(): """Explicit endpoint for Heatmap HTML""" return generate_heatmap_html() @app.route('/api/aqi-map/hotspots', methods=['GET']) def get_hotspots_html(): """Returns the Interactive Hotspots HTML. Accepts optional 'year' param for forecast.""" year = request.args.get('year', type=int) # Validate year if provided if year and year in [2026, 2027, 2028]: return generate_forecast_hotspots_html(year) return generate_hotspots_html() @app.route('/api/aqi-map.png', methods=['GET']) def get_aqi_map_png(): """Default map image (Heatmap)""" return get_heatmap_png() @app.route('/api/aqi-map/heatmap.png', methods=['GET']) def get_heatmap_png(): """Returns the Heatmap Grid PNG""" from flask import send_file # Simple caching/regeneration logic if not os.path.exists(HEATMAP_PATH) or request.args.get('refresh'): html = generate_heatmap_html() render_map_to_png(html, HEATMAP_PATH) return send_file(HEATMAP_PATH, mimetype='image/png') @app.route('/api/aqi-map/hotspots.png', methods=['GET']) def get_hotspots_png(): """Returns the Hotspots PNG. Accepts optional 'year' param for forecast (2026-2028).""" from flask import send_file year = request.args.get('year', type=int) # Validate year if provided if year and year not in [2026, 2027, 2028]: year = None if year: # Generate forecast-based hotspots for the selected year cache_path = os.path.join(STATIC_DIR, f'aqi_map_hotspots_{year}.png') if not os.path.exists(cache_path) or request.args.get('refresh'): html = generate_forecast_hotspots_html(year) render_map_to_png(html, cache_path) return send_file(cache_path, mimetype='image/png') else: # Original baseline hotspots if not os.path.exists(HOTSPOTS_PATH) or request.args.get('refresh'): html = generate_hotspots_html() render_map_to_png(html, HOTSPOTS_PATH) return send_file(HOTSPOTS_PATH, mimetype='image/png') # ============================================================================ # EMISSION MAPS - CO2 Visualization # ============================================================================ @app.route('/api/emission-map', methods=['GET']) def get_emission_map_html(): """Returns the Interactive Emission Heatmap HTML (Default)""" return generate_emission_heatmap_html() @app.route('/api/emission-map/heatmap', methods=['GET']) def get_emission_heatmap_html(): """Explicit endpoint for Emission Heatmap HTML""" return generate_emission_heatmap_html() @app.route('/api/emission-map/hotspots', methods=['GET']) def get_emission_hotspots_html(): """Returns the Interactive Emission Hotspots HTML. Accepts optional 'year' param.""" year = request.args.get('year', type=int) if year and year in [2026, 2027, 2028]: return generate_forecast_emission_hotspots_html(year) return generate_emission_hotspots_html() @app.route('/api/emission-map/heatmap.png', methods=['GET']) def get_emission_heatmap_png(): """Returns the Emission Heatmap Grid PNG""" from flask import send_file if not os.path.exists(EMISSION_HEATMAP_PATH) or request.args.get('refresh'): html = generate_emission_heatmap_html() render_emission_map_to_png(html, EMISSION_HEATMAP_PATH) return send_file(EMISSION_HEATMAP_PATH, mimetype='image/png') @app.route('/api/emission-map/hotspots.png', methods=['GET']) def get_emission_hotspots_png(): """Returns the Emission Hotspots PNG. Accepts optional 'year' param (2026-2028).""" from flask import send_file year = request.args.get('year', type=int) if year and year not in [2026, 2027, 2028]: year = None if year: cache_path = os.path.join(STATIC_DIR, f'emission_map_hotspots_{year}.png') if not os.path.exists(cache_path) or request.args.get('refresh'): html = generate_forecast_emission_hotspots_html(year) render_emission_map_to_png(html, cache_path) return send_file(cache_path, mimetype='image/png') else: if not os.path.exists(EMISSION_HOTSPOTS_PATH) or request.args.get('refresh'): html = generate_emission_hotspots_html() render_emission_map_to_png(html, EMISSION_HOTSPOTS_PATH) return send_file(EMISSION_HOTSPOTS_PATH, mimetype='image/png') # ============================================================================ # SECTOR-SPECIFIC EMISSION MAPS # ============================================================================ VALID_SECTORS = ['Industry', 'Transport', 'Power', 'Residential', 'Aviation', 'Commercial'] @app.route('/api/emission-map/sector/heatmap', methods=['GET']) def get_sector_heatmap_html(): """Returns sector-specific emission heatmap HTML.""" from emission_map import generate_sector_heatmap_html sector = request.args.get('sector', 'Industry') if sector not in VALID_SECTORS: sector = 'Industry' return generate_sector_heatmap_html(sector) @app.route('/api/emission-map/sector/heatmap.png', methods=['GET']) def get_sector_heatmap_png(): """Returns sector-specific emission heatmap PNG.""" from flask import send_file from emission_map import generate_sector_heatmap_html, render_map_to_png sector = request.args.get('sector', 'Industry') if sector not in VALID_SECTORS: sector = 'Industry' cache_path = os.path.join(STATIC_DIR, f'sector_heatmap_{sector.lower()}.png') if not os.path.exists(cache_path) or request.args.get('refresh'): html = generate_sector_heatmap_html(sector) render_map_to_png(html, cache_path) return send_file(cache_path, mimetype='image/png') @app.route('/api/emission-map/sector/hotspots', methods=['GET']) def get_sector_hotspots_html(): """Returns sector-specific emission hotspots HTML.""" from emission_map import generate_sector_hotspots_html sector = request.args.get('sector', 'Industry') year = request.args.get('year', type=int) if sector not in VALID_SECTORS: sector = 'Industry' if year and year not in [2026, 2027, 2028]: year = None return generate_sector_hotspots_html(sector, year) @app.route('/api/emission-map/sector/hotspots.png', methods=['GET']) def get_sector_hotspots_png(): """Returns sector-specific emission hotspots PNG.""" from flask import send_file from emission_map import generate_sector_hotspots_html, render_map_to_png sector = request.args.get('sector', 'Industry') year = request.args.get('year', type=int) if sector not in VALID_SECTORS: sector = 'Industry' if year and year not in [2026, 2027, 2028]: year = None cache_key = f'sector_hotspots_{sector.lower()}' if year: cache_key += f'_{year}' cache_path = os.path.join(STATIC_DIR, f'{cache_key}.png') if not os.path.exists(cache_path) or request.args.get('refresh'): html = generate_sector_hotspots_html(sector, year) render_map_to_png(html, cache_path) return send_file(cache_path, mimetype='image/png') # Background task to refresh map periodically (optional) def refresh_map_periodically(): while True: try: print("Refreshing AQI Map image...") html = generate_aqi_map_html() render_map_to_png(html, MAP_IMAGE_PATH) except Exception as e: print(f"Error refreshing map: {e}") time.sleep(3600) # Every hour # Start background thread for map refresh # threading.Thread(target=refresh_map_periodically, daemon=True).start() # ============================================================================ # HELPER FUNCTIONS # ============================================================================ def get_graph_state(): """Load current graph state.""" try: return GraphState.from_file(str(config.GRAPH_STATE_PATH)) except Exception as e: print(f"Error loading graph: {e}") return None # ============================================================================ # API ENDPOINTS # ============================================================================ @app.route('/', methods=['GET']) def index(): """Root endpoint to show server status.""" return jsonify({ 'message': 'Digital Twin Policy Engine API is running', 'endpoints': { 'health': '/api/health', 'aqi': '/api/aqi?lat=28.7041&lon=77.1025' }, 'status': 'active' }) @app.route('/api/health', methods=['GET']) def health_check(): """Health check endpoint.""" return jsonify({ 'status': 'ok', 'timestamp': datetime.now().isoformat(), 'service': 'policy-engine' }) @app.route('/api/graph-state', methods=['GET']) def get_current_graph_state(): """ Get current graph state (nodes, edges, values). Used by frontend for validation and context. """ try: graph = get_graph_state() if not graph: return jsonify({'error': 'Could not load graph'}), 500 return jsonify({ 'status': 'success', 'graph': graph.to_dict(), 'timestamp': datetime.now().isoformat() }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/api/generate-policy', methods=['POST']) def generate_policy(): """ Generate policy from research query. Request body: { "research_query": "How to reduce transport emissions?" } Returns: { "policy": { Policy JSON }, "research_evidence": ["chunk1", "chunk2", ...], "status": "success" } """ try: data = request.json query = data.get('research_query') if not query: return jsonify({'error': 'Missing research_query'}), 400 # Retrieve research research_chunks = policy_engine.query_research(query, k=3) # Get graph context for validation graph_context = data.get('graph_context') if not graph_context: print("Using static graph context from file") graph_context = get_graph_context_from_file(str(config.GRAPH_STATE_PATH)) else: print("Using dynamic graph context from frontend") # Extract policy via LLM policy = policy_engine.extract_policy(research_chunks, graph_context, user_query=query) return jsonify({ 'status': 'success', 'policy': policy.dict(), 'research_evidence': research_chunks, 'timestamp': datetime.now().isoformat() }) except Exception as e: print(f"Error in generate_policy: {e}") traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/apply-policy', methods=['POST']) def apply_policy(): """ Apply policy to graph and calculate impact. Request body: { "policy": { Policy JSON from generate-policy } } Returns: { "snapshot": { "scenario_id": "id", "policy_id": "id", "impact": { CO2 and AQI changes }, "cascade_analysis": { affected nodes }, ... }, "status": "success" } """ try: data = request.json policy_dict = data.get('policy') if not policy_dict: return jsonify({'error': 'Missing policy'}), 400 # Load baseline graph_context = data.get('graph_context') if graph_context: print("Using dynamic graph context for baseline") # Reconstruct GraphState from context # Context has { node_ids: [], edges: [{source, target, weight}] } # We need to map this back to the GraphState structure try: # Load full default graph to get default node data (values, labels) which might be missing in context list file_graph = get_graph_state() # Reconstruct nodes with enabled status context_nodes = graph_context.get('nodes', []) if context_nodes and isinstance(context_nodes, list): # New format: list of dicts with enabled status reconstructed_nodes = [] for n_ctx in context_nodes: # Find original node data to preserve other fields (x, y, label, etc) original = next((on for on in file_graph.nodes if on['id'] == n_ctx['id']), None) if original: new_node = original.copy() if 'data' not in new_node: new_node['data'] = {} new_node['data']['enabled'] = n_ctx.get('enabled', True) reconstructed_nodes.append(new_node) baseline_nodes = reconstructed_nodes else: # Old format: list of IDs valid_node_ids = set(graph_context.get('node_ids', context_nodes)) # fallback if context_nodes is list of strings baseline_nodes = [n for n in file_graph.nodes if n['id'] in valid_node_ids] # Reconstruct edges context_edges = graph_context.get('edges', []) reconstructed_edges = [] for e in context_edges: reconstructed_edges.append({ 'source': e['source'], 'target': e['target'], 'data': {'weight': e['weight']} }) baseline = GraphState(baseline_nodes, reconstructed_edges) except Exception as e: print(f"Error reconstructing dynamic baseline: {e}") baseline = get_graph_state() else: baseline = get_graph_state() if not baseline: return jsonify({'error': 'Could not load baseline graph'}), 500 # Create post-policy state (deep copy baseline) import copy post_policy = GraphState( copy.deepcopy(baseline.nodes), copy.deepcopy(baseline.edges) ) # Apply mutations mutation_results = post_policy.apply_policy(policy_dict) # Log mutations with before/after values print(f"\n[Policy Applied]") # Print transport value if it exists, safely trans_node = baseline.get_node('transport') if trans_node: print(f"Baseline Transport value: {trans_node['data'].get('value', 'N/A')}") print(f"Baseline CO2 value: {baseline.get_node('co2')['data']['value']}") print(f"Baseline AQI value: {baseline.get_node('aqi')['data']['value']}") for i, mut in enumerate(mutation_results.get('mutations_applied', [])): print(f"Mutation {i+1}: {mut['type']}") if 'before' in mut and 'after' in mut: print(f" Before: {mut['before']}") print(f" After: {mut['after']}") # Calculate impact analyzer = ImpactAnalyzer(baseline, post_policy) impact = analyzer.calculate_impact() if trans_node: post_trans = post_policy.get_node('transport') print(f"Post-policy Transport value: {post_trans['data'].get('value', 'N/A')}") print(f"Post-policy CO2 value: {post_policy.get_node('co2')['data']['value']}") print(f"Post-policy AQI value: {post_policy.get_node('aqi')['data']['value']}") print(f"Impact: CO₂ change {impact['co2']['change_pct']:.1f}%, AQI change {impact['aqi']['change_pct']:.1f}%\n") # Create snapshot snapshot = { 'scenario_id': f"snap-{datetime.now().strftime('%Y%m%d-%H%M%S')}", 'policy_id': policy_dict.get('policy_id'), 'policy_name': policy_dict.get('name'), 'baseline_graph': baseline.to_dict(), 'post_policy_graph': post_policy.to_dict(), 'mutations_applied': mutation_results['mutations_applied'], 'impact': impact, 'timestamp': datetime.now().isoformat() } return jsonify({ 'snapshot': snapshot, 'status': 'success' }) except Exception as e: print(f"Error in apply_policy: {e}") traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/compare-scenarios', methods=['POST']) def compare_scenarios(): """ Compare multiple scenarios side-by-side. Request body: { "scenarios": [ { "name": "Scenario 1", "policy": { Policy JSON } }, { "name": "Scenario 2", "policy": { Policy JSON } }, ... ] } Returns: { "comparison": [ Snapshots for each scenario ], "ranking": { "best_co2_reduction": "Scenario name", "best_aqi_improvement": "Scenario name" } } """ try: data = request.json scenarios = data.get('scenarios', []) if not scenarios: return jsonify({'error': 'Missing scenarios'}), 400 results = [] for scenario in scenarios: # Apply each policy baseline = get_graph_state() if not baseline: continue import copy post_policy = GraphState( copy.deepcopy(baseline.nodes), copy.deepcopy(baseline.edges) ) post_policy.apply_policy(scenario.get('policy', {})) analyzer = ImpactAnalyzer(baseline, post_policy) impact = analyzer.calculate_impact() results.append({ 'name': scenario.get('name'), 'impact': impact }) # Rank by impact best_co2 = max(results, key=lambda r: abs(r['impact']['co2']['change_pct']), default={}) best_aqi = max(results, key=lambda r: abs(r['impact']['aqi']['change_pct']), default={}) return jsonify({ 'status': 'success', 'comparison': results, 'ranking': { 'best_co2_reduction': best_co2.get('name'), 'best_aqi_improvement': best_aqi.get('name') }, 'timestamp': datetime.now().isoformat() }) except Exception as e: print(f"Error in compare_scenarios: {e}") traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/explain-policy', methods=['POST']) def explain_policy(): """ Generate explanation for a policy. Request body: { "policy": { Policy JSON }, "research_evidence": ["chunk1", "chunk2", ...] } Returns: { "explanation": { "policy_id": "...", "narrative_intro": "...", "mutations": [ { narrative, supporting_research, stakeholders } ], "overall_narrative": "..." } } """ try: data = request.json policy = data.get('policy') research_evidence = data.get('research_evidence', []) if not policy: return jsonify({'error': 'Missing policy'}), 400 # Generate explanation explanation = generate_policy_explanation(policy, research_evidence) return jsonify({ 'status': 'success', 'explanation': explanation, 'timestamp': datetime.now().isoformat() }) except Exception as e: print(f"Error in explain_policy: {e}") traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/analyze-aqi-health', methods=['POST']) def analyze_aqi_health(): """ Analyze health impacts based on AQI data. Request body: { "aqi_data": { "aqi": 410, "city": "Delhi", "pm2_5": 250, "pm10": 400, "no2": 85, "o3": 45, "so2": 15, "co": 1.5 } } Response: { "health_impact": { Health analysis from Gemini }, "status": "success", "timestamp": "..." } """ try: data = request.json aqi_data = data.get('aqi_data') if not aqi_data: return jsonify({'error': 'Missing aqi_data'}), 400 # Analyze health impact health_impact = health_analyzer.analyze_aqi_health(aqi_data) return jsonify({ 'health_impact': health_impact, 'status': 'success', 'timestamp': datetime.now().isoformat() }) except Exception as e: print(f"Error analyzing AQI health: {e}") print(traceback.format_exc()) return jsonify({'error': str(e)}), 500 @app.route('/api/chat-health', methods=['POST']) def chat_health(): """ Chat with health expert. Request: { "message": "Should I wear a mask?", "context": { "aqi": 350, "city": "Delhi", ... } } """ try: data = request.json message = data.get('message') context = data.get('context', {}) if not message: return jsonify({'error': 'Missing message'}), 400 response = health_analyzer.chat_with_health_expert(message, context) return jsonify({ 'response': response, 'status': 'success' }) except Exception as e: print(f"Error in chat_health: {e}") return jsonify({'error': str(e)}), 500 # ============================================================================ # ERROR HANDLERS # ============================================================================ @app.errorhandler(404) def not_found(error): return jsonify({'error': 'Not found'}), 404 @app.errorhandler(500) def internal_error(error): return jsonify({'error': 'Internal server error'}), 500 # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": is_hf = bool(os.getenv("HF_SPACE_ID")) port = 7860 if is_hf else config.FLASK_PORT debug = False if is_hf else config.FLASK_DEBUG app.run( host="0.0.0.0", port=port, debug=debug )