File size: 8,823 Bytes
9eecab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Agent Routing Test Suite
========================
Tests that every query type routes to the correct agent (rule-based router),
that the 'list' command is correctly disambiguated, and that the LLM plan
column-validation guard works.

No Ollama required β€” all tests use the rule-based fallback router directly.
"""

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from unittest.mock import patch
from core.query_router import QueryRouter
from data.registry import DatasetRegistry
from cli_app.command_handler import _validate_plan_column, _is_list_with_context

router = QueryRouter()

passed = 0
failed = 0


def run_test(label, got, expected):
    global passed, failed
    ok = got == expected
    tag = "[PASS]" if ok else "[FAIL]"
    print(f"{tag} {label}")
    print(f"       Expected : {expected}")
    print(f"       Got      : {got}\n")
    if ok:
        passed += 1
    else:
        failed += 1


def run_bool_test(label, got, expected=True):
    global passed, failed
    ok = bool(got) == expected
    tag = "[PASS]" if ok else "[FAIL]"
    print(f"{tag} {label}")
    print(f"       Expected : {expected}")
    print(f"       Got      : {got}\n")
    if ok:
        passed += 1
    else:
        failed += 1


print("=" * 60)
print("  Agent Routing Test Suite")
print("=" * 60)


# ── METADATA AGENT ────────────────────────────────────────────
print("\n--- Metadata Agent Routing ---\n")

run_test("Columns query",
         router.route("show all columns in leads"), "metadata_agent")
run_test("Numeric columns query",
         router.route("what are the numeric columns in leads"), "metadata_agent")
run_test("Categorical columns query",
         router.route("list categorical columns in organizations"), "metadata_agent")
run_test("Missing values query",
         router.route("how many missing values in people"), "metadata_agent")
run_test("Schema query",
         router.route("show schema for organizations"), "metadata_agent")


# ── DATAFRAME AGENT ───────────────────────────────────────────
print("--- DataFrame Agent Routing ---\n")

run_test("Average query",
         router.route("average annual_revenue in leads"), "dataframe_agent")
run_test("Mean query",
         router.route("mean of employees in organizations"), "dataframe_agent")
run_test("Max query",
         router.route("max annual_revenue in leads"), "dataframe_agent")
run_test("Min query",
         router.route("min employees in organizations"), "dataframe_agent")
run_test("Top rows query",
         router.route("show top 10 rows in leads"), "dataframe_agent")
run_test("Row count query",
         router.route("how many rows in leads"), "dataframe_agent")


# ── VISUALIZATION AGENT ───────────────────────────────────────
print("--- Visualization Agent Routing ---\n")

run_test("Histogram query",
         router.route("histogram of annual_revenue in leads"), "visualization_agent")
run_test("Bar chart query",
         router.route("bar chart of industry in leads"), "visualization_agent")
run_test("Plot query",
         router.route("plot distribution in organizations"), "visualization_agent")
run_test("Graph query",
         router.route("graph of employees"), "visualization_agent")


# ── TRANSFORMER AGENT β€” existing ops ─────────────────────────
print("--- Transformer Agent Routing (existing ops) ---\n")

run_test("Drop duplicates",
         router.route("drop duplicates in leads"), "transformer_agent")
run_test("Fill nulls",
         router.route("fill nulls in organizations"), "transformer_agent")
run_test("Normalize",
         router.route("normalize annual_revenue in leads"), "transformer_agent")
run_test("Encode",
         router.route("encode industry in leads"), "transformer_agent")
run_test("Rename",
         router.route("rename industry to sector in leads"), "transformer_agent")
run_test("Drop column (no metadata collision)",
         router.route("drop column description in leads"), "transformer_agent")
run_test("Impute (no metadata collision)",
         router.route("impute missing in organizations"), "transformer_agent")
run_test("Strip whitespace",
         router.route("strip whitespace in people"), "transformer_agent")


# ── TRANSFORMER AGENT β€” new preprocessing ops ─────────────────
print("--- Transformer Agent Routing (new preprocessing ops) ---\n")

run_test("Standardize",
         router.route("standardize number of employees in organizations"), "transformer_agent")
run_test("Z-score keyword",
         router.route("z-score normalize founded in organizations"), "transformer_agent")
run_test("Zscore keyword",
         router.route("zscore the index column in leads"), "transformer_agent")
run_test("One-hot encoding",
         router.route("one hot encode industry in organizations"), "transformer_agent")
run_test("Onehot keyword",
         router.route("onehot encode sex in people"), "transformer_agent")
run_test("Dummies keyword",
         router.route("get dummies for industry in organizations"), "transformer_agent")
run_test("Fill with mean",
         router.route("fill with mean in organizations"), "transformer_agent")
run_test("Fill with median",
         router.route("fill nulls with median in leads"), "transformer_agent")
run_test("Fill with mode",
         router.route("fill missing using mode in people"), "transformer_agent")
run_test("Fill zero",
         router.route("fill with zero in leads"), "transformer_agent")
run_test("Drop missing rows",
         router.route("drop missing rows in organizations"), "transformer_agent")
run_test("Drop missing cols",
         router.route("drop missing columns in leads"), "transformer_agent")
run_test("Dropna keyword",
         router.route("dropna in organizations"), "transformer_agent")


# ── LIST DISAMBIGUATION ───────────────────────────────────────
print("--- List Ambiguity Detection ---\n")

run_bool_test("'list columns in leads' β†’ metadata context",
              _is_list_with_context("list columns in leads"), expected=True)
run_bool_test("'list all numeric columns in people' β†’ metadata context",
              _is_list_with_context("list all numeric columns in people"), expected=True)
run_bool_test("'list' alone β†’ no context (dataset list)",
              _is_list_with_context("list"), expected=False)
run_bool_test("'list datasets' β†’ no context (dataset list)",
              _is_list_with_context("list datasets"), expected=False)


# ── COLUMN VALIDATION ─────────────────────────────────────────
print("--- Column Validation (LLM plan guard) ---\n")

registry = DatasetRegistry()
datasets = registry.list_datasets()

if datasets:
    sample_dataset = [d for d in datasets if not d.endswith("_clean")][0]
    info = registry.get_info(sample_dataset)
    real_columns = info.get("columns", [])

    if real_columns:
        real_col = real_columns[0]

        with patch("cli_app.command_handler.registry", registry):
            ok, _ = _validate_plan_column({
                "agent": "transformer_agent", "operation": "fill_mean",
                "dataset": sample_dataset, "column": real_col
            })
            run_bool_test(f"Valid column '{real_col}' in '{sample_dataset}' β†’ passes",
                          ok, expected=True)

            ok, _ = _validate_plan_column({
                "agent": "transformer_agent", "operation": "standardize",
                "dataset": sample_dataset, "column": "ghost_col_xyz"
            })
            run_bool_test("Non-existent column 'ghost_col_xyz' β†’ fails validation",
                          not ok, expected=True)

            ok, _ = _validate_plan_column({
                "agent": "transformer_agent", "operation": "drop_missing_rows",
                "dataset": sample_dataset, "column": None
            })
            run_bool_test("Plan with column=None β†’ always passes",
                          ok, expected=True)
else:
    print("[SKIP] No datasets loaded β€” skipping column validation tests\n")


# ── SUMMARY ───────────────────────────────────────────────────
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
if failed == 0:
    print("All tests passed.")
print("=" * 60)