Pulastya B commited on
Commit
d8b38c5
·
1 Parent(s): dba2298

Fixed conversation pruning logic

Browse files
Files changed (1) hide show
  1. src/orchestrator.py +43 -14
src/orchestrator.py CHANGED
@@ -2953,24 +2953,41 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
2953
  return msg.get('role', '')
2954
  return getattr(msg, 'role', '')
2955
 
 
 
 
 
 
 
2956
  if len(messages) > 10:
2957
  # Keep: system prompt [0], user query [1], last valid exchanges
2958
  system_msg = messages[0]
2959
  user_msg = messages[1]
2960
  recent_msgs = messages[-8:]
2961
 
2962
- # Ensure no orphaned tool messages after pruning
2963
- # Mistral requires: assistant → tool → assistant → user (never tool after user)
2964
  cleaned_recent = []
2965
- for i, msg in enumerate(recent_msgs):
2966
- # Skip tool messages that aren't preceded by assistant
2967
- if get_role(msg) == 'tool':
2968
- # Check if previous message is assistant
2969
- if i > 0 and get_role(recent_msgs[i-1]) == 'assistant':
2970
- cleaned_recent.append(msg)
2971
- # Otherwise skip this orphaned tool message
 
 
 
 
 
 
 
 
 
2972
  else:
 
2973
  cleaned_recent.append(msg)
 
2974
 
2975
  messages = [system_msg, user_msg] + cleaned_recent
2976
  print(f"✂️ Pruned conversation (keeping last 4 exchanges, ~4K tokens saved)")
@@ -2986,14 +3003,26 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
2986
  user_msg = messages[1]
2987
  recent_msgs = messages[-4:]
2988
 
2989
- # Clean orphaned tool messages
2990
  cleaned_recent = []
2991
- for i, msg in enumerate(recent_msgs):
2992
- if get_role(msg) == 'tool':
2993
- if i > 0 and get_role(recent_msgs[i-1]) == 'assistant':
2994
- cleaned_recent.append(msg)
 
 
 
 
 
 
 
 
 
 
 
2995
  else:
2996
  cleaned_recent.append(msg)
 
2997
 
2998
  messages = [system_msg, user_msg] + cleaned_recent
2999
  print(f"⚠️ Emergency pruning (conversation > 8K tokens)")
 
2953
  return msg.get('role', '')
2954
  return getattr(msg, 'role', '')
2955
 
2956
+ # Helper to check if message has tool_calls
2957
+ def has_tool_calls(msg):
2958
+ if isinstance(msg, dict):
2959
+ return bool(msg.get('tool_calls'))
2960
+ return bool(getattr(msg, 'tool_calls', None))
2961
+
2962
  if len(messages) > 10:
2963
  # Keep: system prompt [0], user query [1], last valid exchanges
2964
  system_msg = messages[0]
2965
  user_msg = messages[1]
2966
  recent_msgs = messages[-8:]
2967
 
2968
+ # CRITICAL: Keep complete tool call/response groups together
2969
+ # Mistral requires: assistant (with tool_calls) → tool responses → assistant → user
2970
  cleaned_recent = []
2971
+ i = 0
2972
+ while i < len(recent_msgs):
2973
+ msg = recent_msgs[i]
2974
+ role = get_role(msg)
2975
+
2976
+ if role == 'assistant' and has_tool_calls(msg):
2977
+ # This assistant has tool calls - must keep it AND all following tool responses
2978
+ cleaned_recent.append(msg)
2979
+ i += 1
2980
+ # Collect all consecutive tool responses
2981
+ while i < len(recent_msgs) and get_role(recent_msgs[i]) == 'tool':
2982
+ cleaned_recent.append(recent_msgs[i])
2983
+ i += 1
2984
+ elif role == 'tool':
2985
+ # Orphaned tool message (no preceding assistant with tool_calls) - skip it
2986
+ i += 1
2987
  else:
2988
+ # Regular message (assistant without tool_calls, user, system)
2989
  cleaned_recent.append(msg)
2990
+ i += 1
2991
 
2992
  messages = [system_msg, user_msg] + cleaned_recent
2993
  print(f"✂️ Pruned conversation (keeping last 4 exchanges, ~4K tokens saved)")
 
3003
  user_msg = messages[1]
3004
  recent_msgs = messages[-4:]
3005
 
3006
+ # CRITICAL: Keep complete tool call/response groups together
3007
  cleaned_recent = []
3008
+ i = 0
3009
+ while i < len(recent_msgs):
3010
+ msg = recent_msgs[i]
3011
+ role = get_role(msg)
3012
+
3013
+ if role == 'assistant' and has_tool_calls(msg):
3014
+ # Keep assistant with tool calls AND all its tool responses
3015
+ cleaned_recent.append(msg)
3016
+ i += 1
3017
+ while i < len(recent_msgs) and get_role(recent_msgs[i]) == 'tool':
3018
+ cleaned_recent.append(recent_msgs[i])
3019
+ i += 1
3020
+ elif role == 'tool':
3021
+ # Skip orphaned tool message
3022
+ i += 1
3023
  else:
3024
  cleaned_recent.append(msg)
3025
+ i += 1
3026
 
3027
  messages = [system_msg, user_msg] + cleaned_recent
3028
  print(f"⚠️ Emergency pruning (conversation > 8K tokens)")