Spaces:
Running
Running
Pulastya B commited on
Commit ·
fad6660
1
Parent(s): 84d9eaa
feat: Enhanced summary with metrics and plot URLs
Browse files- Extract actual model metrics (R, RMSE, MAE) from workflow
- Collect all generated plots with accessible URLs
- Return structured artifacts (models, reports, data files)
- Build detailed summary with hyperlinks to resources
- Show cross-validation and tuning results
- Fixes superficial summaries with actual performance data
- src/orchestrator.py +223 -2
src/orchestrator.py
CHANGED
|
@@ -864,6 +864,217 @@ You are a DOER. Complete workflows based on user intent."""
|
|
| 864 |
|
| 865 |
return next_steps.get(stuck_tool, "generate_eda_plots OR train_baseline_models")
|
| 866 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
def _execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
| 868 |
"""
|
| 869 |
Execute a single tool function.
|
|
@@ -1912,15 +2123,25 @@ You are a DOER. Complete workflows based on user intent."""
|
|
| 1912 |
# Final response
|
| 1913 |
final_summary = final_content or "Analysis completed"
|
| 1914 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1915 |
# 🧠 Save conversation to session memory
|
| 1916 |
if self.session:
|
| 1917 |
-
self.session.add_conversation(task_description,
|
| 1918 |
self.session_store.save(self.session)
|
| 1919 |
print(f"\n✅ Session saved: {self.session.session_id}")
|
| 1920 |
|
| 1921 |
result = {
|
| 1922 |
"status": "success",
|
| 1923 |
-
"summary":
|
|
|
|
|
|
|
|
|
|
| 1924 |
"workflow_history": workflow_history,
|
| 1925 |
"iterations": iteration,
|
| 1926 |
"api_calls": self.api_calls_made,
|
|
|
|
| 864 |
|
| 865 |
return next_steps.get(stuck_tool, "generate_eda_plots OR train_baseline_models")
|
| 866 |
|
| 867 |
+
def _generate_enhanced_summary(
|
| 868 |
+
self,
|
| 869 |
+
workflow_history: List[Dict],
|
| 870 |
+
llm_summary: str,
|
| 871 |
+
task_description: str
|
| 872 |
+
) -> Dict[str, Any]:
|
| 873 |
+
"""
|
| 874 |
+
Generate an enhanced summary with extracted metrics, plots, and artifacts.
|
| 875 |
+
|
| 876 |
+
Args:
|
| 877 |
+
workflow_history: List of executed workflow steps
|
| 878 |
+
llm_summary: Original summary from LLM
|
| 879 |
+
task_description: User's original request
|
| 880 |
+
|
| 881 |
+
Returns:
|
| 882 |
+
Dictionary with enhanced summary text, metrics, and artifacts
|
| 883 |
+
"""
|
| 884 |
+
metrics = {}
|
| 885 |
+
artifacts = {
|
| 886 |
+
"models": [],
|
| 887 |
+
"reports": [],
|
| 888 |
+
"data_files": []
|
| 889 |
+
}
|
| 890 |
+
plots = []
|
| 891 |
+
|
| 892 |
+
# Extract information from workflow history
|
| 893 |
+
for step in workflow_history:
|
| 894 |
+
tool = step.get("tool", "")
|
| 895 |
+
result = step.get("result", {})
|
| 896 |
+
|
| 897 |
+
# Skip failed steps
|
| 898 |
+
if not result.get("success", True):
|
| 899 |
+
continue
|
| 900 |
+
|
| 901 |
+
# Extract nested result if present
|
| 902 |
+
nested_result = result.get("result", result)
|
| 903 |
+
|
| 904 |
+
# === EXTRACT MODEL METRICS ===
|
| 905 |
+
if tool == "train_baseline_models":
|
| 906 |
+
if "models" in nested_result:
|
| 907 |
+
models_data = nested_result["models"]
|
| 908 |
+
if models_data:
|
| 909 |
+
# Find best model
|
| 910 |
+
best_model_name = nested_result.get("best_model", "")
|
| 911 |
+
best_model_data = models_data.get(best_model_name, {})
|
| 912 |
+
|
| 913 |
+
metrics["best_model"] = {
|
| 914 |
+
"name": best_model_name,
|
| 915 |
+
"r2_score": best_model_data.get("r2", 0),
|
| 916 |
+
"rmse": best_model_data.get("rmse", 0),
|
| 917 |
+
"mae": best_model_data.get("mae", 0)
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
# All models comparison
|
| 921 |
+
metrics["all_models"] = {
|
| 922 |
+
name: {
|
| 923 |
+
"r2": data.get("r2", 0),
|
| 924 |
+
"rmse": data.get("rmse", 0),
|
| 925 |
+
"mae": data.get("mae", 0)
|
| 926 |
+
}
|
| 927 |
+
for name, data in models_data.items()
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
# Extract model artifacts
|
| 931 |
+
if "model_path" in nested_result:
|
| 932 |
+
artifacts["models"].append({
|
| 933 |
+
"name": nested_result.get("best_model", "model"),
|
| 934 |
+
"path": nested_result["model_path"],
|
| 935 |
+
"url": f"/outputs/models/{nested_result['model_path'].split('/')[-1]}"
|
| 936 |
+
})
|
| 937 |
+
|
| 938 |
+
# Extract performance plots
|
| 939 |
+
if "performance_plots" in nested_result:
|
| 940 |
+
for plot_path in nested_result["performance_plots"]:
|
| 941 |
+
plots.append({
|
| 942 |
+
"title": plot_path.split("/")[-1].replace("_", " ").replace(".png", "").title(),
|
| 943 |
+
"path": plot_path,
|
| 944 |
+
"url": f"/outputs/{plot_path.replace('./outputs/', '')}"
|
| 945 |
+
})
|
| 946 |
+
|
| 947 |
+
if "feature_importance_plot" in nested_result:
|
| 948 |
+
plot_path = nested_result["feature_importance_plot"]
|
| 949 |
+
plots.append({
|
| 950 |
+
"title": "Feature Importance",
|
| 951 |
+
"path": plot_path,
|
| 952 |
+
"url": f"/outputs/{plot_path.replace('./outputs/', '')}"
|
| 953 |
+
})
|
| 954 |
+
|
| 955 |
+
# === HYPERPARAMETER TUNING METRICS ===
|
| 956 |
+
elif tool == "hyperparameter_tuning":
|
| 957 |
+
if "best_score" in nested_result:
|
| 958 |
+
metrics["tuned_model"] = {
|
| 959 |
+
"best_score": nested_result["best_score"],
|
| 960 |
+
"best_params": nested_result.get("best_params", {}),
|
| 961 |
+
"model_type": nested_result.get("model_type", "unknown")
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
+
if "model_path" in nested_result:
|
| 965 |
+
artifacts["models"].append({
|
| 966 |
+
"name": f"{nested_result.get('model_type', 'model')}_tuned",
|
| 967 |
+
"path": nested_result["model_path"],
|
| 968 |
+
"url": f"/outputs/models/{nested_result['model_path'].split('/')[-1]}"
|
| 969 |
+
})
|
| 970 |
+
|
| 971 |
+
# === CROSS-VALIDATION METRICS ===
|
| 972 |
+
elif tool == "perform_cross_validation":
|
| 973 |
+
if "mean_score" in nested_result:
|
| 974 |
+
metrics["cross_validation"] = {
|
| 975 |
+
"mean_score": nested_result["mean_score"],
|
| 976 |
+
"std_score": nested_result.get("std_score", 0),
|
| 977 |
+
"scores": nested_result.get("scores", [])
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
# === COLLECT REPORT FILES ===
|
| 981 |
+
elif "report" in tool.lower() or "dashboard" in tool.lower():
|
| 982 |
+
if "output_path" in nested_result:
|
| 983 |
+
report_path = nested_result["output_path"]
|
| 984 |
+
artifacts["reports"].append({
|
| 985 |
+
"name": tool.replace("_", " ").title(),
|
| 986 |
+
"path": report_path,
|
| 987 |
+
"url": f"/outputs/{report_path.replace('./outputs/', '')}"
|
| 988 |
+
})
|
| 989 |
+
|
| 990 |
+
# === COLLECT PLOT FILES ===
|
| 991 |
+
if "plot_paths" in nested_result:
|
| 992 |
+
for plot_path in nested_result["plot_paths"]:
|
| 993 |
+
plots.append({
|
| 994 |
+
"title": plot_path.split("/")[-1].replace("_", " ").replace(".png", "").title(),
|
| 995 |
+
"path": plot_path,
|
| 996 |
+
"url": f"/outputs/{plot_path.replace('./outputs/', '')}"
|
| 997 |
+
})
|
| 998 |
+
|
| 999 |
+
# === COLLECT DATA FILES ===
|
| 1000 |
+
if "output_path" in nested_result and nested_result["output_path"].endswith(".csv"):
|
| 1001 |
+
artifacts["data_files"].append({
|
| 1002 |
+
"name": nested_result["output_path"].split("/")[-1],
|
| 1003 |
+
"path": nested_result["output_path"],
|
| 1004 |
+
"url": f"/outputs/{nested_result['output_path'].replace('./outputs/', '')}"
|
| 1005 |
+
})
|
| 1006 |
+
|
| 1007 |
+
# Build enhanced text summary
|
| 1008 |
+
summary_lines = [
|
| 1009 |
+
f"## 📊 Analysis Complete: {task_description}",
|
| 1010 |
+
"",
|
| 1011 |
+
llm_summary,
|
| 1012 |
+
""
|
| 1013 |
+
]
|
| 1014 |
+
|
| 1015 |
+
# Add model metrics if available
|
| 1016 |
+
if "best_model" in metrics:
|
| 1017 |
+
best = metrics["best_model"]
|
| 1018 |
+
summary_lines.extend([
|
| 1019 |
+
"### 🏆 Best Model Performance",
|
| 1020 |
+
f"- **Model**: {best['name']}",
|
| 1021 |
+
f"- **R² Score**: {best['r2_score']:.4f}",
|
| 1022 |
+
f"- **RMSE**: {best['rmse']:.4f}",
|
| 1023 |
+
f"- **MAE**: {best['mae']:.4f}",
|
| 1024 |
+
""
|
| 1025 |
+
])
|
| 1026 |
+
|
| 1027 |
+
if "tuned_model" in metrics:
|
| 1028 |
+
tuned = metrics["tuned_model"]
|
| 1029 |
+
summary_lines.extend([
|
| 1030 |
+
"### ⚙️ Hyperparameter Tuning",
|
| 1031 |
+
f"- **Model Type**: {tuned['model_type']}",
|
| 1032 |
+
f"- **Best Score**: {tuned['best_score']:.4f}",
|
| 1033 |
+
""
|
| 1034 |
+
])
|
| 1035 |
+
|
| 1036 |
+
if "cross_validation" in metrics:
|
| 1037 |
+
cv = metrics["cross_validation"]
|
| 1038 |
+
summary_lines.extend([
|
| 1039 |
+
"### ✅ Cross-Validation Results",
|
| 1040 |
+
f"- **Mean Score**: {cv['mean_score']:.4f} (± {cv['std_score']:.4f})",
|
| 1041 |
+
""
|
| 1042 |
+
])
|
| 1043 |
+
|
| 1044 |
+
# Add artifact links
|
| 1045 |
+
if artifacts["models"]:
|
| 1046 |
+
summary_lines.append("### 💾 Trained Models")
|
| 1047 |
+
for model in artifacts["models"]:
|
| 1048 |
+
summary_lines.append(f"- [{model['name']}]({model['url']})")
|
| 1049 |
+
summary_lines.append("")
|
| 1050 |
+
|
| 1051 |
+
if artifacts["reports"]:
|
| 1052 |
+
summary_lines.append("### 📄 Generated Reports")
|
| 1053 |
+
for report in artifacts["reports"]:
|
| 1054 |
+
summary_lines.append(f"- [{report['name']}]({report['url']})")
|
| 1055 |
+
summary_lines.append("")
|
| 1056 |
+
|
| 1057 |
+
if plots:
|
| 1058 |
+
summary_lines.append(f"### 📈 Visualizations ({len(plots)} plots generated)")
|
| 1059 |
+
for plot in plots[:5]: # Show first 5
|
| 1060 |
+
summary_lines.append(f"- [{plot['title']}]({plot['url']})")
|
| 1061 |
+
if len(plots) > 5:
|
| 1062 |
+
summary_lines.append(f"- ... and {len(plots) - 5} more")
|
| 1063 |
+
summary_lines.append("")
|
| 1064 |
+
|
| 1065 |
+
summary_lines.extend([
|
| 1066 |
+
"---",
|
| 1067 |
+
f"**Workflow Steps**: {len([s for s in workflow_history if s.get('result', {}).get('success', True)])} completed",
|
| 1068 |
+
f"**Iterations**: {len(workflow_history)}",
|
| 1069 |
+
])
|
| 1070 |
+
|
| 1071 |
+
return {
|
| 1072 |
+
"text": "\n".join(summary_lines),
|
| 1073 |
+
"metrics": metrics,
|
| 1074 |
+
"artifacts": artifacts,
|
| 1075 |
+
"plots": plots
|
| 1076 |
+
}
|
| 1077 |
+
|
| 1078 |
def _execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
| 1079 |
"""
|
| 1080 |
Execute a single tool function.
|
|
|
|
| 2123 |
# Final response
|
| 2124 |
final_summary = final_content or "Analysis completed"
|
| 2125 |
|
| 2126 |
+
# 🎯 ENHANCED SUMMARY: Extract metrics and artifacts from workflow
|
| 2127 |
+
enhanced_summary = self._generate_enhanced_summary(
|
| 2128 |
+
workflow_history,
|
| 2129 |
+
final_summary,
|
| 2130 |
+
task_description
|
| 2131 |
+
)
|
| 2132 |
+
|
| 2133 |
# 🧠 Save conversation to session memory
|
| 2134 |
if self.session:
|
| 2135 |
+
self.session.add_conversation(task_description, enhanced_summary["text"])
|
| 2136 |
self.session_store.save(self.session)
|
| 2137 |
print(f"\n✅ Session saved: {self.session.session_id}")
|
| 2138 |
|
| 2139 |
result = {
|
| 2140 |
"status": "success",
|
| 2141 |
+
"summary": enhanced_summary["text"],
|
| 2142 |
+
"metrics": enhanced_summary.get("metrics", {}),
|
| 2143 |
+
"artifacts": enhanced_summary.get("artifacts", {}),
|
| 2144 |
+
"plots": enhanced_summary.get("plots", []),
|
| 2145 |
"workflow_history": workflow_history,
|
| 2146 |
"iterations": iteration,
|
| 2147 |
"api_calls": self.api_calls_made,
|