Spaces:
Running
Running
| """ | |
| Agent Routing Test Suite | |
| ======================== | |
| Tests that every query type routes to the correct agent (rule-based router), | |
| that the 'list' command is correctly disambiguated, and that the LLM plan | |
| column-validation guard works. | |
| No Ollama required β all tests use the rule-based fallback router directly. | |
| """ | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from unittest.mock import patch | |
| from core.query_router import QueryRouter | |
| from data.registry import DatasetRegistry | |
| from cli_app.command_handler import _validate_plan_column, _is_list_with_context | |
| router = QueryRouter() | |
| passed = 0 | |
| failed = 0 | |
| def run_test(label, got, expected): | |
| global passed, failed | |
| ok = got == expected | |
| tag = "[PASS]" if ok else "[FAIL]" | |
| print(f"{tag} {label}") | |
| print(f" Expected : {expected}") | |
| print(f" Got : {got}\n") | |
| if ok: | |
| passed += 1 | |
| else: | |
| failed += 1 | |
| def run_bool_test(label, got, expected=True): | |
| global passed, failed | |
| ok = bool(got) == expected | |
| tag = "[PASS]" if ok else "[FAIL]" | |
| print(f"{tag} {label}") | |
| print(f" Expected : {expected}") | |
| print(f" Got : {got}\n") | |
| if ok: | |
| passed += 1 | |
| else: | |
| failed += 1 | |
| print("=" * 60) | |
| print(" Agent Routing Test Suite") | |
| print("=" * 60) | |
| # ββ METADATA AGENT ββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n--- Metadata Agent Routing ---\n") | |
| run_test("Columns query", | |
| router.route("show all columns in leads"), "metadata_agent") | |
| run_test("Numeric columns query", | |
| router.route("what are the numeric columns in leads"), "metadata_agent") | |
| run_test("Categorical columns query", | |
| router.route("list categorical columns in organizations"), "metadata_agent") | |
| run_test("Missing values query", | |
| router.route("how many missing values in people"), "metadata_agent") | |
| run_test("Schema query", | |
| router.route("show schema for organizations"), "metadata_agent") | |
| # ββ DATAFRAME AGENT βββββββββββββββββββββββββββββββββββββββββββ | |
| print("--- DataFrame Agent Routing ---\n") | |
| run_test("Average query", | |
| router.route("average annual_revenue in leads"), "dataframe_agent") | |
| run_test("Mean query", | |
| router.route("mean of employees in organizations"), "dataframe_agent") | |
| run_test("Max query", | |
| router.route("max annual_revenue in leads"), "dataframe_agent") | |
| run_test("Min query", | |
| router.route("min employees in organizations"), "dataframe_agent") | |
| run_test("Top rows query", | |
| router.route("show top 10 rows in leads"), "dataframe_agent") | |
| run_test("Row count query", | |
| router.route("how many rows in leads"), "dataframe_agent") | |
| # ββ VISUALIZATION AGENT βββββββββββββββββββββββββββββββββββββββ | |
| print("--- Visualization Agent Routing ---\n") | |
| run_test("Histogram query", | |
| router.route("histogram of annual_revenue in leads"), "visualization_agent") | |
| run_test("Bar chart query", | |
| router.route("bar chart of industry in leads"), "visualization_agent") | |
| run_test("Plot query", | |
| router.route("plot distribution in organizations"), "visualization_agent") | |
| run_test("Graph query", | |
| router.route("graph of employees"), "visualization_agent") | |
| # ββ TRANSFORMER AGENT β existing ops βββββββββββββββββββββββββ | |
| print("--- Transformer Agent Routing (existing ops) ---\n") | |
| run_test("Drop duplicates", | |
| router.route("drop duplicates in leads"), "transformer_agent") | |
| run_test("Fill nulls", | |
| router.route("fill nulls in organizations"), "transformer_agent") | |
| run_test("Normalize", | |
| router.route("normalize annual_revenue in leads"), "transformer_agent") | |
| run_test("Encode", | |
| router.route("encode industry in leads"), "transformer_agent") | |
| run_test("Rename", | |
| router.route("rename industry to sector in leads"), "transformer_agent") | |
| run_test("Drop column (no metadata collision)", | |
| router.route("drop column description in leads"), "transformer_agent") | |
| run_test("Impute (no metadata collision)", | |
| router.route("impute missing in organizations"), "transformer_agent") | |
| run_test("Strip whitespace", | |
| router.route("strip whitespace in people"), "transformer_agent") | |
| # ββ TRANSFORMER AGENT β new preprocessing ops βββββββββββββββββ | |
| print("--- Transformer Agent Routing (new preprocessing ops) ---\n") | |
| run_test("Standardize", | |
| router.route("standardize number of employees in organizations"), "transformer_agent") | |
| run_test("Z-score keyword", | |
| router.route("z-score normalize founded in organizations"), "transformer_agent") | |
| run_test("Zscore keyword", | |
| router.route("zscore the index column in leads"), "transformer_agent") | |
| run_test("One-hot encoding", | |
| router.route("one hot encode industry in organizations"), "transformer_agent") | |
| run_test("Onehot keyword", | |
| router.route("onehot encode sex in people"), "transformer_agent") | |
| run_test("Dummies keyword", | |
| router.route("get dummies for industry in organizations"), "transformer_agent") | |
| run_test("Fill with mean", | |
| router.route("fill with mean in organizations"), "transformer_agent") | |
| run_test("Fill with median", | |
| router.route("fill nulls with median in leads"), "transformer_agent") | |
| run_test("Fill with mode", | |
| router.route("fill missing using mode in people"), "transformer_agent") | |
| run_test("Fill zero", | |
| router.route("fill with zero in leads"), "transformer_agent") | |
| run_test("Drop missing rows", | |
| router.route("drop missing rows in organizations"), "transformer_agent") | |
| run_test("Drop missing cols", | |
| router.route("drop missing columns in leads"), "transformer_agent") | |
| run_test("Dropna keyword", | |
| router.route("dropna in organizations"), "transformer_agent") | |
| # ββ LIST DISAMBIGUATION βββββββββββββββββββββββββββββββββββββββ | |
| print("--- List Ambiguity Detection ---\n") | |
| run_bool_test("'list columns in leads' β metadata context", | |
| _is_list_with_context("list columns in leads"), expected=True) | |
| run_bool_test("'list all numeric columns in people' β metadata context", | |
| _is_list_with_context("list all numeric columns in people"), expected=True) | |
| run_bool_test("'list' alone β no context (dataset list)", | |
| _is_list_with_context("list"), expected=False) | |
| run_bool_test("'list datasets' β no context (dataset list)", | |
| _is_list_with_context("list datasets"), expected=False) | |
| # ββ COLUMN VALIDATION βββββββββββββββββββββββββββββββββββββββββ | |
| print("--- Column Validation (LLM plan guard) ---\n") | |
| registry = DatasetRegistry() | |
| datasets = registry.list_datasets() | |
| if datasets: | |
| sample_dataset = [d for d in datasets if not d.endswith("_clean")][0] | |
| info = registry.get_info(sample_dataset) | |
| real_columns = info.get("columns", []) | |
| if real_columns: | |
| real_col = real_columns[0] | |
| with patch("cli_app.command_handler.registry", registry): | |
| ok, _ = _validate_plan_column({ | |
| "agent": "transformer_agent", "operation": "fill_mean", | |
| "dataset": sample_dataset, "column": real_col | |
| }) | |
| run_bool_test(f"Valid column '{real_col}' in '{sample_dataset}' β passes", | |
| ok, expected=True) | |
| ok, _ = _validate_plan_column({ | |
| "agent": "transformer_agent", "operation": "standardize", | |
| "dataset": sample_dataset, "column": "ghost_col_xyz" | |
| }) | |
| run_bool_test("Non-existent column 'ghost_col_xyz' β fails validation", | |
| not ok, expected=True) | |
| ok, _ = _validate_plan_column({ | |
| "agent": "transformer_agent", "operation": "drop_missing_rows", | |
| "dataset": sample_dataset, "column": None | |
| }) | |
| run_bool_test("Plan with column=None β always passes", | |
| ok, expected=True) | |
| else: | |
| print("[SKIP] No datasets loaded β skipping column validation tests\n") | |
| # ββ SUMMARY βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("=" * 60) | |
| print(f"Results: {passed} passed, {failed} failed") | |
| if failed == 0: | |
| print("All tests passed.") | |
| print("=" * 60) | |