Mirrowel commited on
Commit
f925da4
Β·
1 Parent(s): b9d9895

fix(gemini): πŸ› add robust JSON parsing and tool response grouping for Gemini CLI

Browse files

This commit addresses multiple issues in the Gemini CLI provider related to malformed tool responses and conversation history management:

- Add `_recursively_parse_json_strings()` to handle JSON-stringified tool arguments and malformed double-encoded JSON
- Implement selective control character unescaping (preserving intentional escapes like \" and \\)
- Add `_fix_tool_response_grouping()` to properly pair function calls with their responses
- Implement ID-based pairing with fallback recovery strategies (name matching, order-based matching, placeholder insertion)
- Add comprehensive logging for debugging ID mismatches and orphaned responses
- Parse tool response content as JSON before wrapping in result object
- Add fallback handling for missing tool_call_id mappings (can occur after context compaction)

The grouping fix prevents API errors caused by linear tool response format and ensures responses are correctly matched to their corresponding function calls even when IDs are lost during context processing.

src/rotator_library/providers/gemini_cli_provider.py CHANGED
@@ -183,6 +183,98 @@ FINISH_REASON_MAP = {
183
  }
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def _env_bool(key: str, default: bool = False) -> bool:
187
  """Get boolean from environment variable."""
188
  return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
@@ -840,23 +932,39 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
840
  elif role == "tool":
841
  tool_call_id = msg.get("tool_call_id")
842
  function_name = tool_call_id_to_name.get(tool_call_id)
843
- if function_name:
844
- # Add prefix for Gemini 3
845
- if is_gemini_3 and self._enable_gemini3_tool_fix:
846
- function_name = f"{self._gemini3_tool_prefix}{function_name}"
847
-
848
- # Wrap the tool response in a 'result' object
849
- response_content = {"result": content}
850
- # Accumulate tool responses - they'll be combined into one user message
851
- pending_tool_parts.append(
852
- {
853
- "functionResponse": {
854
- "name": function_name,
855
- "response": response_content,
856
- "id": tool_call_id,
857
- }
858
- }
859
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
  # Don't add parts here - tool responses are handled via pending_tool_parts
861
  continue
862
 
@@ -872,6 +980,216 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
872
 
873
  return system_instruction, gemini_contents
874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875
  def _handle_reasoning_parameters(
876
  self, payload: Dict[str, Any], model: str
877
  ) -> Optional[Dict[str, Any]]:
@@ -991,9 +1309,12 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
991
  # Get current tool index from accumulator (default 0) and increment
992
  current_tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0
993
 
994
- # Get args and strip _confirm ONLY if it's the sole parameter
 
 
 
 
995
  # This ensures we only strip our injection, not legitimate user params
996
- tool_args = function_call.get("args", {})
997
  if isinstance(tool_args, dict) and "_confirm" in tool_args:
998
  if len(tool_args) == 1:
999
  # _confirm is the only param - this was our injection
@@ -1578,6 +1899,9 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1578
  system_instruction, contents = self._transform_messages(
1579
  kwargs.get("messages", []), model_name
1580
  )
 
 
 
1581
  request_payload = {
1582
  "model": model_name,
1583
  "project": project_id,
@@ -1865,6 +2189,8 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1865
 
1866
  # Transform messages to Gemini format
1867
  system_instruction, contents = self._transform_messages(messages)
 
 
1868
 
1869
  # Build request payload
1870
  request_payload = {
 
183
  }
184
 
185
 
186
+ def _recursively_parse_json_strings(obj: Any) -> Any:
187
+ """
188
+ Recursively parse JSON strings in nested data structures.
189
+
190
+ Gemini sometimes returns tool arguments with JSON-stringified values:
191
+ {"files": "[{...}]"} instead of {"files": [{...}]}.
192
+
193
+ Additionally handles:
194
+ - Malformed double-encoded JSON (extra trailing '}' or ']')
195
+ - Escaped string content (\n, \t, etc.)
196
+ """
197
+ if isinstance(obj, dict):
198
+ return {k: _recursively_parse_json_strings(v) for k, v in obj.items()}
199
+ elif isinstance(obj, list):
200
+ return [_recursively_parse_json_strings(item) for item in obj]
201
+ elif isinstance(obj, str):
202
+ stripped = obj.strip()
203
+
204
+ # Check if string contains control character escape sequences that need unescaping
205
+ # This handles cases where diff content has literal \n or \t instead of actual newlines/tabs
206
+ #
207
+ # IMPORTANT: We intentionally do NOT unescape strings containing \" or \\
208
+ # because these are typically intentional escapes in code/config content
209
+ # (e.g., JSON embedded in YAML: BOT_NAMES_JSON: '["mirrobot", ...]')
210
+ # Unescaping these would corrupt the content and cause issues like
211
+ # oldString and newString becoming identical when they should differ.
212
+ has_control_char_escapes = "\\n" in obj or "\\t" in obj
213
+ has_intentional_escapes = '\\"' in obj or "\\\\" in obj
214
+
215
+ if has_control_char_escapes and not has_intentional_escapes:
216
+ try:
217
+ # Use json.loads with quotes to properly unescape the string
218
+ # This converts \n -> newline, \t -> tab
219
+ unescaped = json.loads(f'"{obj}"')
220
+ # Log the fix with a snippet for debugging
221
+ snippet = obj[:80] + "..." if len(obj) > 80 else obj
222
+ lib_logger.debug(
223
+ f"[GeminiCli] Unescaped control chars in string: "
224
+ f"{len(obj) - len(unescaped)} chars changed. Snippet: {snippet!r}"
225
+ )
226
+ return unescaped
227
+ except (json.JSONDecodeError, ValueError):
228
+ # If unescaping fails, continue with original processing
229
+ pass
230
+
231
+ # Check if it looks like JSON (starts with { or [)
232
+ if stripped and stripped[0] in ("{", "["):
233
+ # Try standard parsing first
234
+ if (stripped.startswith("{") and stripped.endswith("}")) or (
235
+ stripped.startswith("[") and stripped.endswith("]")
236
+ ):
237
+ try:
238
+ parsed = json.loads(obj)
239
+ return _recursively_parse_json_strings(parsed)
240
+ except (json.JSONDecodeError, ValueError):
241
+ pass
242
+
243
+ # Handle malformed JSON: array that doesn't end with ]
244
+ # e.g., '[{"path": "..."}]}' instead of '[{"path": "..."}]'
245
+ if stripped.startswith("[") and not stripped.endswith("]"):
246
+ try:
247
+ # Find the last ] and truncate there
248
+ last_bracket = stripped.rfind("]")
249
+ if last_bracket > 0:
250
+ cleaned = stripped[: last_bracket + 1]
251
+ parsed = json.loads(cleaned)
252
+ lib_logger.warning(
253
+ f"[GeminiCli] Auto-corrected malformed JSON string: "
254
+ f"truncated {len(stripped) - len(cleaned)} extra chars"
255
+ )
256
+ return _recursively_parse_json_strings(parsed)
257
+ except (json.JSONDecodeError, ValueError):
258
+ pass
259
+
260
+ # Handle malformed JSON: object that doesn't end with }
261
+ if stripped.startswith("{") and not stripped.endswith("}"):
262
+ try:
263
+ # Find the last } and truncate there
264
+ last_brace = stripped.rfind("}")
265
+ if last_brace > 0:
266
+ cleaned = stripped[: last_brace + 1]
267
+ parsed = json.loads(cleaned)
268
+ lib_logger.warning(
269
+ f"[GeminiCli] Auto-corrected malformed JSON string: "
270
+ f"truncated {len(stripped) - len(cleaned)} extra chars"
271
+ )
272
+ return _recursively_parse_json_strings(parsed)
273
+ except (json.JSONDecodeError, ValueError):
274
+ pass
275
+ return obj
276
+
277
+
278
  def _env_bool(key: str, default: bool = False) -> bool:
279
  """Get boolean from environment variable."""
280
  return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
 
932
  elif role == "tool":
933
  tool_call_id = msg.get("tool_call_id")
934
  function_name = tool_call_id_to_name.get(tool_call_id)
935
+
936
+ # Log warning if tool_call_id not found in mapping (can happen after context compaction)
937
+ if not function_name:
938
+ lib_logger.warning(
939
+ f"[ID Mismatch] Tool response has ID '{tool_call_id}' which was not found in tool_id_to_name map. "
940
+ f"Available IDs: {list(tool_call_id_to_name.keys())}. Using 'unknown_function' as fallback."
 
 
 
 
 
 
 
 
 
 
941
  )
942
+ function_name = "unknown_function"
943
+
944
+ # Add prefix for Gemini 3
945
+ if is_gemini_3 and self._enable_gemini3_tool_fix:
946
+ function_name = f"{self._gemini3_tool_prefix}{function_name}"
947
+
948
+ # Try to parse content as JSON first, fall back to string
949
+ try:
950
+ parsed_content = (
951
+ json.loads(content) if isinstance(content, str) else content
952
+ )
953
+ except (json.JSONDecodeError, TypeError):
954
+ parsed_content = content
955
+
956
+ # Wrap the tool response in a 'result' object
957
+ response_content = {"result": parsed_content}
958
+ # Accumulate tool responses - they'll be combined into one user message
959
+ pending_tool_parts.append(
960
+ {
961
+ "functionResponse": {
962
+ "name": function_name,
963
+ "response": response_content,
964
+ "id": tool_call_id,
965
+ }
966
+ }
967
+ )
968
  # Don't add parts here - tool responses are handled via pending_tool_parts
969
  continue
970
 
 
980
 
981
  return system_instruction, gemini_contents
982
 
983
+ def _fix_tool_response_grouping(
984
+ self, contents: List[Dict[str, Any]]
985
+ ) -> List[Dict[str, Any]]:
986
+ """
987
+ Group function calls with their responses for Gemini CLI compatibility.
988
+
989
+ Converts linear format (call, response, call, response)
990
+ to grouped format (model with calls, user with all responses).
991
+
992
+ IMPORTANT: Preserves ID-based pairing to prevent mismatches.
993
+ When IDs don't match, attempts recovery by:
994
+ 1. Matching by function name first
995
+ 2. Matching by order if names don't match
996
+ 3. Inserting placeholder responses if responses are missing
997
+ 4. Inserting responses at the CORRECT position (after their corresponding call)
998
+ """
999
+ new_contents = []
1000
+ # Each pending group tracks:
1001
+ # - ids: expected response IDs
1002
+ # - func_names: expected function names (for orphan matching)
1003
+ # - insert_after_idx: position in new_contents where model message was added
1004
+ pending_groups = []
1005
+ collected_responses = {} # Dict mapping ID -> response_part
1006
+
1007
+ for content in contents:
1008
+ role = content.get("role")
1009
+ parts = content.get("parts", [])
1010
+
1011
+ response_parts = [p for p in parts if "functionResponse" in p]
1012
+
1013
+ if response_parts:
1014
+ # Collect responses by ID (ignore duplicates - keep first occurrence)
1015
+ for resp in response_parts:
1016
+ resp_id = resp.get("functionResponse", {}).get("id", "")
1017
+ if resp_id:
1018
+ if resp_id in collected_responses:
1019
+ lib_logger.warning(
1020
+ f"[Grouping] Duplicate response ID detected: {resp_id}. "
1021
+ f"Ignoring duplicate - this may indicate malformed conversation history."
1022
+ )
1023
+ continue
1024
+ collected_responses[resp_id] = resp
1025
+
1026
+ # Try to satisfy pending groups (newest first)
1027
+ for i in range(len(pending_groups) - 1, -1, -1):
1028
+ group = pending_groups[i]
1029
+ group_ids = group["ids"]
1030
+
1031
+ # Check if we have ALL responses for this group
1032
+ if all(gid in collected_responses for gid in group_ids):
1033
+ # Extract responses in the same order as the function calls
1034
+ group_responses = [
1035
+ collected_responses.pop(gid) for gid in group_ids
1036
+ ]
1037
+ new_contents.append({"parts": group_responses, "role": "user"})
1038
+ pending_groups.pop(i)
1039
+ break
1040
+ continue
1041
+
1042
+ if role == "model":
1043
+ func_calls = [p for p in parts if "functionCall" in p]
1044
+ new_contents.append(content)
1045
+ if func_calls:
1046
+ call_ids = [
1047
+ fc.get("functionCall", {}).get("id", "") for fc in func_calls
1048
+ ]
1049
+ call_ids = [cid for cid in call_ids if cid] # Filter empty IDs
1050
+
1051
+ # Also extract function names for orphan matching
1052
+ func_names = [
1053
+ fc.get("functionCall", {}).get("name", "") for fc in func_calls
1054
+ ]
1055
+
1056
+ if call_ids:
1057
+ pending_groups.append(
1058
+ {
1059
+ "ids": call_ids,
1060
+ "func_names": func_names,
1061
+ "insert_after_idx": len(new_contents) - 1,
1062
+ }
1063
+ )
1064
+ else:
1065
+ new_contents.append(content)
1066
+
1067
+ # Handle remaining groups (shouldn't happen in well-formed conversations)
1068
+ # Attempt recovery by matching orphans to unsatisfied calls
1069
+ # Process in REVERSE order of insert_after_idx so insertions don't shift indices
1070
+ pending_groups.sort(key=lambda g: g["insert_after_idx"], reverse=True)
1071
+
1072
+ for group in pending_groups:
1073
+ group_ids = group["ids"]
1074
+ group_func_names = group.get("func_names", [])
1075
+ insert_idx = group["insert_after_idx"] + 1
1076
+ group_responses = []
1077
+
1078
+ lib_logger.debug(
1079
+ f"[Grouping Recovery] Processing unsatisfied group: "
1080
+ f"ids={group_ids}, names={group_func_names}, insert_at={insert_idx}"
1081
+ )
1082
+
1083
+ for i, expected_id in enumerate(group_ids):
1084
+ expected_name = group_func_names[i] if i < len(group_func_names) else ""
1085
+
1086
+ if expected_id in collected_responses:
1087
+ # Direct ID match
1088
+ group_responses.append(collected_responses.pop(expected_id))
1089
+ lib_logger.debug(
1090
+ f"[Grouping Recovery] Direct ID match for '{expected_id}'"
1091
+ )
1092
+ elif collected_responses:
1093
+ # Try to find orphan with matching function name first
1094
+ matched_orphan_id = None
1095
+
1096
+ # First pass: match by function name
1097
+ for orphan_id, orphan_resp in collected_responses.items():
1098
+ orphan_name = orphan_resp.get("functionResponse", {}).get(
1099
+ "name", ""
1100
+ )
1101
+ # Match if names are equal
1102
+ if orphan_name == expected_name:
1103
+ matched_orphan_id = orphan_id
1104
+ lib_logger.debug(
1105
+ f"[Grouping Recovery] Matched orphan '{orphan_id}' by name '{orphan_name}'"
1106
+ )
1107
+ break
1108
+
1109
+ # Second pass: if no name match, try "unknown_function" orphans
1110
+ if not matched_orphan_id:
1111
+ for orphan_id, orphan_resp in collected_responses.items():
1112
+ orphan_name = orphan_resp.get("functionResponse", {}).get(
1113
+ "name", ""
1114
+ )
1115
+ if orphan_name == "unknown_function":
1116
+ matched_orphan_id = orphan_id
1117
+ lib_logger.debug(
1118
+ f"[Grouping Recovery] Matched unknown_function orphan '{orphan_id}' "
1119
+ f"to expected '{expected_name}'"
1120
+ )
1121
+ break
1122
+
1123
+ # Third pass: if still no match, take first available (order-based)
1124
+ if not matched_orphan_id:
1125
+ matched_orphan_id = next(iter(collected_responses))
1126
+ lib_logger.debug(
1127
+ f"[Grouping Recovery] No name match, using first available orphan '{matched_orphan_id}'"
1128
+ )
1129
+
1130
+ if matched_orphan_id:
1131
+ orphan_resp = collected_responses.pop(matched_orphan_id)
1132
+
1133
+ # Fix the ID in the response to match the call
1134
+ old_id = orphan_resp["functionResponse"].get("id", "")
1135
+ orphan_resp["functionResponse"]["id"] = expected_id
1136
+
1137
+ # Fix the name if it was "unknown_function"
1138
+ if (
1139
+ orphan_resp["functionResponse"].get("name")
1140
+ == "unknown_function"
1141
+ and expected_name
1142
+ ):
1143
+ orphan_resp["functionResponse"]["name"] = expected_name
1144
+ lib_logger.info(
1145
+ f"[Grouping Recovery] Fixed function name from 'unknown_function' to '{expected_name}'"
1146
+ )
1147
+
1148
+ lib_logger.warning(
1149
+ f"[Grouping] Auto-repaired ID mismatch: mapped response '{old_id}' "
1150
+ f"to call '{expected_id}' (function: {expected_name})"
1151
+ )
1152
+ group_responses.append(orphan_resp)
1153
+ else:
1154
+ # No responses available - create placeholder
1155
+ placeholder_resp = {
1156
+ "functionResponse": {
1157
+ "name": expected_name or "unknown_function",
1158
+ "response": {
1159
+ "result": {
1160
+ "error": "Tool response was lost during context processing. "
1161
+ "This is a recovered placeholder.",
1162
+ "recovered": True,
1163
+ }
1164
+ },
1165
+ "id": expected_id,
1166
+ }
1167
+ }
1168
+ lib_logger.warning(
1169
+ f"[Grouping Recovery] Created placeholder response for missing tool: "
1170
+ f"id='{expected_id}', name='{expected_name}'"
1171
+ )
1172
+ group_responses.append(placeholder_resp)
1173
+
1174
+ if group_responses:
1175
+ # Insert at the correct position (right after the model message with the calls)
1176
+ new_contents.insert(
1177
+ insert_idx, {"parts": group_responses, "role": "user"}
1178
+ )
1179
+ lib_logger.info(
1180
+ f"[Grouping Recovery] Inserted {len(group_responses)} responses at position {insert_idx} "
1181
+ f"(expected {len(group_ids)})"
1182
+ )
1183
+
1184
+ # Warn about unmatched responses
1185
+ if collected_responses:
1186
+ lib_logger.warning(
1187
+ f"[Grouping] {len(collected_responses)} unmatched responses remaining: "
1188
+ f"ids={list(collected_responses.keys())}"
1189
+ )
1190
+
1191
+ return new_contents
1192
+
1193
  def _handle_reasoning_parameters(
1194
  self, payload: Dict[str, Any], model: str
1195
  ) -> Optional[Dict[str, Any]]:
 
1309
  # Get current tool index from accumulator (default 0) and increment
1310
  current_tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0
1311
 
1312
+ # Get args, recursively parse any JSON strings, and strip _confirm if sole param
1313
+ raw_args = function_call.get("args", {})
1314
+ tool_args = _recursively_parse_json_strings(raw_args)
1315
+
1316
+ # Strip _confirm ONLY if it's the sole parameter
1317
  # This ensures we only strip our injection, not legitimate user params
 
1318
  if isinstance(tool_args, dict) and "_confirm" in tool_args:
1319
  if len(tool_args) == 1:
1320
  # _confirm is the only param - this was our injection
 
1899
  system_instruction, contents = self._transform_messages(
1900
  kwargs.get("messages", []), model_name
1901
  )
1902
+ # Fix tool response grouping (handles ID mismatches, missing responses)
1903
+ contents = self._fix_tool_response_grouping(contents)
1904
+
1905
  request_payload = {
1906
  "model": model_name,
1907
  "project": project_id,
 
2189
 
2190
  # Transform messages to Gemini format
2191
  system_instruction, contents = self._transform_messages(messages)
2192
+ # Fix tool response grouping (handles ID mismatches, missing responses)
2193
+ contents = self._fix_tool_response_grouping(contents)
2194
 
2195
  # Build request payload
2196
  request_payload = {