File size: 11,158 Bytes
845d5aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
import os
region="us-east-1" #set this to AWS region you're using
os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
from dataclasses import dataclass, field
from datetime import datetime
import json
import re
from pydantic_ai import Agent, RunContext, Tool
# from pydantic_ai.models.openai import OpenAIChatModel
# from pydantic_ai.providers.ollama import OllamaProvider
from pydantic_ai.models.bedrock import BedrockConverseModel
from pydantic_ai.settings import ModelSettings
from bedrock_agentcore.runtime import BedrockAgentCoreApp
from bedrock_agentcore_starter_toolkit import Runtime
from boto3.session import Session
from config import DATASETS, DATASET_LIST
from schemas import DataSourceTracker, ReportMapOutput
from tools import analyse_and_plot_features_and_nearby_infrastructure,\
analyse_and_plot_within_op,\
analyse_using_mcda_then_plot,\
perform_scenario_analysis_then_plot,\
get_scenario_weights,\
geocode_location,\
assess_seismic_risk_at_location,\
assess_infrastructure_proximity,\
calculate_overall_risk_score,\
create_risk_assessment_map,\
plan_low_impact_exploration_sites,\
plan_global_wind_farm_sites,\
analyze_wind_farm_constraints,\
get_location_bounds
def show_seismic_dataset(run_context: RunContext):
""" Show the seismic dataset.
Args:
None
Return:
data: The seismic data
"""
import pandas as pd
from config import get_dataset_path
df = pd.read_csv(get_dataset_path("seismic"))
return df.head().to_dict()
def get_available_data_sources(run_context: RunContext):
""" Return the available data sources.
Args:
None
Return:
data_source_list: A list of strings, showing the available data sources
"""
print(run_context)
print()
data_source_list = DATASET_LIST
return data_source_list
app = BedrockAgentCoreApp()
# model = BedrockConverseModel('us.anthropic.claude-sonnet-4-20250514-v1:0')
# model = BedrockConverseModel('us.anthropic.claude-sonnet-4-5-20250929-v1:0')
model = BedrockConverseModel('us.anthropic.claude-3-5-haiku-20241022-v1:0')
model_settings = ModelSettings(
max_retries=6, # Retry on throttling
retry_delay=5.0 # Wait between retries
)
agent = Agent(
model=model,
# deps_type=DataSourceTracker,
tools=[
# get_available_data_sources,
Tool(
function=get_available_data_sources,
takes_ctx=True,
description="Get list of available datasets"
),
Tool(
function=geocode_location,
takes_ctx=True,
description="Convert a location name or address to latitude and longitude coordinates."
),
Tool(
function=assess_seismic_risk_at_location,
takes_ctx=True,
description="Assess seismic risk by counting earthquake events within a radius of a specific location."
),
Tool(
function=assess_infrastructure_proximity,
takes_ctx=True,
description="Assess infrastructure proximity risk by counting existing infrastructure within a radius."
),
Tool(
function=calculate_overall_risk_score,
takes_ctx=True,
description="Calculate overall risk score from individual risk components and provide recommendations."
),
Tool(
function=get_location_bounds,
takes_ctx=True,
description="Get geographical bounds for any location."
),
# analyze_wind_farm_constraints,
# calculate_overall_risk_score,
# create_risk_assessment_map,
],
deps_type=DataSourceTracker,
system_prompt = """
You're a helpful assistant specialized in energy infrastructure analysis and risk assessment.
IMPORTANT: When users ask for comprehensive analyses like "assess risk", "evaluate location", or "analyze infrastructure", you should use MULTIPLE tools in sequence to provide complete answers.
CRITICAL TOOL CHAINING FOR LOCATION-BASED QUERIES:
When users mention specific locations (regions, countries, seas, coordinates) in their queries, ALWAYS follow this sequence:
1. **FIRST**: Call get_location_bounds to get precise geographical coordinates
2. **THEN**: Use those coordinates in subsequent tools
LOCATION-BASED QUERY PATTERNS:
- "wind farm sites in [LOCATION]" β get_location_bounds β plan_global_wind_farm_sites
- "assess risk in [LOCATION]" β get_location_bounds β assess_seismic_risk_at_location
- "explore [LOCATION]" β get_location_bounds β appropriate planning tool
TOOL CHAINING PATTERNS:
For "wind farm planning" queries:
1. If user mentions a specific location β get_location_bounds(location="LOCATION")
2. Extract bounds from response: {"bounds": {"min_lat": X, "max_lat": Y, "min_lon": Z, "max_lon": W}}
3. Call plan_global_wind_farm_sites with appropriate parameters:
- Use explicit bounds: min_lat=X, max_lat=Y, min_lon=Z, max_lon=W, location_name="LOCATION"
- For scenario-based queries: Add scenario_name and adjustments
- For constraint relaxation: Use adaptive_constraints=True
WIND FARM SCENARIO SUPPORT:
Available scenarios: "balanced_wind", "wind_resource_focus", "environmental_focus", "economic_focus", "operational_focus"
- "environmental focus" β scenario_name="environmental_focus"
- "prioritize wind resource" β scenario_name="wind_resource_focus"
- "economic optimization" β scenario_name="economic_focus"
- Adjustments: "increase environmental weight" β adjust_environmental=0.15
For "risk assessment" queries:
1. If location is a name/address β geocode_location first
2. Then assess_seismic_risk_at_location
3. Then assess_infrastructure_proximity
4. Then calculate_overall_risk_score
5. Finally create_risk_assessment_map for visualization
For "infrastructure analysis":
1. Use appropriate analysis tools (analyse_and_plot_features_and_nearby_infrastructure, etc.)
2. Add mapping/visualization tools when helpful
For "multi-criteria analysis":
1. Use MCDA tools (analyse_using_mcda_then_plot, perform_scenario_analysis_then_plot)
EXAMPLE WORKFLOWS:
User: "Explore potential wind farm sites in Africa with environmental focus"
1. Call: get_location_bounds(location="Africa")
Result: {"bounds": {"min_lat": -35.0, "max_lat": 37.0, "min_lon": -25.0, "max_lon": 52.0}}
2. Call: plan_global_wind_farm_sites(min_lat=-35.0, max_lat=37.0, min_lon=-25.0, max_lon=52.0, location_name="Africa", scenario_name="environmental_focus")
User: "Plan wind farms in North Sea with economic focus but increase environmental importance"
1. Call: get_location_bounds(location="North Sea")
2. Call: plan_global_wind_farm_sites(..., scenario_name="economic_focus", adjust_environmental=0.15)
CONSTRAINT STRATEGY: Always use adaptive_constraints=True for wind farm planning to ensure good site coverage and let the system optimize constraints for each specific region.
CRITICAL: Don't stop after just getting coordinates. The user expects a complete analysis when they ask for wind farm planning or site assessment.
When you get coordinates from get_location_bounds, immediately use those coordinates in subsequent planning tools.
SYSTEM:
Show your reasoning explicitly in <think>...</think> tags.
Keep it concise and structured.
Continue analysis until you've fully answered the user's question.
"""
)
conversation_histories = {}
@app.entrypoint
def pydantic_bedrock_claude_main(payload):
# def agent(payload):
"""
Invoke the agent with a payload
"""
print("========== ENTRYPOINT CALLED ==========")
print(f"Payload: {payload}")
user_input = payload.get("prompt")
session_id = payload.get("session_id", "default")
if session_id not in conversation_histories:
conversation_histories[session_id] = []
deps = DataSourceTracker()
# Currently, Pydantic AI does not officially support returning the results of
# called tool directly (without summarizing). So I followed this workaround:
# https://github.com/pydantic/pydantic-ai/pull/142#issuecomment-3158974832
result = agent.run_sync(user_input,
deps=deps,
output_type=[
analyse_and_plot_features_and_nearby_infrastructure,
analyse_and_plot_within_op,
analyse_using_mcda_then_plot,
perform_scenario_analysis_then_plot,
create_risk_assessment_map,
plan_low_impact_exploration_sites,
plan_global_wind_farm_sites,
str], # Functions passed here!
model_settings=model_settings,
message_history=conversation_histories[session_id],
)
# Extract thinking and tool calls from messages
thinking_log = []
tool_calls_log = []
seen_tool_calls = set()
for msg in result.all_messages():
if hasattr(msg, 'parts'):
for part in msg.parts:
# Extract tool calls
if hasattr(part, 'tool_name'):
tool_name = part.tool_name
# Remove Pydantic AI added prefix in tool names
if tool_name.startswith('final_result_'):
tool_name = tool_name.replace('final_result_', '')
args = part.args if hasattr(part, 'args') else {}
# More robust signature - handle empty args
if args:
tool_signature = f"{tool_name}:{json.dumps(args, sort_keys=True)}"
else:
tool_signature = f"{tool_name}:no_args"
# Debug: print what we're seeing
print(f"π Tool detected: {tool_name}, Args: {args}, Signature: {tool_signature}")
if tool_name not in seen_tool_calls:
seen_tool_calls.add(tool_name)
tool_calls_log.append({'tool_name': tool_name, 'args': args})
print(f" β
Added to log")
else:
print(f" βοΈ Skipped (duplicate)")
# Extract text content (includes <think> tags)
elif hasattr(part, 'content') and isinstance(part.content, str):
# Extract thinking from <think> tags
think_matches = re.findall(r'<think>(.*?)</think>', part.content, re.DOTALL)
for think_content in think_matches:
thinking_log.append(think_content.strip())
conversation_histories[session_id] = result.all_messages()
data_sources = deps.get_sources()
print(result.output)
# Return structured response with thinking and tool calls
return {
"output": result.output,
"thinking": thinking_log,
"tool_calls": tool_calls_log,
"data_sources": data_sources,
}
if __name__ == "__main__":
app.run()
|