File size: 11,158 Bytes
845d5aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import os
region="us-east-1" #set this to AWS region you're using
os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"

from dataclasses import dataclass, field
from datetime import datetime
import json
import re

from pydantic_ai import Agent, RunContext, Tool
# from pydantic_ai.models.openai import OpenAIChatModel
# from pydantic_ai.providers.ollama import OllamaProvider
from pydantic_ai.models.bedrock import BedrockConverseModel
from pydantic_ai.settings import ModelSettings
from bedrock_agentcore.runtime import BedrockAgentCoreApp
from bedrock_agentcore_starter_toolkit import Runtime
from boto3.session import Session

from config import DATASETS, DATASET_LIST
from schemas import DataSourceTracker, ReportMapOutput
from tools import analyse_and_plot_features_and_nearby_infrastructure,\
    analyse_and_plot_within_op,\
    analyse_using_mcda_then_plot,\
    perform_scenario_analysis_then_plot,\
    get_scenario_weights,\
    geocode_location,\
    assess_seismic_risk_at_location,\
    assess_infrastructure_proximity,\
    calculate_overall_risk_score,\
    create_risk_assessment_map,\
    plan_low_impact_exploration_sites,\
    plan_global_wind_farm_sites,\
    analyze_wind_farm_constraints,\
    get_location_bounds
    

def show_seismic_dataset(run_context: RunContext):
    """ Show the seismic dataset.
    Args:
        None

    Return:
        data: The seismic data
    """
    import pandas as pd
    from config import get_dataset_path
    
    df = pd.read_csv(get_dataset_path("seismic"))
    return df.head().to_dict()


def get_available_data_sources(run_context: RunContext):    
    """ Return the available data sources.
    Args:
        None

    Return:
        data_source_list: A list of strings, showing the available data sources
    """

    print(run_context)
    print()

    data_source_list = DATASET_LIST
    return data_source_list


app = BedrockAgentCoreApp()

# model = BedrockConverseModel('us.anthropic.claude-sonnet-4-20250514-v1:0')
# model = BedrockConverseModel('us.anthropic.claude-sonnet-4-5-20250929-v1:0')
model = BedrockConverseModel('us.anthropic.claude-3-5-haiku-20241022-v1:0')

model_settings = ModelSettings(
    max_retries=6,  # Retry on throttling
    retry_delay=5.0   # Wait between retries
)

agent = Agent(
    model=model,
    # deps_type=DataSourceTracker,
    tools=[
        # get_available_data_sources,
        Tool(
            function=get_available_data_sources,
            takes_ctx=True, 
            description="Get list of available datasets"
        ),
        Tool(
            function=geocode_location,
            takes_ctx=True, 
            description="Convert a location name or address to latitude and longitude coordinates."
        ),
        Tool(
            function=assess_seismic_risk_at_location,
            takes_ctx=True, 
            description="Assess seismic risk by counting earthquake events within a radius of a specific location."
        ),
        Tool(
            function=assess_infrastructure_proximity,
            takes_ctx=True, 
            description="Assess infrastructure proximity risk by counting existing infrastructure within a radius."
        ),

        Tool(
            function=calculate_overall_risk_score,
            takes_ctx=True, 
            description="Calculate overall risk score from individual risk components and provide recommendations."
        ),

        Tool(
            function=get_location_bounds,
            takes_ctx=True, 
            description="Get geographical bounds for any location."
        ),
        # analyze_wind_farm_constraints,
        
        # calculate_overall_risk_score,
        # create_risk_assessment_map,


    ],
    deps_type=DataSourceTracker,
    system_prompt = """
You're a helpful assistant specialized in energy infrastructure analysis and risk assessment. 

IMPORTANT: When users ask for comprehensive analyses like "assess risk", "evaluate location", or "analyze infrastructure", you should use MULTIPLE tools in sequence to provide complete answers.

CRITICAL TOOL CHAINING FOR LOCATION-BASED QUERIES:
When users mention specific locations (regions, countries, seas, coordinates) in their queries, ALWAYS follow this sequence:

1. **FIRST**: Call get_location_bounds to get precise geographical coordinates
2. **THEN**: Use those coordinates in subsequent tools

LOCATION-BASED QUERY PATTERNS:
- "wind farm sites in [LOCATION]" β†’ get_location_bounds β†’ plan_global_wind_farm_sites
- "assess risk in [LOCATION]" β†’ get_location_bounds β†’ assess_seismic_risk_at_location
- "explore [LOCATION]" β†’ get_location_bounds β†’ appropriate planning tool

TOOL CHAINING PATTERNS:

For "wind farm planning" queries:
1. If user mentions a specific location β†’ get_location_bounds(location="LOCATION")
2. Extract bounds from response: {"bounds": {"min_lat": X, "max_lat": Y, "min_lon": Z, "max_lon": W}}
3. Call plan_global_wind_farm_sites with appropriate parameters:
   - Use explicit bounds: min_lat=X, max_lat=Y, min_lon=Z, max_lon=W, location_name="LOCATION"
   - For scenario-based queries: Add scenario_name and adjustments
   - For constraint relaxation: Use adaptive_constraints=True

WIND FARM SCENARIO SUPPORT:
Available scenarios: "balanced_wind", "wind_resource_focus", "environmental_focus", "economic_focus", "operational_focus"
- "environmental focus" β†’ scenario_name="environmental_focus"
- "prioritize wind resource" β†’ scenario_name="wind_resource_focus"
- "economic optimization" β†’ scenario_name="economic_focus"
- Adjustments: "increase environmental weight" β†’ adjust_environmental=0.15

For "risk assessment" queries:
1. If location is a name/address β†’ geocode_location first
2. Then assess_seismic_risk_at_location 
3. Then assess_infrastructure_proximity
4. Then calculate_overall_risk_score
5. Finally create_risk_assessment_map for visualization

For "infrastructure analysis":
1. Use appropriate analysis tools (analyse_and_plot_features_and_nearby_infrastructure, etc.)
2. Add mapping/visualization tools when helpful

For "multi-criteria analysis":
1. Use MCDA tools (analyse_using_mcda_then_plot, perform_scenario_analysis_then_plot)

EXAMPLE WORKFLOWS:

User: "Explore potential wind farm sites in Africa with environmental focus"
1. Call: get_location_bounds(location="Africa")
   Result: {"bounds": {"min_lat": -35.0, "max_lat": 37.0, "min_lon": -25.0, "max_lon": 52.0}}
2. Call: plan_global_wind_farm_sites(min_lat=-35.0, max_lat=37.0, min_lon=-25.0, max_lon=52.0, location_name="Africa", scenario_name="environmental_focus")

User: "Plan wind farms in North Sea with economic focus but increase environmental importance"
1. Call: get_location_bounds(location="North Sea")
2. Call: plan_global_wind_farm_sites(..., scenario_name="economic_focus", adjust_environmental=0.15)

CONSTRAINT STRATEGY: Always use adaptive_constraints=True for wind farm planning to ensure good site coverage and let the system optimize constraints for each specific region.

CRITICAL: Don't stop after just getting coordinates. The user expects a complete analysis when they ask for wind farm planning or site assessment.

When you get coordinates from get_location_bounds, immediately use those coordinates in subsequent planning tools.

SYSTEM:
Show your reasoning explicitly in <think>...</think> tags.
Keep it concise and structured.
Continue analysis until you've fully answered the user's question.
"""
)

