Pulastya B commited on
Commit
fad6660
·
1 Parent(s): 84d9eaa

feat: Enhanced summary with metrics and plot URLs

Browse files

- Extract actual model metrics (R, RMSE, MAE) from workflow
- Collect all generated plots with accessible URLs
- Return structured artifacts (models, reports, data files)
- Build detailed summary with hyperlinks to resources
- Show cross-validation and tuning results
- Fixes superficial summaries with actual performance data

Files changed (1) hide show
  1. src/orchestrator.py +223 -2
src/orchestrator.py CHANGED
@@ -864,6 +864,217 @@ You are a DOER. Complete workflows based on user intent."""
864
 
865
  return next_steps.get(stuck_tool, "generate_eda_plots OR train_baseline_models")
866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
  def _execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
868
  """
869
  Execute a single tool function.
@@ -1912,15 +2123,25 @@ You are a DOER. Complete workflows based on user intent."""
1912
  # Final response
1913
  final_summary = final_content or "Analysis completed"
1914
 
 
 
 
 
 
 
 
1915
  # 🧠 Save conversation to session memory
1916
  if self.session:
1917
- self.session.add_conversation(task_description, final_summary)
1918
  self.session_store.save(self.session)
1919
  print(f"\n✅ Session saved: {self.session.session_id}")
1920
 
1921
  result = {
1922
  "status": "success",
1923
- "summary": final_summary,
 
 
 
1924
  "workflow_history": workflow_history,
1925
  "iterations": iteration,
1926
  "api_calls": self.api_calls_made,
 
864
 
865
  return next_steps.get(stuck_tool, "generate_eda_plots OR train_baseline_models")
866
 
867
+ def _generate_enhanced_summary(
868
+ self,
869
+ workflow_history: List[Dict],
870
+ llm_summary: str,
871
+ task_description: str
872
+ ) -> Dict[str, Any]:
873
+ """
874
+ Generate an enhanced summary with extracted metrics, plots, and artifacts.
875
+
876
+ Args:
877
+ workflow_history: List of executed workflow steps
878
+ llm_summary: Original summary from LLM
879
+ task_description: User's original request
880
+
881
+ Returns:
882
+ Dictionary with enhanced summary text, metrics, and artifacts
883
+ """
884
+ metrics = {}
885
+ artifacts = {
886
+ "models": [],
887
+ "reports": [],
888
+ "data_files": []
889
+ }
890
+ plots = []
891
+
892
+ # Extract information from workflow history
893
+ for step in workflow_history:
894
+ tool = step.get("tool", "")
895
+ result = step.get("result", {})
896
+
897
+ # Skip failed steps
898
+ if not result.get("success", True):
899
+ continue
900
+
901
+ # Extract nested result if present
902
+ nested_result = result.get("result", result)
903
+
904
+ # === EXTRACT MODEL METRICS ===
905
+ if tool == "train_baseline_models":
906
+ if "models" in nested_result:
907
+ models_data = nested_result["models"]
908
+ if models_data:
909
+ # Find best model
910
+ best_model_name = nested_result.get("best_model", "")
911
+ best_model_data = models_data.get(best_model_name, {})
912
+
913
+ metrics["best_model"] = {
914
+ "name": best_model_name,
915
+ "r2_score": best_model_data.get("r2", 0),
916
+ "rmse": best_model_data.get("rmse", 0),
917
+ "mae": best_model_data.get("mae", 0)
918
+ }
919
+
920
+ # All models comparison
921
+ metrics["all_models"] = {
922
+ name: {
923
+ "r2": data.get("r2", 0),
924
+ "rmse": data.get("rmse", 0),
925
+ "mae": data.get("mae", 0)
926
+ }
927
+ for name, data in models_data.items()
928
+ }
929
+
930
+ # Extract model artifacts
931
+ if "model_path" in nested_result:
932
+ artifacts["models"].append({
933
+ "name": nested_result.get("best_model", "model"),
934
+ "path": nested_result["model_path"],
935
+ "url": f"/outputs/models/{nested_result['model_path'].split('/')[-1]}"
936
+ })
937
+
938
+ # Extract performance plots
939
+ if "performance_plots" in nested_result:
940
+ for plot_path in nested_result["performance_plots"]:
941
+ plots.append({
942
+ "title": plot_path.split("/")[-1].replace("_", " ").replace(".png", "").title(),
943
+ "path": plot_path,
944
+ "url": f"/outputs/{plot_path.replace('./outputs/', '')}"
945
+ })
946
+
947
+ if "feature_importance_plot" in nested_result:
948
+ plot_path = nested_result["feature_importance_plot"]
949
+ plots.append({
950
+ "title": "Feature Importance",
951
+ "path": plot_path,
952
+ "url": f"/outputs/{plot_path.replace('./outputs/', '')}"
953
+ })
954
+
955
+ # === HYPERPARAMETER TUNING METRICS ===
956
+ elif tool == "hyperparameter_tuning":
957
+ if "best_score" in nested_result:
958
+ metrics["tuned_model"] = {
959
+ "best_score": nested_result["best_score"],
960
+ "best_params": nested_result.get("best_params", {}),
961
+ "model_type": nested_result.get("model_type", "unknown")
962
+ }
963
+
964
+ if "model_path" in nested_result:
965
+ artifacts["models"].append({
966
+ "name": f"{nested_result.get('model_type', 'model')}_tuned",
967
+ "path": nested_result["model_path"],
968
+ "url": f"/outputs/models/{nested_result['model_path'].split('/')[-1]}"
969
+ })
970
+
971
+ # === CROSS-VALIDATION METRICS ===
972
+ elif tool == "perform_cross_validation":
973
+ if "mean_score" in nested_result:
974
+ metrics["cross_validation"] = {
975
+ "mean_score": nested_result["mean_score"],
976
+ "std_score": nested_result.get("std_score", 0),
977
+ "scores": nested_result.get("scores", [])
978
+ }
979
+
980
+ # === COLLECT REPORT FILES ===
981
+ elif "report" in tool.lower() or "dashboard" in tool.lower():
982
+ if "output_path" in nested_result:
983
+ report_path = nested_result["output_path"]
984
+ artifacts["reports"].append({
985
+ "name": tool.replace("_", " ").title(),
986
+ "path": report_path,
987
+ "url": f"/outputs/{report_path.replace('./outputs/', '')}"
988
+ })
989
+
990
+ # === COLLECT PLOT FILES ===
991
+ if "plot_paths" in nested_result:
992
+ for plot_path in nested_result["plot_paths"]:
993
+ plots.append({
994
+ "title": plot_path.split("/")[-1].replace("_", " ").replace(".png", "").title(),
995
+ "path": plot_path,
996
+ "url": f"/outputs/{plot_path.replace('./outputs/', '')}"
997
+ })
998
+
999
+ # === COLLECT DATA FILES ===
1000
+ if "output_path" in nested_result and nested_result["output_path"].endswith(".csv"):
1001
+ artifacts["data_files"].append({
1002
+ "name": nested_result["output_path"].split("/")[-1],
1003
+ "path": nested_result["output_path"],
1004
+ "url": f"/outputs/{nested_result['output_path'].replace('./outputs/', '')}"
1005
+ })
1006
+
1007
+ # Build enhanced text summary
1008
+ summary_lines = [
1009
+ f"## 📊 Analysis Complete: {task_description}",
1010
+ "",
1011
+ llm_summary,
1012
+ ""
1013
+ ]
1014
+
1015
+ # Add model metrics if available
1016
+ if "best_model" in metrics:
1017
+ best = metrics["best_model"]
1018
+ summary_lines.extend([
1019
+ "### 🏆 Best Model Performance",
1020
+ f"- **Model**: {best['name']}",
1021
+ f"- **R² Score**: {best['r2_score']:.4f}",
1022
+ f"- **RMSE**: {best['rmse']:.4f}",
1023
+ f"- **MAE**: {best['mae']:.4f}",
1024
+ ""
1025
+ ])
1026
+
1027
+ if "tuned_model" in metrics:
1028
+ tuned = metrics["tuned_model"]
1029
+ summary_lines.extend([
1030
+ "### ⚙️ Hyperparameter Tuning",
1031
+ f"- **Model Type**: {tuned['model_type']}",
1032
+ f"- **Best Score**: {tuned['best_score']:.4f}",
1033
+ ""
1034
+ ])
1035
+
1036
+ if "cross_validation" in metrics:
1037
+ cv = metrics["cross_validation"]
1038
+ summary_lines.extend([
1039
+ "### ✅ Cross-Validation Results",
1040
+ f"- **Mean Score**: {cv['mean_score']:.4f} (± {cv['std_score']:.4f})",
1041
+ ""
1042
+ ])
1043
+
1044
+ # Add artifact links
1045
+ if artifacts["models"]:
1046
+ summary_lines.append("### 💾 Trained Models")
1047
+ for model in artifacts["models"]:
1048
+ summary_lines.append(f"- [{model['name']}]({model['url']})")
1049
+ summary_lines.append("")
1050
+
1051
+ if artifacts["reports"]:
1052
+ summary_lines.append("### 📄 Generated Reports")
1053
+ for report in artifacts["reports"]:
1054
+ summary_lines.append(f"- [{report['name']}]({report['url']})")
1055
+ summary_lines.append("")
1056
+
1057
+ if plots:
1058
+ summary_lines.append(f"### 📈 Visualizations ({len(plots)} plots generated)")
1059
+ for plot in plots[:5]: # Show first 5
1060
+ summary_lines.append(f"- [{plot['title']}]({plot['url']})")
1061
+ if len(plots) > 5:
1062
+ summary_lines.append(f"- ... and {len(plots) - 5} more")
1063
+ summary_lines.append("")
1064
+
1065
+ summary_lines.extend([
1066
+ "---",
1067
+ f"**Workflow Steps**: {len([s for s in workflow_history if s.get('result', {}).get('success', True)])} completed",
1068
+ f"**Iterations**: {len(workflow_history)}",
1069
+ ])
1070
+
1071
+ return {
1072
+ "text": "\n".join(summary_lines),
1073
+ "metrics": metrics,
1074
+ "artifacts": artifacts,
1075
+ "plots": plots
1076
+ }
1077
+
1078
  def _execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
1079
  """
1080
  Execute a single tool function.
 
2123
  # Final response
2124
  final_summary = final_content or "Analysis completed"
2125
 
2126
+ # 🎯 ENHANCED SUMMARY: Extract metrics and artifacts from workflow
2127
+ enhanced_summary = self._generate_enhanced_summary(
2128
+ workflow_history,
2129
+ final_summary,
2130
+ task_description
2131
+ )
2132
+
2133
  # 🧠 Save conversation to session memory
2134
  if self.session:
2135
+ self.session.add_conversation(task_description, enhanced_summary["text"])
2136
  self.session_store.save(self.session)
2137
  print(f"\n✅ Session saved: {self.session.session_id}")
2138
 
2139
  result = {
2140
  "status": "success",
2141
+ "summary": enhanced_summary["text"],
2142
+ "metrics": enhanced_summary.get("metrics", {}),
2143
+ "artifacts": enhanced_summary.get("artifacts", {}),
2144
+ "plots": enhanced_summary.get("plots", []),
2145
  "workflow_history": workflow_history,
2146
  "iterations": iteration,
2147
  "api_calls": self.api_calls_made,