Pulastya B commited on
Commit
230b136
·
1 Parent(s): c206ce6

Fix compression crash: Add defensive checks for model data structure

Browse files

- Filter models list to only valid dicts before max() operation
- Add try-except wrapper around entire compression function
- Graceful fallback if compression fails (prevents workflow crash)
- Log compression errors without stopping execution

This fixes AttributeError when models list contains strings instead of dicts, which was causing complete workflow failure mid-execution.

Files changed (1) hide show
  1. src/orchestrator.py +40 -18
src/orchestrator.py CHANGED
@@ -1125,14 +1125,15 @@ You are a DOER. Complete workflows based on user intent."""
1125
  Returns:
1126
  Compressed result dict (typically 100-500 tokens vs 5K-10K)
1127
  """
1128
- if not result.get("success", True):
1129
- # Keep full error info (critical for debugging)
1130
- return result
1131
-
1132
- compressed = {
1133
- "success": True,
1134
- "tool": tool_name
1135
- }
 
1136
 
1137
  # Tool-specific compression rules
1138
  if tool_name == "profile_dataset":
@@ -1173,15 +1174,26 @@ You are a DOER. Complete workflows based on user intent."""
1173
  elif tool_name == "train_baseline_models":
1174
  r = result.get("result", {})
1175
  models = r.get("models", [])
1176
- if models:
1177
- best = max(models, key=lambda m: m.get("test_score", 0))
1178
- compressed["summary"] = {
1179
- "best_model": best.get("model"),
1180
- "test_score": round(best.get("test_score", 0), 4),
1181
- "train_score": round(best.get("train_score", 0), 4),
1182
- "task_type": r.get("task_type"),
1183
- "models_trained": len(models)
1184
- }
 
 
 
 
 
 
 
 
 
 
 
1185
  compressed["next_steps"] = ["hyperparameter_tuning", "generate_combined_eda_report"]
1186
 
1187
  elif tool_name in ["generate_plotly_dashboard", "generate_ydata_profiling_report", "generate_combined_eda_report"]:
@@ -1217,7 +1229,17 @@ You are a DOER. Complete workflows based on user intent."""
1217
  compressed["summary"] = {"result": str(r)[:200] if r else "completed"}
1218
  compressed["next_steps"] = ["Continue workflow"]
1219
 
1220
- return compressed
 
 
 
 
 
 
 
 
 
 
1221
 
1222
 
1223
  def _parse_text_tool_calls(self, text_response: str) -> List[Dict[str, Any]]:
 
1125
  Returns:
1126
  Compressed result dict (typically 100-500 tokens vs 5K-10K)
1127
  """
1128
+ try:
1129
+ if not result.get("success", True):
1130
+ # Keep full error info (critical for debugging)
1131
+ return result
1132
+
1133
+ compressed = {
1134
+ "success": True,
1135
+ "tool": tool_name
1136
+ }
1137
 
1138
  # Tool-specific compression rules
1139
  if tool_name == "profile_dataset":
 
1174
  elif tool_name == "train_baseline_models":
1175
  r = result.get("result", {})
1176
  models = r.get("models", [])
1177
+ if models and isinstance(models, list) and len(models) > 0:
1178
+ # Filter to only dict entries (defensive)
1179
+ valid_models = [m for m in models if isinstance(m, dict) and "test_score" in m]
1180
+ if valid_models:
1181
+ best = max(valid_models, key=lambda m: m.get("test_score", 0))
1182
+ compressed["summary"] = {
1183
+ "best_model": best.get("model"),
1184
+ "test_score": round(best.get("test_score", 0), 4),
1185
+ "train_score": round(best.get("train_score", 0), 4),
1186
+ "task_type": r.get("task_type"),
1187
+ "models_trained": len(valid_models)
1188
+ }
1189
+ else:
1190
+ # Fallback if no valid models
1191
+ compressed["summary"] = {
1192
+ "task_type": r.get("task_type"),
1193
+ "status": "No valid models trained"
1194
+ }
1195
+ else:
1196
+ compressed["summary"] = {"status": "No models found"}
1197
  compressed["next_steps"] = ["hyperparameter_tuning", "generate_combined_eda_report"]
1198
 
1199
  elif tool_name in ["generate_plotly_dashboard", "generate_ydata_profiling_report", "generate_combined_eda_report"]:
 
1229
  compressed["summary"] = {"result": str(r)[:200] if r else "completed"}
1230
  compressed["next_steps"] = ["Continue workflow"]
1231
 
1232
+ return compressed
1233
+
1234
+ except Exception as e:
1235
+ # If compression fails, return minimal safe result
1236
+ print(f"⚠️ Compression failed for {tool_name}: {str(e)}")
1237
+ return {
1238
+ "success": result.get("success", True),
1239
+ "tool": tool_name,
1240
+ "summary": {"status": "completed (compression failed)"},
1241
+ "result": result.get("result", {}) if isinstance(result.get("result"), dict) else {}
1242
+ }
1243
 
1244
 
1245
  def _parse_text_tool_calls(self, text_response: str) -> List[Dict[str, Any]]: