Upload folder using huggingface_hub
Browse files- web/agent_wrapper.py +32 -1
web/agent_wrapper.py
CHANGED
|
@@ -216,12 +216,43 @@ class AgentSession:
|
|
| 216 |
await asyncio.sleep(0.5)
|
| 217 |
|
| 218 |
# Collect Arraylake snippet from NEW messages only
|
|
|
|
| 219 |
arraylake_snippets = []
|
| 220 |
-
|
|
|
|
| 221 |
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
| 222 |
for tc in msg.tool_calls:
|
| 223 |
if tc.get('name') == 'retrieve_era5_data':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
args = tc.get('args', {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
arraylake_snippets.append(_arraylake_snippet(
|
| 226 |
variable=args.get('variable_id', 'sst'),
|
| 227 |
query_type=args.get('query_type', 'spatial'),
|
|
|
|
| 216 |
await asyncio.sleep(0.5)
|
| 217 |
|
| 218 |
# Collect Arraylake snippet from NEW messages only
|
| 219 |
+
# Only emit ONE snippet per unique (variable, region) — skip failed calls
|
| 220 |
arraylake_snippets = []
|
| 221 |
+
seen_snippet_keys = set()
|
| 222 |
+
for i, msg in enumerate(new_messages):
|
| 223 |
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
| 224 |
for tc in msg.tool_calls:
|
| 225 |
if tc.get('name') == 'retrieve_era5_data':
|
| 226 |
+
# Check if tool call succeeded by looking at the next message
|
| 227 |
+
# (ToolMessage with same tool_call_id)
|
| 228 |
+
tc_id = tc.get('id', '')
|
| 229 |
+
succeeded = True
|
| 230 |
+
for later_msg in new_messages[i+1:]:
|
| 231 |
+
if (hasattr(later_msg, 'tool_call_id') and
|
| 232 |
+
later_msg.tool_call_id == tc_id):
|
| 233 |
+
content = getattr(later_msg, 'content', '') or ''
|
| 234 |
+
if any(kw in content.lower() for kw in
|
| 235 |
+
['error', 'failed', 'exception', 'limit',
|
| 236 |
+
'exceeded', 'rejected', 'too large']):
|
| 237 |
+
succeeded = False
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
if not succeeded:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
args = tc.get('args', {})
|
| 244 |
+
# Dedup key: variable + rounded region
|
| 245 |
+
dedup_key = (
|
| 246 |
+
args.get('variable_id', 'sst'),
|
| 247 |
+
round(args.get('min_latitude', -90)),
|
| 248 |
+
round(args.get('max_latitude', 90)),
|
| 249 |
+
round(args.get('min_longitude', 0)),
|
| 250 |
+
round(args.get('max_longitude', 360)),
|
| 251 |
+
)
|
| 252 |
+
if dedup_key in seen_snippet_keys:
|
| 253 |
+
continue
|
| 254 |
+
seen_snippet_keys.add(dedup_key)
|
| 255 |
+
|
| 256 |
arraylake_snippets.append(_arraylake_snippet(
|
| 257 |
variable=args.get('variable_id', 'sst'),
|
| 258 |
query_type=args.get('query_type', 'spatial'),
|