data-quality-env / scripts /qa_accuracy_check.py
Hemanth Kunta
Meta hackathon submission
91e7690
from __future__ import annotations
import json
import time
import requests
BASE = "http://localhost:7860"
def post(path: str, payload: dict) -> dict:
r = requests.post(BASE + path, json=payload, timeout=30)
r.raise_for_status()
return r.json()
def task1_bad_vs_good(seed: int = 42) -> dict:
post("/reset", {"task_id": 1, "seed": seed})
bad = {
"action": {
"action_type": "submit_report",
"report": {
"null_issues": {},
"duplicate_row_count": {"value": 0, "confidence": 0.1},
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"relational_issues": [],
"recommended_fixes": [],
},
}
}
bad_score = post("/step", bad)["reward"]["value"]
post("/reset", {"task_id": 1, "seed": seed})
q1 = post(
"/step",
{
"action": {
"action_type": "query",
"sql": "SELECT SUM(CASE WHEN email IS NULL OR lower(trim(cast(email as varchar))) IN ('null','n/a','unknown','-','','0','none') THEN 1 ELSE 0 END) AS email_null_total, SUM(CASE WHEN customer_id IS NULL THEN 1 ELSE 0 END) AS cid_nulls FROM customers",
}
},
)
q2 = post(
"/step",
{
"action": {
"action_type": "query",
"sql": "SELECT COALESCE(SUM(c-1),0) AS exact_duplicate_rows FROM (SELECT customer_id,email,name,signup_date,country, COUNT(*) c FROM customers GROUP BY 1,2,3,4,5 HAVING COUNT(*)>1) t",
}
},
)
r1 = (q1.get("observation", {}).get("last_query_result") or [{}])[0]
r2 = (q2.get("observation", {}).get("last_query_result") or [{}])[0]
good = {
"action": {
"action_type": "submit_report",
"report": {
"null_issues": {
"email": {"value": int(r1.get("email_null_total", 0) or 0), "confidence": 0.92},
"customer_id": {"value": int(r1.get("cid_nulls", 0) or 0), "confidence": 0.92},
},
"duplicate_row_count": {"value": int(r2.get("exact_duplicate_rows", 0) or 0), "confidence": 0.9},
"schema_violations": [
{
"column": "customers",
"issue_type": "near_duplicate_pattern",
"example": "country changed",
"count": 1,
"confidence": 0.6,
}
],
"drifted_columns": [],
"drift_details": {},
"relational_issues": [],
"recommended_fixes": ["dedupe and normalize disguised nulls"],
},
}
}
good_score = post("/step", good)["reward"]["value"]
return {"task1_bad_score": bad_score, "task1_good_score": good_score}
def task3_bad_vs_good(seed: int = 42) -> dict:
post("/reset", {"task_id": 3, "seed": seed})
bad = {
"action": {
"action_type": "submit_report",
"report": {
"null_issues": {},
"duplicate_row_count": {"value": 0, "confidence": 0.1},
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"relational_issues": [],
"recommended_fixes": [],
},
}
}
bad_score = post("/step", bad)["reward"]["value"]
post("/reset", {"task_id": 3, "seed": seed})
m = post(
"/step",
{
"action": {
"action_type": "query",
"sql": "SELECT (SELECT AVG(amount) FROM transactions_baseline) AS baseline_mean, (SELECT AVG(amount) FROM transactions_current) AS current_mean",
}
},
)
c = post(
"/step",
{
"action": {
"action_type": "query",
"sql": "SELECT DISTINCT c.category FROM transactions_current c LEFT JOIN (SELECT DISTINCT category FROM transactions_baseline) b ON c.category=b.category WHERE b.category IS NULL ORDER BY c.category",
}
},
)
u = post(
"/step",
{
"action": {
"action_type": "query",
"sql": "SELECT AVG(CASE WHEN user_id >= 3000 THEN 1.0 ELSE 0.0 END) AS new_user_row_pct FROM transactions_current",
}
},
)
mr = (m.get("observation", {}).get("last_query_result") or [{}])[0]
ur = (u.get("observation", {}).get("last_query_result") or [{}])[0]
cats = [str(x.get("category")) for x in (c.get("observation", {}).get("last_query_result") or []) if x.get("category") is not None]
good = {
"action": {
"action_type": "submit_report",
"report": {
"null_issues": {},
"duplicate_row_count": {"value": 0, "confidence": 0.6},
"schema_violations": [],
"drifted_columns": ["amount", "category", "user_id"],
"drift_details": {
"amount": {"value": f"mean shift from {float(mr.get('baseline_mean', 0.0) or 0.0):.2f} to {float(mr.get('current_mean', 0.0) or 0.0):.2f}", "confidence": 0.9},
"category": {"value": ",".join(cats), "confidence": 0.88},
"user_id": {"value": f"{float(ur.get('new_user_row_pct', 0.0) or 0.0)*100:.1f}%", "confidence": 0.87},
},
"relational_issues": [],
"recommended_fixes": ["enable drift monitors"],
},
}
}
good_score = post("/step", good)["reward"]["value"]
return {"task3_bad_score": bad_score, "task3_good_score": good_score}
def main() -> None:
t0 = time.time()
health = requests.get(BASE + "/health", timeout=10).json()
t1 = task1_bad_vs_good(42)
t3 = task3_bad_vs_good(42)
out = {
"health": health,
"scorer_sensitivity": {**t1, **t3},
"elapsed_sec": round(time.time() - t0, 3),
}
print(json.dumps(out, indent=2))
if __name__ == "__main__":
main()