Spaces:
Running
Running
Pulastya B commited on
Commit ·
d8b38c5
1
Parent(s): dba2298
Fixed conversation pruning logic
Browse files- src/orchestrator.py +43 -14
src/orchestrator.py
CHANGED
|
@@ -2953,24 +2953,41 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
|
|
| 2953 |
return msg.get('role', '')
|
| 2954 |
return getattr(msg, 'role', '')
|
| 2955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2956 |
if len(messages) > 10:
|
| 2957 |
# Keep: system prompt [0], user query [1], last valid exchanges
|
| 2958 |
system_msg = messages[0]
|
| 2959 |
user_msg = messages[1]
|
| 2960 |
recent_msgs = messages[-8:]
|
| 2961 |
|
| 2962 |
-
#
|
| 2963 |
-
# Mistral requires: assistant → tool → assistant → user
|
| 2964 |
cleaned_recent = []
|
| 2965 |
-
|
| 2966 |
-
|
| 2967 |
-
|
| 2968 |
-
|
| 2969 |
-
|
| 2970 |
-
|
| 2971 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2972 |
else:
|
|
|
|
| 2973 |
cleaned_recent.append(msg)
|
|
|
|
| 2974 |
|
| 2975 |
messages = [system_msg, user_msg] + cleaned_recent
|
| 2976 |
print(f"✂️ Pruned conversation (keeping last 4 exchanges, ~4K tokens saved)")
|
|
@@ -2986,14 +3003,26 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
|
|
| 2986 |
user_msg = messages[1]
|
| 2987 |
recent_msgs = messages[-4:]
|
| 2988 |
|
| 2989 |
-
#
|
| 2990 |
cleaned_recent = []
|
| 2991 |
-
|
| 2992 |
-
|
| 2993 |
-
|
| 2994 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2995 |
else:
|
| 2996 |
cleaned_recent.append(msg)
|
|
|
|
| 2997 |
|
| 2998 |
messages = [system_msg, user_msg] + cleaned_recent
|
| 2999 |
print(f"⚠️ Emergency pruning (conversation > 8K tokens)")
|
|
|
|
| 2953 |
return msg.get('role', '')
|
| 2954 |
return getattr(msg, 'role', '')
|
| 2955 |
|
| 2956 |
+
# Helper to check if message has tool_calls
|
| 2957 |
+
def has_tool_calls(msg):
|
| 2958 |
+
if isinstance(msg, dict):
|
| 2959 |
+
return bool(msg.get('tool_calls'))
|
| 2960 |
+
return bool(getattr(msg, 'tool_calls', None))
|
| 2961 |
+
|
| 2962 |
if len(messages) > 10:
|
| 2963 |
# Keep: system prompt [0], user query [1], last valid exchanges
|
| 2964 |
system_msg = messages[0]
|
| 2965 |
user_msg = messages[1]
|
| 2966 |
recent_msgs = messages[-8:]
|
| 2967 |
|
| 2968 |
+
# CRITICAL: Keep complete tool call/response groups together
|
| 2969 |
+
# Mistral requires: assistant (with tool_calls) → tool responses → assistant → user
|
| 2970 |
cleaned_recent = []
|
| 2971 |
+
i = 0
|
| 2972 |
+
while i < len(recent_msgs):
|
| 2973 |
+
msg = recent_msgs[i]
|
| 2974 |
+
role = get_role(msg)
|
| 2975 |
+
|
| 2976 |
+
if role == 'assistant' and has_tool_calls(msg):
|
| 2977 |
+
# This assistant has tool calls - must keep it AND all following tool responses
|
| 2978 |
+
cleaned_recent.append(msg)
|
| 2979 |
+
i += 1
|
| 2980 |
+
# Collect all consecutive tool responses
|
| 2981 |
+
while i < len(recent_msgs) and get_role(recent_msgs[i]) == 'tool':
|
| 2982 |
+
cleaned_recent.append(recent_msgs[i])
|
| 2983 |
+
i += 1
|
| 2984 |
+
elif role == 'tool':
|
| 2985 |
+
# Orphaned tool message (no preceding assistant with tool_calls) - skip it
|
| 2986 |
+
i += 1
|
| 2987 |
else:
|
| 2988 |
+
# Regular message (assistant without tool_calls, user, system)
|
| 2989 |
cleaned_recent.append(msg)
|
| 2990 |
+
i += 1
|
| 2991 |
|
| 2992 |
messages = [system_msg, user_msg] + cleaned_recent
|
| 2993 |
print(f"✂️ Pruned conversation (keeping last 4 exchanges, ~4K tokens saved)")
|
|
|
|
| 3003 |
user_msg = messages[1]
|
| 3004 |
recent_msgs = messages[-4:]
|
| 3005 |
|
| 3006 |
+
# CRITICAL: Keep complete tool call/response groups together
|
| 3007 |
cleaned_recent = []
|
| 3008 |
+
i = 0
|
| 3009 |
+
while i < len(recent_msgs):
|
| 3010 |
+
msg = recent_msgs[i]
|
| 3011 |
+
role = get_role(msg)
|
| 3012 |
+
|
| 3013 |
+
if role == 'assistant' and has_tool_calls(msg):
|
| 3014 |
+
# Keep assistant with tool calls AND all its tool responses
|
| 3015 |
+
cleaned_recent.append(msg)
|
| 3016 |
+
i += 1
|
| 3017 |
+
while i < len(recent_msgs) and get_role(recent_msgs[i]) == 'tool':
|
| 3018 |
+
cleaned_recent.append(recent_msgs[i])
|
| 3019 |
+
i += 1
|
| 3020 |
+
elif role == 'tool':
|
| 3021 |
+
# Skip orphaned tool message
|
| 3022 |
+
i += 1
|
| 3023 |
else:
|
| 3024 |
cleaned_recent.append(msg)
|
| 3025 |
+
i += 1
|
| 3026 |
|
| 3027 |
messages = [system_msg, user_msg] + cleaned_recent
|
| 3028 |
print(f"⚠️ Emergency pruning (conversation > 8K tokens)")
|