digital-twin-backend / policy_engine.py
AlexKurian's picture
Added gemini
3c54c83
"""
Core policy engine: transforms research into structured policies.
Uses LLM with structured prompting and Pydantic validation.
"""
import os
# Enable fast download for HuggingFace (must be set before other imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import json
import re
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, ValidationError, Field
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
# from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import PromptTemplate
import config
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class PolicyMutation(BaseModel):
"""Represents a single mutation to the causal graph."""
type: str = Field(..., description="disable_node | reduce_edge_weight | increase_edge_weight")
node_id: Optional[str] = Field(None, description="For disable_node mutations")
source: Optional[str] = Field(None, description="For edge mutations")
target: Optional[str] = Field(None, description="For edge mutations")
new_weight: Optional[float] = Field(None, description="New edge weight [0.0, 1.0]")
original_weight: Optional[float] = Field(None, description="Original weight (optional)")
reason: str = Field(..., description="Why this mutation is applied")
reversible: bool = Field(True, description="Can this be undone?")
class SourceResearch(BaseModel):
"""Research evidence backing the policy."""
paper_ids: List[str] = Field(default_factory=list)
key_quotes: List[str] = Field(default_factory=list)
confidence: float = Field(0.8, description="Confidence in policy [0.0, 1.0]")
class TradeOff(BaseModel):
"""Trade-off from policy implementation."""
sector: str
impact: str = Field(..., description="positive | negative | neutral")
magnitude: str = Field(..., description="mild | moderate | strong")
description: str
class EstimatedImpact(BaseModel):
"""Estimated system-wide impacts."""
co2_reduction_pct: float = Field(0.0, description="% reduction in CO₂")
aqi_improvement_pct: float = Field(0.0, description="% reduction in AQI")
confidence: float = Field(0.7)
class Policy(BaseModel):
"""Complete structured policy JSON."""
policy_id: str
name: str
description: Optional[str] = None
mutations: List[PolicyMutation]
estimated_impacts: EstimatedImpact
trade_offs: List[TradeOff] = Field(default_factory=list)
source_research: SourceResearch
timestamp: Optional[str] = None
# ============================================================================
# POLICY ENGINE
# ============================================================================
class PolicyEngine:
"""Converts research insights into structured policies via LLM."""
def __init__(self):
"""Initialize FAISS index and LLM."""
# SKIP EMBEDDINGS DOWNLOAD: "Can we not do it at all?"
# User requested instant startup without RAG.
# -----------------------------------------------------------------
# from huggingface_hub import snapshot_download
print("\n[INFO] RAG/Embeddings initialization SKIPPED by configuration.")
print("[INFO] Policy Engine running in 'Direct Query' mode (LLM only).")
self.embeddings = None
self.db = None
# -----------------------------------------------------------------
# Original logic commented out to prevent 400MB+ download on Spaces:
#
# print(f"Initializing PolicyEngine with model: {config.EMBEDDINGS_MODEL}")
# try:
# snapshot_download(...)
# except Exception: ...
# self.embeddings = HuggingFaceEmbeddings(...)
# self.db = FAISS.load_local(...)
# -----------------------------------------------------------------
# self.llm = ChatGroq(
# model=config.LLM_MODEL,
# temperature=0.5,
# api_key=config.GROQ_API_KEY
# )
if not config.GEMINI_API_KEY:
print("WARNING: GEMINI_API_KEY is not set. AI features will fail.")
self.llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
temperature=0.5,
google_api_key=config.GEMINI_API_KEY
)
def query_research(self, question: str, k: int = None) -> tuple[List[str], bool]:
"""
Retrieve research chunks from FAISS.
Args:
question: Search query
k: Number of results (default from config)
Returns:
Tuple of (research chunks or [question], is_direct_query)
"""
if not self.db:
print(f"FAISS DB not initialized, using direct query: {question}")
# Return the user's query directly when FAISS is not available
return [question], True
k = k or config.FAISS_K_SEARCH
results = self.db.similarity_search(question, k=k)
return [r.page_content for r in results], False
def extract_policy(
self,
research_chunks: List[str],
graph_context: Dict[str, Any],
is_direct_query: bool = False,
user_query: str = ""
) -> Policy:
"""
Use LLM to extract structured policy from research.
Args:
research_chunks: List of research excerpts or [user_query] if direct
graph_context: Dict with node/edge structure for validation
is_direct_query: True if research_chunks contains user query (no FAISS)
user_query: The original user query string (optional but recommended)
Returns:
Validated Policy object
"""
# Format content for prompt
# Ensure all chunks are strings (handle potential nested lists)
flat_chunks = []
for chunk in research_chunks:
if isinstance(chunk, list):
flat_chunks.extend([str(c) for c in chunk])
else:
flat_chunks.append(str(chunk))
# Determine intent from user_query if available, otherwise check chunks
query_text = user_query if user_query else (flat_chunks[0] if flat_chunks else "")
query_lower = query_text.lower()
increase_emissions = any(word in query_lower for word in ["increase", "raise", "boost", "expand", "grow", "worsen", "high"])
if is_direct_query:
research_section = f"USER QUERY: {query_text}\n\nUse your knowledge to create a policy addressing this query."
else:
research_section = f"USER QUERY: {query_text}\n\nRESEARCH FINDINGS:\n" + "\n---\n".join(flat_chunks)
formatted_nodes = ""
disabled_nodes = []
# Handle new graph_context structure (list of dicts) vs old (list of ids)
if 'nodes' in graph_context and isinstance(graph_context['nodes'], list):
formatted_nodes = ", ".join([n['id'] for n in graph_context['nodes']])
disabled_nodes = [n['id'] for n in graph_context['nodes'] if not n.get('enabled', True)]
else:
formatted_nodes = ", ".join(graph_context.get("node_ids", []))
formatted_edges = "\n".join([
f" {e['source']}->{e['target']} (current weight: {e.get('weight', 0.5)})"
for e in graph_context.get("edges", [])
])
disabled_section = ""
if disabled_nodes:
disabled_section = f"""
DISABLED SECTORS (User has manually disconnected these):
{', '.join(disabled_nodes)}
IMPORTANT: The above sectors are DISCONNECTED.
- Do NOT try to modify edges coming FROM these sectors, as they have no effect.
- You should acknowledge that they are disabled in your reasoning.
- Focus policies on the remaining ACTIVE sectors to achieve the goal."""
# Determine policy direction
if increase_emissions:
task_description = "INCREASE emissions"
mutation_type = "increase_edge_weight"
mechanics_section = """TO INCREASE EMISSIONS:
- MUST increase the edge weight to a LARGER number
- Example: 0.7 → 0.9 (increases flow by 28%)
- Example: 0.5 → 0.75 (increases flow by 50%)
- Example: 0.4 → 0.7 (increases flow by 75%)
CRITICAL: new_weight MUST BE GREATER THAN current weight. Do not decrease!"""
correct_examples = """ ✓ Change transport→co2 from 0.7 to 0.9 (increases flow)
✓ Change energy→co2 from 0.8 to 0.95 (increases propagation)"""
estimated_field = "co2_increase_pct"
else:
task_description = "REDUCE emissions"
mutation_type = "reduce_edge_weight"
mechanics_section = """TO REDUCE EMISSIONS:
- MUST decrease the edge weight to a SMALLER number
- Example: 0.7 → 0.35 (50% reduction in flow)
- Example: 0.5 → 0.25 (50% reduction in flow)
- Example: 0.8 → 0.4 (50% reduction in flow)
CRITICAL: new_weight MUST BE LESS THAN current weight. Do not increase!"""
correct_examples = """ ✓ Change transport→co2 from 0.7 to 0.35 (cuts in half)
✓ Change transport→co2 from 0.7 to 0.49 (30% reduction)
✓ Change energy→co2 from 0.8 to 0.48 (40% reduction)"""
estimated_field = "co2_reduction_pct"
prompt = f"""You are a climate policy expert. Your task is to design policies that {task_description}.
{research_section}
{disabled_section}
CURRENT SYSTEM EDGES (with current weights):
{formatted_edges}
Available nodes: {formatted_nodes}
EMISSION MECHANICS - READ CAREFULLY:
The system propagates emissions through edges. Each edge has a weight (0.0 to 1.0):
- target_value = source_value × weight
- A weight of 0.7 means 70% of the source value flows to the target
- A weight of 0.3 means 30% of the source value flows to the target
{mechanics_section}
CRITICAL INSTRUCTION:
1. You MUST follow the TASK direction ({task_description}).
2. USE ONLY THE EDGES LISTED ABOVE. Do not hallucinate connections.
3. If a node (e.g., 'transport') is NOT in the "Available nodes" list, YOU CANNOT CREATE A POLICY FOR IT.
4. If a node is listed in DISABLED SECTORS, do not attempt to change its edges (it is already off).
5. If the user asks to INCREASE emissions, you MUST generate a policy that INCREASES them.
6. Ignore research advice if it contradicts the goal to {task_description}.
7. Use the research only for context on *what* to modify, but reverse the action if needed to match the goal.
WRONG EXAMPLES (DO NOT DO THIS):
✗ Change 0.7 to 0.75 (wrong direction)
✗ Change 0.5 to 0.8 (wrong direction)
✗ Change 0.6 to 0.7 (wrong direction)
CORRECT EXAMPLES:
{correct_examples}
TASK: Create ONE policy to {task_description} by addressing: "{flat_chunks[0][:100] if flat_chunks else 'emissions policy'}..."
If the research suggests "Promoting EV", and your task is to INCREASE emissions, your policy should be "Tax EVs / Promote Gas Cars".
Return ONLY valid JSON:
{{
"policy_id": "policy-slug",
"name": "Policy Name",
"description": "Description",
"mutations": [
{{
"type": "{mutation_type}",
"source": "transport",
"target": "vehicle-emissions",
"new_weight": {"0.35" if not increase_emissions else "0.9"},
"original_weight": 0.7,
"reason": "Research shows..."
}}
],
"estimated_impacts": {{
"{estimated_field}": {"15.0" if not increase_emissions else "12.0"},
"aqi_improvement_pct": {"18.0" if not increase_emissions else "-15.0"},
"confidence": 0.8
}},
"trade_offs": [],
"source_research": {{
"paper_ids": [],
"key_quotes": [],
"confidence": 0.85
}}
}}"""
# Call LLM
response = self.llm.invoke(prompt)
response_text = response.content
# Extract JSON (handle markdown code blocks)
json_match = re.search(r'\{[\s\S]*\}', response_text)
if not json_match:
raise ValueError(f"No JSON found in LLM response: {response_text}")
json_str = json_match.group()
policy_dict = json.loads(json_str)
# Clean up trade_offs - LLM sometimes returns strings instead of objects
if 'trade_offs' in policy_dict:
cleaned_trade_offs = []
for item in policy_dict['trade_offs']:
if isinstance(item, str):
# Convert string to proper TradeOff object
cleaned_trade_offs.append({
'sector': 'general',
'impact': 'neutral',
'magnitude': 'mild',
'description': item
})
elif isinstance(item, dict):
cleaned_trade_offs.append(item)
policy_dict['trade_offs'] = cleaned_trade_offs
# Validate against schema
policy = Policy(**policy_dict)
# Log the mutations for debugging
print(f"\n[Policy Generated]")
print(f"Policy: {policy.name}")
print(f"Description: {policy.description}")
print(f"Mutations: {len(policy.mutations)}")
for i, mut in enumerate(policy.mutations):
if mut.type in ["reduce_edge_weight", "increase_edge_weight"]:
print(f" {i+1}. {mut.type}: {mut.source} -> {mut.target}")
print(f" New weight: {mut.new_weight} (reason: {mut.reason})")
print(f"Estimated CO₂ reduction: {policy.estimated_impacts.co2_reduction_pct}%")
print(f"Estimated AQI improvement: {policy.estimated_impacts.aqi_improvement_pct}%\n")
# Validate mutations reference real nodes/edges
self._validate_mutations(policy, graph_context)
return policy
def _validate_mutations(self, policy: Policy, graph_context: Dict) -> None:
"""Ensure mutations reference real nodes/edges."""
# Handle both old (node_ids list) and new (nodes dict list) formats
if 'nodes' in graph_context and isinstance(graph_context['nodes'], list):
node_ids = set(n['id'] for n in graph_context['nodes'])
else:
node_ids = set(graph_context.get("node_ids", []))
edge_pairs = set((e['source'], e['target']) for e in graph_context.get("edges", []))
for mutation in policy.mutations:
if mutation.type == "disable_node":
if mutation.node_id not in node_ids:
raise ValueError(f"Unknown node: {mutation.node_id}")
elif mutation.type in ["reduce_edge_weight", "increase_edge_weight"]:
if (mutation.source, mutation.target) not in edge_pairs:
raise ValueError(f"Unknown edge: {mutation.source} -> {mutation.target}")
if mutation.new_weight is None or not (0.0 <= mutation.new_weight <= 1.0):
raise ValueError(f"Invalid weight: {mutation.new_weight}")
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def get_graph_context_from_file(filepath: str) -> Dict[str, Any]:
"""
Load graph context (nodes, edges) from snapshot file.
Used for validation during policy extraction.
"""
try:
with open(filepath, 'r') as f:
data = json.load(f)
# Extract node IDs
node_ids = [n['id'] for n in data.get('nodes', [])]
# Extract edges with weights
edges = [
{
'source': e['source'],
'target': e['target'],
'weight': e.get('data', {}).get('weight', 0.5)
}
for e in data.get('edges', [])
]
return {
'node_ids': node_ids,
'edges': edges,
'full_graph': data
}
except Exception as e:
print(f"Could not load graph context: {e}")
return {
'node_ids': [
'industries', 'transport', 'energy', 'infrastructure',
'moves-goods', 'uses-power', 'fuels-transport', 'co2', 'aqi'
],
'edges': [
{'source': 'industries', 'target': 'transport', 'weight': 0.6},
{'source': 'transport', 'target': 'vehicle-emissions', 'weight': 0.7},
{'source': 'energy', 'target': 'co2', 'weight': 0.8},
{'source': 'co2', 'target': 'aqi', 'weight': 0.9}
]
}