""" Core policy engine: transforms research into structured policies. Uses LLM with structured prompting and Pydantic validation. """ import os # Enable fast download for HuggingFace (must be set before other imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import json import re from typing import List, Optional, Dict, Any from pydantic import BaseModel, ValidationError, Field from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings # from langchain_groq import ChatGroq from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.prompts import PromptTemplate import config # ============================================================================ # PYDANTIC MODELS # ============================================================================ class PolicyMutation(BaseModel): """Represents a single mutation to the causal graph.""" type: str = Field(..., description="disable_node | reduce_edge_weight | increase_edge_weight") node_id: Optional[str] = Field(None, description="For disable_node mutations") source: Optional[str] = Field(None, description="For edge mutations") target: Optional[str] = Field(None, description="For edge mutations") new_weight: Optional[float] = Field(None, description="New edge weight [0.0, 1.0]") original_weight: Optional[float] = Field(None, description="Original weight (optional)") reason: str = Field(..., description="Why this mutation is applied") reversible: bool = Field(True, description="Can this be undone?") class SourceResearch(BaseModel): """Research evidence backing the policy.""" paper_ids: List[str] = Field(default_factory=list) key_quotes: List[str] = Field(default_factory=list) confidence: float = Field(0.8, description="Confidence in policy [0.0, 1.0]") class TradeOff(BaseModel): """Trade-off from policy implementation.""" sector: str impact: str = Field(..., description="positive | negative | neutral") magnitude: str = Field(..., description="mild | moderate | strong") description: str class EstimatedImpact(BaseModel): """Estimated system-wide impacts.""" co2_reduction_pct: float = Field(0.0, description="% reduction in CO₂") aqi_improvement_pct: float = Field(0.0, description="% reduction in AQI") confidence: float = Field(0.7) class Policy(BaseModel): """Complete structured policy JSON.""" policy_id: str name: str description: Optional[str] = None mutations: List[PolicyMutation] estimated_impacts: EstimatedImpact trade_offs: List[TradeOff] = Field(default_factory=list) source_research: SourceResearch timestamp: Optional[str] = None # ============================================================================ # POLICY ENGINE # ============================================================================ class PolicyEngine: """Converts research insights into structured policies via LLM.""" def __init__(self): """Initialize FAISS index and LLM.""" # SKIP EMBEDDINGS DOWNLOAD: "Can we not do it at all?" # User requested instant startup without RAG. # ----------------------------------------------------------------- # from huggingface_hub import snapshot_download print("\n[INFO] RAG/Embeddings initialization SKIPPED by configuration.") print("[INFO] Policy Engine running in 'Direct Query' mode (LLM only).") self.embeddings = None self.db = None # ----------------------------------------------------------------- # Original logic commented out to prevent 400MB+ download on Spaces: # # print(f"Initializing PolicyEngine with model: {config.EMBEDDINGS_MODEL}") # try: # snapshot_download(...) # except Exception: ... # self.embeddings = HuggingFaceEmbeddings(...) # self.db = FAISS.load_local(...) # ----------------------------------------------------------------- # self.llm = ChatGroq( # model=config.LLM_MODEL, # temperature=0.5, # api_key=config.GROQ_API_KEY # ) if not config.GEMINI_API_KEY: print("WARNING: GEMINI_API_KEY is not set. AI features will fail.") self.llm = ChatGoogleGenerativeAI( model="gemini-2.5-flash", temperature=0.5, google_api_key=config.GEMINI_API_KEY ) def query_research(self, question: str, k: int = None) -> tuple[List[str], bool]: """ Retrieve research chunks from FAISS. Args: question: Search query k: Number of results (default from config) Returns: Tuple of (research chunks or [question], is_direct_query) """ if not self.db: print(f"FAISS DB not initialized, using direct query: {question}") # Return the user's query directly when FAISS is not available return [question], True k = k or config.FAISS_K_SEARCH results = self.db.similarity_search(question, k=k) return [r.page_content for r in results], False def extract_policy( self, research_chunks: List[str], graph_context: Dict[str, Any], is_direct_query: bool = False, user_query: str = "" ) -> Policy: """ Use LLM to extract structured policy from research. Args: research_chunks: List of research excerpts or [user_query] if direct graph_context: Dict with node/edge structure for validation is_direct_query: True if research_chunks contains user query (no FAISS) user_query: The original user query string (optional but recommended) Returns: Validated Policy object """ # Format content for prompt # Ensure all chunks are strings (handle potential nested lists) flat_chunks = [] for chunk in research_chunks: if isinstance(chunk, list): flat_chunks.extend([str(c) for c in chunk]) else: flat_chunks.append(str(chunk)) # Determine intent from user_query if available, otherwise check chunks query_text = user_query if user_query else (flat_chunks[0] if flat_chunks else "") query_lower = query_text.lower() increase_emissions = any(word in query_lower for word in ["increase", "raise", "boost", "expand", "grow", "worsen", "high"]) if is_direct_query: research_section = f"USER QUERY: {query_text}\n\nUse your knowledge to create a policy addressing this query." else: research_section = f"USER QUERY: {query_text}\n\nRESEARCH FINDINGS:\n" + "\n---\n".join(flat_chunks) formatted_nodes = "" disabled_nodes = [] # Handle new graph_context structure (list of dicts) vs old (list of ids) if 'nodes' in graph_context and isinstance(graph_context['nodes'], list): formatted_nodes = ", ".join([n['id'] for n in graph_context['nodes']]) disabled_nodes = [n['id'] for n in graph_context['nodes'] if not n.get('enabled', True)] else: formatted_nodes = ", ".join(graph_context.get("node_ids", [])) formatted_edges = "\n".join([ f" {e['source']}->{e['target']} (current weight: {e.get('weight', 0.5)})" for e in graph_context.get("edges", []) ]) disabled_section = "" if disabled_nodes: disabled_section = f""" DISABLED SECTORS (User has manually disconnected these): {', '.join(disabled_nodes)} IMPORTANT: The above sectors are DISCONNECTED. - Do NOT try to modify edges coming FROM these sectors, as they have no effect. - You should acknowledge that they are disabled in your reasoning. - Focus policies on the remaining ACTIVE sectors to achieve the goal.""" # Determine policy direction if increase_emissions: task_description = "INCREASE emissions" mutation_type = "increase_edge_weight" mechanics_section = """TO INCREASE EMISSIONS: - MUST increase the edge weight to a LARGER number - Example: 0.7 → 0.9 (increases flow by 28%) - Example: 0.5 → 0.75 (increases flow by 50%) - Example: 0.4 → 0.7 (increases flow by 75%) CRITICAL: new_weight MUST BE GREATER THAN current weight. Do not decrease!""" correct_examples = """ ✓ Change transport→co2 from 0.7 to 0.9 (increases flow) ✓ Change energy→co2 from 0.8 to 0.95 (increases propagation)""" estimated_field = "co2_increase_pct" else: task_description = "REDUCE emissions" mutation_type = "reduce_edge_weight" mechanics_section = """TO REDUCE EMISSIONS: - MUST decrease the edge weight to a SMALLER number - Example: 0.7 → 0.35 (50% reduction in flow) - Example: 0.5 → 0.25 (50% reduction in flow) - Example: 0.8 → 0.4 (50% reduction in flow) CRITICAL: new_weight MUST BE LESS THAN current weight. Do not increase!""" correct_examples = """ ✓ Change transport→co2 from 0.7 to 0.35 (cuts in half) ✓ Change transport→co2 from 0.7 to 0.49 (30% reduction) ✓ Change energy→co2 from 0.8 to 0.48 (40% reduction)""" estimated_field = "co2_reduction_pct" prompt = f"""You are a climate policy expert. Your task is to design policies that {task_description}. {research_section} {disabled_section} CURRENT SYSTEM EDGES (with current weights): {formatted_edges} Available nodes: {formatted_nodes} EMISSION MECHANICS - READ CAREFULLY: The system propagates emissions through edges. Each edge has a weight (0.0 to 1.0): - target_value = source_value × weight - A weight of 0.7 means 70% of the source value flows to the target - A weight of 0.3 means 30% of the source value flows to the target {mechanics_section} CRITICAL INSTRUCTION: 1. You MUST follow the TASK direction ({task_description}). 2. USE ONLY THE EDGES LISTED ABOVE. Do not hallucinate connections. 3. If a node (e.g., 'transport') is NOT in the "Available nodes" list, YOU CANNOT CREATE A POLICY FOR IT. 4. If a node is listed in DISABLED SECTORS, do not attempt to change its edges (it is already off). 5. If the user asks to INCREASE emissions, you MUST generate a policy that INCREASES them. 6. Ignore research advice if it contradicts the goal to {task_description}. 7. Use the research only for context on *what* to modify, but reverse the action if needed to match the goal. WRONG EXAMPLES (DO NOT DO THIS): ✗ Change 0.7 to 0.75 (wrong direction) ✗ Change 0.5 to 0.8 (wrong direction) ✗ Change 0.6 to 0.7 (wrong direction) CORRECT EXAMPLES: {correct_examples} TASK: Create ONE policy to {task_description} by addressing: "{flat_chunks[0][:100] if flat_chunks else 'emissions policy'}..." If the research suggests "Promoting EV", and your task is to INCREASE emissions, your policy should be "Tax EVs / Promote Gas Cars". Return ONLY valid JSON: {{ "policy_id": "policy-slug", "name": "Policy Name", "description": "Description", "mutations": [ {{ "type": "{mutation_type}", "source": "transport", "target": "vehicle-emissions", "new_weight": {"0.35" if not increase_emissions else "0.9"}, "original_weight": 0.7, "reason": "Research shows..." }} ], "estimated_impacts": {{ "{estimated_field}": {"15.0" if not increase_emissions else "12.0"}, "aqi_improvement_pct": {"18.0" if not increase_emissions else "-15.0"}, "confidence": 0.8 }}, "trade_offs": [], "source_research": {{ "paper_ids": [], "key_quotes": [], "confidence": 0.85 }} }}""" # Call LLM response = self.llm.invoke(prompt) response_text = response.content # Extract JSON (handle markdown code blocks) json_match = re.search(r'\{[\s\S]*\}', response_text) if not json_match: raise ValueError(f"No JSON found in LLM response: {response_text}") json_str = json_match.group() policy_dict = json.loads(json_str) # Clean up trade_offs - LLM sometimes returns strings instead of objects if 'trade_offs' in policy_dict: cleaned_trade_offs = [] for item in policy_dict['trade_offs']: if isinstance(item, str): # Convert string to proper TradeOff object cleaned_trade_offs.append({ 'sector': 'general', 'impact': 'neutral', 'magnitude': 'mild', 'description': item }) elif isinstance(item, dict): cleaned_trade_offs.append(item) policy_dict['trade_offs'] = cleaned_trade_offs # Validate against schema policy = Policy(**policy_dict) # Log the mutations for debugging print(f"\n[Policy Generated]") print(f"Policy: {policy.name}") print(f"Description: {policy.description}") print(f"Mutations: {len(policy.mutations)}") for i, mut in enumerate(policy.mutations): if mut.type in ["reduce_edge_weight", "increase_edge_weight"]: print(f" {i+1}. {mut.type}: {mut.source} -> {mut.target}") print(f" New weight: {mut.new_weight} (reason: {mut.reason})") print(f"Estimated CO₂ reduction: {policy.estimated_impacts.co2_reduction_pct}%") print(f"Estimated AQI improvement: {policy.estimated_impacts.aqi_improvement_pct}%\n") # Validate mutations reference real nodes/edges self._validate_mutations(policy, graph_context) return policy def _validate_mutations(self, policy: Policy, graph_context: Dict) -> None: """Ensure mutations reference real nodes/edges.""" # Handle both old (node_ids list) and new (nodes dict list) formats if 'nodes' in graph_context and isinstance(graph_context['nodes'], list): node_ids = set(n['id'] for n in graph_context['nodes']) else: node_ids = set(graph_context.get("node_ids", [])) edge_pairs = set((e['source'], e['target']) for e in graph_context.get("edges", [])) for mutation in policy.mutations: if mutation.type == "disable_node": if mutation.node_id not in node_ids: raise ValueError(f"Unknown node: {mutation.node_id}") elif mutation.type in ["reduce_edge_weight", "increase_edge_weight"]: if (mutation.source, mutation.target) not in edge_pairs: raise ValueError(f"Unknown edge: {mutation.source} -> {mutation.target}") if mutation.new_weight is None or not (0.0 <= mutation.new_weight <= 1.0): raise ValueError(f"Invalid weight: {mutation.new_weight}") # ============================================================================ # HELPER FUNCTIONS # ============================================================================ def get_graph_context_from_file(filepath: str) -> Dict[str, Any]: """ Load graph context (nodes, edges) from snapshot file. Used for validation during policy extraction. """ try: with open(filepath, 'r') as f: data = json.load(f) # Extract node IDs node_ids = [n['id'] for n in data.get('nodes', [])] # Extract edges with weights edges = [ { 'source': e['source'], 'target': e['target'], 'weight': e.get('data', {}).get('weight', 0.5) } for e in data.get('edges', []) ] return { 'node_ids': node_ids, 'edges': edges, 'full_graph': data } except Exception as e: print(f"Could not load graph context: {e}") return { 'node_ids': [ 'industries', 'transport', 'energy', 'infrastructure', 'moves-goods', 'uses-power', 'fuels-transport', 'co2', 'aqi' ], 'edges': [ {'source': 'industries', 'target': 'transport', 'weight': 0.6}, {'source': 'transport', 'target': 'vehicle-emissions', 'weight': 0.7}, {'source': 'energy', 'target': 'co2', 'weight': 0.8}, {'source': 'co2', 'target': 'aqi', 'weight': 0.9} ] }