shlaiagent / evaluate.py
Utkarsh430's picture
scripts and tests
b4ccf27 verified
Raw
History Blame Contribute Delete
4.93 kB
"""
tests/evaluate.py — Automated evaluation script for the SHL Agent.
Tests all 5 required scenario types:
1. Vague query → clarification (no recommendations)
2. Clear query → recommendations (1–10 items)
3. Changed preference → refined results
4. Comparison query → grounded explanation
5. Off-topic → refusal (no recommendations)
Usage:
# Against local server
python tests/evaluate.py --base-url http://localhost:7860
# Against deployed HF Space
python tests/evaluate.py --base-url https://<your-space>.hf.space
The script prints a pass/fail table and exits with code 1 if any test fails.
This makes it usable in CI/CD pipelines.
"""
import sys
import os
import json
import argparse
import time
import requests
def load_test_cases(path: str) -> list:
with open(path, "r") as f:
return json.load(f)
def run_test(base_url: str, test: dict) -> dict:
"""
Run a single test case against the /chat endpoint.
Returns a result dict with pass/fail and details.
"""
url = f"{base_url}/chat"
payload = {"messages": test["messages"]}
try:
resp = requests.post(url, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
except requests.exceptions.Timeout:
return {"scenario": test["scenario"], "passed": False, "reason": "TIMEOUT"}
except requests.exceptions.RequestException as e:
return {"scenario": test["scenario"], "passed": False, "reason": str(e)}
reply = data.get("reply", "")
recs = data.get("recommendations", [])
eoc = data.get("end_of_conversation", False)
failures = []
# Check: recommendations empty when expected
if test.get("expected_recommendations_empty") and len(recs) > 0:
failures.append(f"Expected empty recommendations but got {len(recs)}")
# Check: recommendations non-empty when expected
if test.get("expected_recommendations_empty") is False and len(recs) == 0:
failures.append("Expected non-empty recommendations but got []")
# Check: end_of_conversation
if "expected_end_of_conversation" in test:
if eoc != test["expected_end_of_conversation"]:
failures.append(
f"Expected end_of_conversation={test['expected_end_of_conversation']} but got {eoc}"
)
# Check: reply is non-empty
if not reply.strip():
failures.append("Reply is empty")
# Check: recommendation count 1–10 if non-empty
if recs and not (1 <= len(recs) <= 10):
failures.append(f"Recommendations count {len(recs)} not in [1, 10]")
# Check: all URLs come from catalog (basic format check)
for rec in recs:
if not rec.get("url", "").startswith("https://www.shl.com/"):
failures.append(f"Suspicious URL: {rec.get('url')}")
passed = len(failures) == 0
return {
"scenario": test["scenario"],
"passed": passed,
"reason": "; ".join(failures) if failures else "OK",
"reply_preview": reply[:100],
"rec_count": len(recs),
"eoc": eoc,
}
def main():
parser = argparse.ArgumentParser(description="Evaluate SHL Agent")
parser.add_argument(
"--base-url",
default="http://localhost:7860",
help="Base URL of the running API (default: http://localhost:7860)",
)
parser.add_argument(
"--tests",
default=os.path.join(os.path.dirname(__file__), "sample_requests.json"),
help="Path to test cases JSON file",
)
args = parser.parse_args()
# Health check first
try:
health_resp = requests.get(f"{args.base_url}/health", timeout=10)
health_resp.raise_for_status()
print(f"✓ Health check passed: {health_resp.json()}\n")
except Exception as e:
print(f"✗ Health check failed: {e}")
sys.exit(1)
test_cases = load_test_cases(args.tests)
results = []
for test in test_cases:
print(f" Running: {test['scenario']}...", end=" ", flush=True)
result = run_test(args.base_url, test)
results.append(result)
status = "PASS" if result["passed"] else "FAIL"
print(status)
if not result["passed"]:
print(f" Reason: {result['reason']}")
else:
print(f" Recs: {result['rec_count']} | EOC: {result['eoc']}")
print(f" Reply: {result['reply_preview']}...")
time.sleep(0.5) # be gentle on rate limits
passed = sum(1 for r in results if r["passed"])
total = len(results)
print(f"\n{'='*50}")
print(f"Results: {passed}/{total} passed")
if passed < total:
print("\nFailed scenarios:")
for r in results:
if not r["passed"]:
print(f" - {r['scenario']}: {r['reason']}")
sys.exit(1)
else:
print("All tests passed.")
sys.exit(0)
if __name__ == "__main__":
main()