conversation_histories = {}

@app.entrypoint
def pydantic_bedrock_claude_main(payload):
# def agent(payload):
    """
    Invoke the agent with a payload
    """
    print("========== ENTRYPOINT CALLED ==========")
    print(f"Payload: {payload}")

    user_input = payload.get("prompt")

    session_id = payload.get("session_id", "default")
    if session_id not in conversation_histories:
        conversation_histories[session_id] = []

    deps = DataSourceTracker()

    # Currently, Pydantic AI does not officially support returning the results of
    # called tool directly (without summarizing). So I followed this workaround: 
    # https://github.com/pydantic/pydantic-ai/pull/142#issuecomment-3158974832
    result = agent.run_sync(user_input,
        deps=deps,
        output_type=[
            analyse_and_plot_features_and_nearby_infrastructure,
            analyse_and_plot_within_op,
            analyse_using_mcda_then_plot,
            perform_scenario_analysis_then_plot,
            create_risk_assessment_map,
            plan_low_impact_exploration_sites,
            plan_global_wind_farm_sites,                        
            str],  # Functions passed here!
        model_settings=model_settings,
        message_history=conversation_histories[session_id],                   
    )   
    
    # Extract thinking and tool calls from messages
    thinking_log = []
    tool_calls_log = []
    seen_tool_calls = set()
        
    for msg in result.all_messages():
        if hasattr(msg, 'parts'):
            for part in msg.parts:
                # Extract tool calls
                if hasattr(part, 'tool_name'):
                    tool_name = part.tool_name
                
                    # Remove Pydantic AI added prefix in tool names
                    if tool_name.startswith('final_result_'):
                        tool_name = tool_name.replace('final_result_', '')
                    
                    args = part.args if hasattr(part, 'args') else {}
                    
                    # More robust signature - handle empty args
                    if args:
                        tool_signature = f"{tool_name}:{json.dumps(args, sort_keys=True)}"
                    else:
                        tool_signature = f"{tool_name}:no_args"
                    
                    # Debug: print what we're seeing
                    print(f"πŸ” Tool detected: {tool_name}, Args: {args}, Signature: {tool_signature}")
                    
                    if tool_name not in seen_tool_calls:
                        seen_tool_calls.add(tool_name)
                        tool_calls_log.append({'tool_name': tool_name, 'args': args})
                        print(f"  βœ… Added to log")
                    else:
                        print(f"  ⏭️ Skipped (duplicate)")
                        
                # Extract text content (includes <think> tags)
                elif hasattr(part, 'content') and isinstance(part.content, str):
                    # Extract thinking from <think> tags
                    think_matches = re.findall(r'<think>(.*?)</think>', part.content, re.DOTALL)
                    for think_content in think_matches:
                        thinking_log.append(think_content.strip())
    
    conversation_histories[session_id] = result.all_messages()
    data_sources = deps.get_sources()

    print(result.output)

    # Return structured response with thinking and tool calls
    return {
        "output": result.output,
        "thinking": thinking_log,
        "tool_calls": tool_calls_log,
        "data_sources": data_sources,
    }

if __name__ == "__main__":
    app.run()