Spaces:
Running
Running
File size: 8,823 Bytes
9eecab5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """
Agent Routing Test Suite
========================
Tests that every query type routes to the correct agent (rule-based router),
that the 'list' command is correctly disambiguated, and that the LLM plan
column-validation guard works.
No Ollama required β all tests use the rule-based fallback router directly.
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from unittest.mock import patch
from core.query_router import QueryRouter
from data.registry import DatasetRegistry
from cli_app.command_handler import _validate_plan_column, _is_list_with_context
router = QueryRouter()
passed = 0
failed = 0
def run_test(label, got, expected):
global passed, failed
ok = got == expected
tag = "[PASS]" if ok else "[FAIL]"
print(f"{tag} {label}")
print(f" Expected : {expected}")
print(f" Got : {got}\n")
if ok:
passed += 1
else:
failed += 1
def run_bool_test(label, got, expected=True):
global passed, failed
ok = bool(got) == expected
tag = "[PASS]" if ok else "[FAIL]"
print(f"{tag} {label}")
print(f" Expected : {expected}")
print(f" Got : {got}\n")
if ok:
passed += 1
else:
failed += 1
print("=" * 60)
print(" Agent Routing Test Suite")
print("=" * 60)
# ββ METADATA AGENT ββββββββββββββββββββββββββββββββββββββββββββ
print("\n--- Metadata Agent Routing ---\n")
run_test("Columns query",
router.route("show all columns in leads"), "metadata_agent")
run_test("Numeric columns query",
router.route("what are the numeric columns in leads"), "metadata_agent")
run_test("Categorical columns query",
router.route("list categorical columns in organizations"), "metadata_agent")
run_test("Missing values query",
router.route("how many missing values in people"), "metadata_agent")
run_test("Schema query",
router.route("show schema for organizations"), "metadata_agent")
# ββ DATAFRAME AGENT βββββββββββββββββββββββββββββββββββββββββββ
print("--- DataFrame Agent Routing ---\n")
run_test("Average query",
router.route("average annual_revenue in leads"), "dataframe_agent")
run_test("Mean query",
router.route("mean of employees in organizations"), "dataframe_agent")
run_test("Max query",
router.route("max annual_revenue in leads"), "dataframe_agent")
run_test("Min query",
router.route("min employees in organizations"), "dataframe_agent")
run_test("Top rows query",
router.route("show top 10 rows in leads"), "dataframe_agent")
run_test("Row count query",
router.route("how many rows in leads"), "dataframe_agent")
# ββ VISUALIZATION AGENT βββββββββββββββββββββββββββββββββββββββ
print("--- Visualization Agent Routing ---\n")
run_test("Histogram query",
router.route("histogram of annual_revenue in leads"), "visualization_agent")
run_test("Bar chart query",
router.route("bar chart of industry in leads"), "visualization_agent")
run_test("Plot query",
router.route("plot distribution in organizations"), "visualization_agent")
run_test("Graph query",
router.route("graph of employees"), "visualization_agent")
# ββ TRANSFORMER AGENT β existing ops βββββββββββββββββββββββββ
print("--- Transformer Agent Routing (existing ops) ---\n")
run_test("Drop duplicates",
router.route("drop duplicates in leads"), "transformer_agent")
run_test("Fill nulls",
router.route("fill nulls in organizations"), "transformer_agent")
run_test("Normalize",
router.route("normalize annual_revenue in leads"), "transformer_agent")
run_test("Encode",
router.route("encode industry in leads"), "transformer_agent")
run_test("Rename",
router.route("rename industry to sector in leads"), "transformer_agent")
run_test("Drop column (no metadata collision)",
router.route("drop column description in leads"), "transformer_agent")
run_test("Impute (no metadata collision)",
router.route("impute missing in organizations"), "transformer_agent")
run_test("Strip whitespace",
router.route("strip whitespace in people"), "transformer_agent")
# ββ TRANSFORMER AGENT β new preprocessing ops βββββββββββββββββ
print("--- Transformer Agent Routing (new preprocessing ops) ---\n")
run_test("Standardize",
router.route("standardize number of employees in organizations"), "transformer_agent")
run_test("Z-score keyword",
router.route("z-score normalize founded in organizations"), "transformer_agent")
run_test("Zscore keyword",
router.route("zscore the index column in leads"), "transformer_agent")
run_test("One-hot encoding",
router.route("one hot encode industry in organizations"), "transformer_agent")
run_test("Onehot keyword",
router.route("onehot encode sex in people"), "transformer_agent")
run_test("Dummies keyword",
router.route("get dummies for industry in organizations"), "transformer_agent")
run_test("Fill with mean",
router.route("fill with mean in organizations"), "transformer_agent")
run_test("Fill with median",
router.route("fill nulls with median in leads"), "transformer_agent")
run_test("Fill with mode",
router.route("fill missing using mode in people"), "transformer_agent")
run_test("Fill zero",
router.route("fill with zero in leads"), "transformer_agent")
run_test("Drop missing rows",
router.route("drop missing rows in organizations"), "transformer_agent")
run_test("Drop missing cols",
router.route("drop missing columns in leads"), "transformer_agent")
run_test("Dropna keyword",
router.route("dropna in organizations"), "transformer_agent")
# ββ LIST DISAMBIGUATION βββββββββββββββββββββββββββββββββββββββ
print("--- List Ambiguity Detection ---\n")
run_bool_test("'list columns in leads' β metadata context",
_is_list_with_context("list columns in leads"), expected=True)
run_bool_test("'list all numeric columns in people' β metadata context",
_is_list_with_context("list all numeric columns in people"), expected=True)
run_bool_test("'list' alone β no context (dataset list)",
_is_list_with_context("list"), expected=False)
run_bool_test("'list datasets' β no context (dataset list)",
_is_list_with_context("list datasets"), expected=False)
# ββ COLUMN VALIDATION βββββββββββββββββββββββββββββββββββββββββ
print("--- Column Validation (LLM plan guard) ---\n")
registry = DatasetRegistry()
datasets = registry.list_datasets()
if datasets:
sample_dataset = [d for d in datasets if not d.endswith("_clean")][0]
info = registry.get_info(sample_dataset)
real_columns = info.get("columns", [])
if real_columns:
real_col = real_columns[0]
with patch("cli_app.command_handler.registry", registry):
ok, _ = _validate_plan_column({
"agent": "transformer_agent", "operation": "fill_mean",
"dataset": sample_dataset, "column": real_col
})
run_bool_test(f"Valid column '{real_col}' in '{sample_dataset}' β passes",
ok, expected=True)
ok, _ = _validate_plan_column({
"agent": "transformer_agent", "operation": "standardize",
"dataset": sample_dataset, "column": "ghost_col_xyz"
})
run_bool_test("Non-existent column 'ghost_col_xyz' β fails validation",
not ok, expected=True)
ok, _ = _validate_plan_column({
"agent": "transformer_agent", "operation": "drop_missing_rows",
"dataset": sample_dataset, "column": None
})
run_bool_test("Plan with column=None β always passes",
ok, expected=True)
else:
print("[SKIP] No datasets loaded β skipping column validation tests\n")
# ββ SUMMARY βββββββββββββββββββββββββββββββββββββββββββββββββββ
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
if failed == 0:
print("All tests passed.")
print("=" * 60)
|