dmpantiu commited on
Commit
608d4a4
·
verified ·
1 Parent(s): 63fd495

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. web/agent_wrapper.py +32 -1
web/agent_wrapper.py CHANGED
@@ -216,12 +216,43 @@ class AgentSession:
216
  await asyncio.sleep(0.5)
217
 
218
  # Collect Arraylake snippet from NEW messages only
 
219
  arraylake_snippets = []
220
- for msg in new_messages:
 
221
  if hasattr(msg, 'tool_calls') and msg.tool_calls:
222
  for tc in msg.tool_calls:
223
  if tc.get('name') == 'retrieve_era5_data':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  args = tc.get('args', {})
 
 
 
 
 
 
 
 
 
 
 
 
225
  arraylake_snippets.append(_arraylake_snippet(
226
  variable=args.get('variable_id', 'sst'),
227
  query_type=args.get('query_type', 'spatial'),
 
216
  await asyncio.sleep(0.5)
217
 
218
  # Collect Arraylake snippet from NEW messages only
219
+ # Only emit ONE snippet per unique (variable, region) — skip failed calls
220
  arraylake_snippets = []
221
+ seen_snippet_keys = set()
222
+ for i, msg in enumerate(new_messages):
223
  if hasattr(msg, 'tool_calls') and msg.tool_calls:
224
  for tc in msg.tool_calls:
225
  if tc.get('name') == 'retrieve_era5_data':
226
+ # Check if tool call succeeded by looking at the next message
227
+ # (ToolMessage with same tool_call_id)
228
+ tc_id = tc.get('id', '')
229
+ succeeded = True
230
+ for later_msg in new_messages[i+1:]:
231
+ if (hasattr(later_msg, 'tool_call_id') and
232
+ later_msg.tool_call_id == tc_id):
233
+ content = getattr(later_msg, 'content', '') or ''
234
+ if any(kw in content.lower() for kw in
235
+ ['error', 'failed', 'exception', 'limit',
236
+ 'exceeded', 'rejected', 'too large']):
237
+ succeeded = False
238
+ break
239
+
240
+ if not succeeded:
241
+ continue
242
+
243
  args = tc.get('args', {})
244
+ # Dedup key: variable + rounded region
245
+ dedup_key = (
246
+ args.get('variable_id', 'sst'),
247
+ round(args.get('min_latitude', -90)),
248
+ round(args.get('max_latitude', 90)),
249
+ round(args.get('min_longitude', 0)),
250
+ round(args.get('max_longitude', 360)),
251
+ )
252
+ if dedup_key in seen_snippet_keys:
253
+ continue
254
+ seen_snippet_keys.add(dedup_key)
255
+
256
  arraylake_snippets.append(_arraylake_snippet(
257
  variable=args.get('variable_id', 'sst'),
258
  query_type=args.get('query_type', 'spatial'),