File size: 4,929 Bytes
b4ccf27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
tests/evaluate.py β€” Automated evaluation script for the SHL Agent.

Tests all 5 required scenario types:
  1. Vague query β†’ clarification (no recommendations)
  2. Clear query β†’ recommendations (1–10 items)
  3. Changed preference β†’ refined results
  4. Comparison query β†’ grounded explanation
  5. Off-topic β†’ refusal (no recommendations)

Usage:
    # Against local server
    python tests/evaluate.py --base-url http://localhost:7860

    # Against deployed HF Space
    python tests/evaluate.py --base-url https://<your-space>.hf.space

The script prints a pass/fail table and exits with code 1 if any test fails.
This makes it usable in CI/CD pipelines.
"""

import sys
import os
import json
import argparse
import time

import requests


def load_test_cases(path: str) -> list:
    with open(path, "r") as f:
        return json.load(f)


def run_test(base_url: str, test: dict) -> dict:
    """
    Run a single test case against the /chat endpoint.
    Returns a result dict with pass/fail and details.
    """
    url = f"{base_url}/chat"
    payload = {"messages": test["messages"]}

    try:
        resp = requests.post(url, json=payload, timeout=30)
        resp.raise_for_status()
        data = resp.json()
    except requests.exceptions.Timeout:
        return {"scenario": test["scenario"], "passed": False, "reason": "TIMEOUT"}
    except requests.exceptions.RequestException as e:
        return {"scenario": test["scenario"], "passed": False, "reason": str(e)}

    reply = data.get("reply", "")
    recs = data.get("recommendations", [])
    eoc = data.get("end_of_conversation", False)

    failures = []

    # Check: recommendations empty when expected
    if test.get("expected_recommendations_empty") and len(recs) > 0:
        failures.append(f"Expected empty recommendations but got {len(recs)}")

    # Check: recommendations non-empty when expected
    if test.get("expected_recommendations_empty") is False and len(recs) == 0:
        failures.append("Expected non-empty recommendations but got []")

    # Check: end_of_conversation
    if "expected_end_of_conversation" in test:
        if eoc != test["expected_end_of_conversation"]:
            failures.append(
                f"Expected end_of_conversation={test['expected_end_of_conversation']} but got {eoc}"
            )

    # Check: reply is non-empty
    if not reply.strip():
        failures.append("Reply is empty")

    # Check: recommendation count 1–10 if non-empty
    if recs and not (1 <= len(recs) <= 10):
        failures.append(f"Recommendations count {len(recs)} not in [1, 10]")

    # Check: all URLs come from catalog (basic format check)
    for rec in recs:
        if not rec.get("url", "").startswith("https://www.shl.com/"):
            failures.append(f"Suspicious URL: {rec.get('url')}")

    passed = len(failures) == 0
    return {
        "scenario": test["scenario"],
        "passed": passed,
        "reason": "; ".join(failures) if failures else "OK",
        "reply_preview": reply[:100],
        "rec_count": len(recs),
        "eoc": eoc,
    }


def main():
    parser = argparse.ArgumentParser(description="Evaluate SHL Agent")
    parser.add_argument(
        "--base-url",
        default="http://localhost:7860",
        help="Base URL of the running API (default: http://localhost:7860)",
    )
    parser.add_argument(
        "--tests",
        default=os.path.join(os.path.dirname(__file__), "sample_requests.json"),
        help="Path to test cases JSON file",
    )
    args = parser.parse_args()

    # Health check first
    try:
        health_resp = requests.get(f"{args.base_url}/health", timeout=10)
        health_resp.raise_for_status()
        print(f"βœ“ Health check passed: {health_resp.json()}\n")
    except Exception as e:
        print(f"βœ— Health check failed: {e}")
        sys.exit(1)

    test_cases = load_test_cases(args.tests)
    results = []

    for test in test_cases:
        print(f"  Running: {test['scenario']}...", end=" ", flush=True)
        result = run_test(args.base_url, test)
        results.append(result)
        status = "PASS" if result["passed"] else "FAIL"
        print(status)
        if not result["passed"]:
            print(f"    Reason: {result['reason']}")
        else:
            print(f"    Recs: {result['rec_count']} | EOC: {result['eoc']}")
            print(f"    Reply: {result['reply_preview']}...")
        time.sleep(0.5)  # be gentle on rate limits

    passed = sum(1 for r in results if r["passed"])
    total = len(results)
    print(f"\n{'='*50}")
    print(f"Results: {passed}/{total} passed")

    if passed < total:
        print("\nFailed scenarios:")
        for r in results:
            if not r["passed"]:
                print(f"  - {r['scenario']}: {r['reason']}")
        sys.exit(1)
    else:
        print("All tests passed.")
        sys.exit(0)


if __name__ == "__main__":
    main()