Pulastya B commited on
Commit
fe14b09
Β·
1 Parent(s): c073e6b

Added SBERT semantic routing and EDA Safety Guard Rails

Browse files
Files changed (2) hide show
  1. src/orchestrator.py +245 -2
  2. src/reasoning/evaluator.py +12 -2
src/orchestrator.py CHANGED
@@ -2251,6 +2251,40 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
2251
  print(f" βœ“ Stripped invalid parameter '{invalid_param}': {val}")
2252
  print(f" ℹ️ create_statistical_features creates row-wise stats (mean, std, min, max)")
2253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2254
  # πŸ”₯ FIX: Generic parameter sanitization - strip any unknown kwargs
2255
  # This prevents "got an unexpected keyword argument" errors from LLM hallucinations
2256
  import inspect
@@ -2653,6 +2687,61 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
2653
  "trials_completed": r.get("n_trials")
2654
  }
2655
  compressed["next_steps"] = ["perform_cross_validation", "generate_model_performance_plots"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2656
 
2657
  else:
2658
  # Generic compression: Keep only key fields
@@ -3071,6 +3160,109 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
3071
 
3072
  return "\n".join(lines)
3073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3074
  def _run_reasoning_loop(
3075
  self,
3076
  question: str,
@@ -3112,8 +3304,59 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
3112
  synthesizer = Synthesizer(llm_caller=self._llm_text_call)
3113
  findings = FindingsAccumulator(question=question, mode=mode)
3114
 
3115
- # Get tools description for the reasoner
3116
- tools_desc = self._get_tools_description(tool_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3117
 
3118
  # Track for API response
3119
  workflow_history = []
 
2251
  print(f" βœ“ Stripped invalid parameter '{invalid_param}': {val}")
2252
  print(f" ℹ️ create_statistical_features creates row-wise stats (mean, std, min, max)")
2253
 
2254
+ # πŸ”§ FIX: analyze_autogluon_model path resolution
2255
+ # The Reasoner hallucinates model paths β€” resolve to actual saved path
2256
+ if tool_name == "analyze_autogluon_model":
2257
+ model_path = arguments.get("model_path", "")
2258
+ if model_path and not Path(model_path).exists():
2259
+ # Try the default AutoGluon output dir
2260
+ fallback_paths = [
2261
+ "./outputs/autogluon_model",
2262
+ "outputs/autogluon_model",
2263
+ "/tmp/data_science_agent/outputs/autogluon_model",
2264
+ ]
2265
+ for fallback in fallback_paths:
2266
+ if Path(fallback).exists():
2267
+ print(f" βœ“ Fixed model_path: '{model_path}' β†’ '{fallback}'")
2268
+ arguments["model_path"] = fallback
2269
+ break
2270
+ else:
2271
+ print(f" ⚠️ Model path '{model_path}' not found, no fallback available")
2272
+
2273
+ # πŸ”§ FIX: predict_with_autogluon path resolution (same issue)
2274
+ if tool_name == "predict_with_autogluon":
2275
+ model_path = arguments.get("model_path", "")
2276
+ if model_path and not Path(model_path).exists():
2277
+ fallback_paths = [
2278
+ "./outputs/autogluon_model",
2279
+ "outputs/autogluon_model",
2280
+ "/tmp/data_science_agent/outputs/autogluon_model",
2281
+ ]
2282
+ for fallback in fallback_paths:
2283
+ if Path(fallback).exists():
2284
+ print(f" βœ“ Fixed model_path: '{model_path}' β†’ '{fallback}'")
2285
+ arguments["model_path"] = fallback
2286
+ break
2287
+
2288
  # πŸ”₯ FIX: Generic parameter sanitization - strip any unknown kwargs
2289
  # This prevents "got an unexpected keyword argument" errors from LLM hallucinations
2290
  import inspect
 
2687
  "trials_completed": r.get("n_trials")
2688
  }
2689
  compressed["next_steps"] = ["perform_cross_validation", "generate_model_performance_plots"]
2690
+
2691
+ # ── Feature importance / selection tools ──
2692
+ elif tool_name == "auto_feature_selection":
2693
+ r = result.get("result", {})
2694
+ # Preserve the actual feature scores β€” this IS the answer for "feature importance" queries
2695
+ feature_scores = r.get("feature_scores", r.get("feature_rankings", {}))
2696
+ # Keep top 15 features max
2697
+ if isinstance(feature_scores, dict):
2698
+ sorted_feats = sorted(feature_scores.items(), key=lambda x: abs(float(x[1])) if x[1] is not None else 0, reverse=True)[:15]
2699
+ feature_scores = {k: round(float(v), 4) if v is not None else 0 for k, v in sorted_feats}
2700
+ compressed["summary"] = {
2701
+ "n_features_original": r.get("n_features_original"),
2702
+ "n_features_selected": r.get("n_features_selected"),
2703
+ "selected_features": r.get("selected_features", [])[:15],
2704
+ "feature_scores": feature_scores,
2705
+ "selection_method": r.get("selection_method"),
2706
+ "task_type": r.get("task_type"),
2707
+ "output_path": r.get("output_path")
2708
+ }
2709
+ compressed["next_steps"] = ["analyze_correlations", "generate_eda_plots"]
2710
+
2711
+ elif tool_name == "analyze_correlations":
2712
+ r = result.get("result", {})
2713
+ # Preserve high correlations and target correlations β€” key analytical data
2714
+ high_corrs = r.get("high_correlations", [])[:10] # Top 10 pairs
2715
+ target_corrs = r.get("target_correlations", {})
2716
+ if isinstance(target_corrs, dict) and "top_features" in target_corrs:
2717
+ target_corrs = {
2718
+ "target": target_corrs.get("target"),
2719
+ "top_features": target_corrs["top_features"][:10]
2720
+ }
2721
+ compressed["summary"] = {
2722
+ "numeric_columns_count": len(r.get("numeric_columns", [])),
2723
+ "high_correlations": high_corrs,
2724
+ "target_correlations": target_corrs,
2725
+ }
2726
+ compressed["next_steps"] = ["auto_feature_selection", "generate_eda_plots"]
2727
+
2728
+ elif tool_name in ["train_with_autogluon", "analyze_autogluon_model"]:
2729
+ r = result.get("result", {})
2730
+ # Preserve model metrics AND feature importance
2731
+ feature_importance = r.get("feature_importance", [])
2732
+ if isinstance(feature_importance, list):
2733
+ feature_importance = feature_importance[:10] # Top 10 features
2734
+ compressed["summary"] = {
2735
+ "task_type": r.get("task_type"),
2736
+ "best_model": r.get("best_model"),
2737
+ "best_score": r.get("best_score"),
2738
+ "eval_metric": r.get("eval_metric"),
2739
+ "n_models_trained": r.get("n_models_trained"),
2740
+ "feature_importance": feature_importance,
2741
+ "model_path": r.get("model_path", r.get("output_path")),
2742
+ "training_time_seconds": r.get("training_time_seconds")
2743
+ }
2744
+ compressed["next_steps"] = ["predict_with_autogluon", "generate_model_report"]
2745
 
2746
  else:
2747
  # Generic compression: Keep only key fields
 
3160
 
3161
  return "\n".join(lines)
3162
 
3163
+ def _get_relevant_tools_sbert(
3164
+ self,
3165
+ query: str,
3166
+ candidate_tools: Optional[set] = None,
3167
+ top_k: int = 20,
3168
+ threshold: float = 0.15
3169
+ ) -> set:
3170
+ """
3171
+ Use SBERT semantic similarity to rank tools by relevance to the query.
3172
+
3173
+ Encodes the query and each tool's (name + docstring) into embeddings,
3174
+ then keeps only tools whose cosine similarity exceeds the threshold.
3175
+ Tool embeddings are lazily computed and cached for the lifetime of the
3176
+ orchestrator instance.
3177
+
3178
+ Args:
3179
+ query: User's natural language question
3180
+ candidate_tools: Tools to score (default: all tool_functions)
3181
+ top_k: Max number of tools to return
3182
+ threshold: Minimum cosine similarity to include a tool (0.0-1.0)
3183
+
3184
+ Returns:
3185
+ Set of tool names that are semantically relevant to the query.
3186
+ Falls back to candidate_tools unchanged if SBERT is unavailable.
3187
+ """
3188
+ if not self.semantic_layer.enabled:
3189
+ return candidate_tools or set(self.tool_functions.keys())
3190
+
3191
+ try:
3192
+ from sklearn.metrics.pairwise import cosine_similarity as cos_sim
3193
+ import numpy as np
3194
+ except ImportError:
3195
+ return candidate_tools or set(self.tool_functions.keys())
3196
+
3197
+ candidates = candidate_tools or set(self.tool_functions.keys())
3198
+
3199
+ # ── Lazily build & cache tool embeddings ──
3200
+ if not hasattr(self, '_tool_embeddings_cache'):
3201
+ self._tool_embeddings_cache = {}
3202
+
3203
+ # Compute embeddings for any tools not yet cached
3204
+ tools_needing_embed = [t for t in candidates if t not in self._tool_embeddings_cache]
3205
+ if tools_needing_embed:
3206
+ texts = []
3207
+ for name in tools_needing_embed:
3208
+ func = self.tool_functions.get(name)
3209
+ doc = (func.__doc__ or "").strip().split("\n")[0][:150] if func else ""
3210
+ texts.append(f"{name}: {doc}")
3211
+
3212
+ try:
3213
+ embeddings = self.semantic_layer.model.encode(
3214
+ texts, convert_to_numpy=True, show_progress_bar=False, batch_size=32
3215
+ )
3216
+ for name, emb in zip(tools_needing_embed, embeddings):
3217
+ self._tool_embeddings_cache[name] = emb
3218
+ except Exception as e:
3219
+ print(f"⚠️ SBERT tool encoding failed: {e}, returning all candidates")
3220
+ return candidates
3221
+
3222
+ # ── Encode the query ──
3223
+ try:
3224
+ query_emb = self.semantic_layer.model.encode(
3225
+ query, convert_to_numpy=True, show_progress_bar=False
3226
+ ).reshape(1, -1)
3227
+ except Exception as e:
3228
+ print(f"⚠️ SBERT query encoding failed: {e}")
3229
+ return candidates
3230
+
3231
+ # ── Score each candidate tool ──
3232
+ scored = []
3233
+ for name in candidates:
3234
+ emb = self._tool_embeddings_cache.get(name)
3235
+ if emb is None:
3236
+ continue
3237
+ sim = float(cos_sim(query_emb, emb.reshape(1, -1))[0][0])
3238
+ scored.append((name, sim))
3239
+
3240
+ # Sort descending by similarity
3241
+ scored.sort(key=lambda x: x[1], reverse=True)
3242
+
3243
+ # Keep tools above threshold, up to top_k
3244
+ selected = {name for name, sim in scored[:top_k] if sim >= threshold}
3245
+
3246
+ # ── Always include universally-useful core tools ──
3247
+ CORE_TOOLS = {
3248
+ "profile_dataset", "analyze_correlations", "auto_feature_selection",
3249
+ "generate_eda_plots", "clean_missing_values",
3250
+ "execute_python_code",
3251
+ }
3252
+ selected |= (CORE_TOOLS & candidates)
3253
+
3254
+ if selected:
3255
+ # Log what SBERT chose
3256
+ top5 = scored[:5]
3257
+ print(f" 🧠 SBERT tool routing: {len(selected)}/{len(candidates)} tools selected")
3258
+ print(f" Top-5 by similarity: {[(n, f'{s:.3f}') for n, s in top5]}")
3259
+ else:
3260
+ # Safety: if nothing passed threshold, return all candidates
3261
+ print(f" ⚠️ SBERT: no tools above threshold {threshold}, using all {len(candidates)} candidates")
3262
+ selected = candidates
3263
+
3264
+ return selected
3265
+
3266
  def _run_reasoning_loop(
3267
  self,
3268
  question: str,
 
3304
  synthesizer = Synthesizer(llm_caller=self._llm_text_call)
3305
  findings = FindingsAccumulator(question=question, mode=mode)
3306
 
3307
+ # ── Intelligent tool filtering for the reasoning loop ──
3308
+ # Step 1: Hard-exclude tools that can never work in the reasoning loop
3309
+ EXCLUDED_FROM_REASONING = {
3310
+ "generate_feature_importance_plot", # needs Dict[str, float] β€” Reasoner can't supply
3311
+ }
3312
+ TRAINING_TOOLS = {
3313
+ "train_with_autogluon", "train_baseline_models", "train_model",
3314
+ "hyperparameter_tuning", "predict_with_autogluon",
3315
+ "analyze_autogluon_model", "advanced_model_training",
3316
+ "neural_architecture_search"
3317
+ }
3318
+
3319
+ # Build initial candidate pool
3320
+ effective_tool_names = set(tool_names) if tool_names else set(self.tool_functions.keys())
3321
+ effective_tool_names -= EXCLUDED_FROM_REASONING
3322
+
3323
+ # Step 2: SBERT semantic routing β€” score tools against the query
3324
+ # This replaces the old keyword-only approach with real semantic understanding
3325
+ if self.semantic_layer.enabled:
3326
+ print(f" 🧠 Using SBERT semantic routing for tool selection...")
3327
+ effective_tool_names = self._get_relevant_tools_sbert(
3328
+ query=question,
3329
+ candidate_tools=effective_tool_names,
3330
+ top_k=20,
3331
+ threshold=0.15
3332
+ )
3333
+
3334
+ # Step 3: Hard safety rail β€” even if SBERT scores a training tool highly,
3335
+ # block it for pure EDA queries (training wastes 120-180s for no benefit)
3336
+ question_lower = question.lower()
3337
+ explicitly_wants_training = any(kw in question_lower for kw in [
3338
+ "train", "predict", "build a model", "classification", "regression",
3339
+ "classify", "forecast", "deploy model", "autogluon"
3340
+ ])
3341
+ if not explicitly_wants_training:
3342
+ EDA_KEYWORDS = [
3343
+ "feature importance", "important features", "most important",
3344
+ "correlations", "correlation", "explore", "explain",
3345
+ "understand", "patterns", "insights", "eda", "profiling",
3346
+ "distribution", "outliers", "summary", "describe", "overview",
3347
+ "what drives", "what affects", "key factors", "top features",
3348
+ "feature ranking", "data quality", "missing values"
3349
+ ]
3350
+ is_eda_query = any(kw in question_lower for kw in EDA_KEYWORDS)
3351
+ if is_eda_query:
3352
+ removed = effective_tool_names & TRAINING_TOOLS
3353
+ if removed:
3354
+ print(f" 🚫 EDA safety rail β€” removing training tools: {removed}")
3355
+ effective_tool_names -= TRAINING_TOOLS
3356
+
3357
+ # Get tools description for the reasoner (filtered)
3358
+ tools_desc = self._get_tools_description(list(effective_tool_names))
3359
+ print(f" πŸ“‹ Reasoning loop will see {len(effective_tool_names)} tools (of {len(self.tool_functions)})")
3360
 
3361
  # Track for API response
3362
  workflow_history = []
src/reasoning/evaluator.py CHANGED
@@ -51,6 +51,12 @@ Be concise but insightful. Focus on:
51
  - Confounders and caveats
52
  - What's surprising vs expected
53
 
 
 
 
 
 
 
54
  CRITICAL: Output ONLY valid JSON, no other text."""
55
 
56
  EVALUATOR_USER_TEMPLATE = """**User's original question**: {question}
@@ -77,11 +83,15 @@ Guidelines for should_stop:
77
  - true: Question is fully answered OR we've gathered enough evidence OR no more useful actions
78
  - false: Important aspects remain uninvestigated
79
 
 
 
 
 
80
  Guidelines for confidence:
81
  - 0.0-0.3: Weak evidence, need more investigation
82
  - 0.3-0.6: Moderate evidence, some aspects unclear
83
- - 0.6-0.8: Strong evidence, minor questions remain
84
- - 0.8-1.0: Very strong evidence, question well answered"""
85
 
86
 
87
  class Evaluator:
 
51
  - Confounders and caveats
52
  - What's surprising vs expected
53
 
54
+ IMPORTANT CONFIDENCE RULES:
55
+ - If the tool returned feature_scores, feature_importance, or correlation values, and the user asked about features/importance/correlations β†’ this IS the answer. Set answered=true, confidence β‰₯ 0.7.
56
+ - If the tool returned actual ranked data (top features, sorted scores, correlation pairs), set confidence β‰₯ 0.6.
57
+ - Do NOT keep saying "not answered" when the tool literally returned the requested information.
58
+ - Only say answered=false when the result is genuinely unrelated to the question or contains NO useful data.
59
+
60
  CRITICAL: Output ONLY valid JSON, no other text."""
61
 
62
  EVALUATOR_USER_TEMPLATE = """**User's original question**: {question}
 
83
  - true: Question is fully answered OR we've gathered enough evidence OR no more useful actions
84
  - false: Important aspects remain uninvestigated
85
 
86
+ Guidelines for answered:
87
+ - true: The result contains data that directly addresses the user's question (e.g., feature scores for "which features are important?", correlations for "what correlates with X?")
88
+ - false: Result is unrelated to the question or contains only metadata without actual answers
89
+
90
  Guidelines for confidence:
91
  - 0.0-0.3: Weak evidence, need more investigation
92
  - 0.3-0.6: Moderate evidence, some aspects unclear
93
+ - 0.6-0.8: Strong evidence, minor questions remain (e.g., got feature importance scores but could add more context)
94
+ - 0.8-1.0: Very strong evidence, question well answered (e.g., got ranked feature list with scores AND correlations)"""
95
 
96
 
97
  class Evaluator: