data-quality-env / chat_agent.py
Hemanth Kunta
Meta hackathon submission
91e7690
"""
Chat-style AI auditor for DataQualityEnv.
This wrapper now behaves like a modern assistant stack:
- planner produces hypotheses and safe probe ideas
- executor runs OpenEnv tool calls
- critic normalizes/repairs the final report
- memory influences future turns
"""
from __future__ import annotations
import argparse
import json
import os
from typing import Any
import requests
from openai import OpenAI
from env.agent_memory import MemoryStore
from env.multi_agent_orchestrator import MultiAgentOrchestrator
API_BASE_URL = os.environ.get("API_BASE_URL", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
MEMORY_PATH = os.environ.get("AGENT_MEMORY_PATH", "outputs/agent_memory.json")
SYSTEM_PROMPT = """You are a data quality auditing assistant.
You can investigate data via SQL and then submit a final JSON report.
Return valid JSON only in this schema:
{
"assistant_message": "short natural language reply",
"action": {
"action_type": "query" | "submit_report",
"sql": "... optional when query ...",
"report": {
"null_issues": {"col": 0},
"duplicate_row_count": 0,
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"recommended_fixes": []
}
}
}
Rules:
- If user asks to inspect, use action_type=query with safe SELECT/WITH SQL.
- If enough evidence exists or user asks to finalize, use action_type=submit_report.
- Keep assistant_message concise and helpful.
"""
class ChatAuditor:
def __init__(self, task_id: int, seed: int) -> None:
if not API_BASE_URL or not MODEL_NAME or not API_KEY:
raise RuntimeError("Set API_BASE_URL, MODEL_NAME, and HF_TOKEN/OPENAI_API_KEY.")
self.client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
self.memory = MemoryStore(MEMORY_PATH)
self.orchestrator = MultiAgentOrchestrator(memory=self.memory)
self.task_id = task_id
self.seed = seed
self.history: list[dict[str, Any]] = []
self.obs = self.call_env("reset", {"task_id": task_id, "seed": seed})
def call_env(self, endpoint: str, payload: dict | None = None, method: str = "POST") -> dict:
url = f"{ENV_URL}/{endpoint}"
if method == "POST":
r = requests.post(url, json=payload or {}, timeout=30)
else:
r = requests.get(url, timeout=30)
r.raise_for_status()
return r.json()
def build_user_payload(self, user_text: str) -> str:
view = {
"user_request": user_text,
"task_id": self.obs.get("task_id"),
"task_description": self.obs.get("task_description"),
"table_name": self.obs.get("table_name"),
"schema": self.obs.get("schema"),
"row_count": self.obs.get("row_count"),
"step": self.obs.get("step"),
"max_steps": self.obs.get("max_steps"),
"last_query_result": (self.obs.get("last_query_result") or [])[:5],
"last_action_error": self.obs.get("last_action_error"),
"recent_history": self.history[-6:],
}
return json.dumps(view)
def decide(self, user_text: str) -> dict:
base_queries = [
f"SELECT COUNT(*) AS n FROM {self.obs['table_name']}",
f"SELECT * FROM {self.obs['table_name']} LIMIT 5",
]
plan = self.orchestrator.build_chat_response(
user_text=user_text,
obs=self.obs,
task_id=self.task_id,
base_queries=base_queries,
reasoning_hints=[],
)
return {
"assistant_message": plan.assistant_message,
"action": plan.action,
"hypotheses": plan.hypotheses,
"selected_queries": plan.selected_queries,
}
def step(self, user_text: str) -> tuple[str, dict]:
decision = self.decide(user_text)
assistant_message = str(decision.get("assistant_message", ""))
action = decision.get("action", {"action_type": "query", "sql": f"SELECT COUNT(*) FROM {self.obs['table_name']}"})
out = self.call_env("step", {"action": action})
self.obs = out.get("observation", self.obs)
reward = out.get("reward", {})
self.history.append(
{
"user": user_text,
"assistant_message": assistant_message,
"action_type": action.get("action_type"),
"reward": reward.get("value", 0.0),
"done": reward.get("done", False),
"selected_queries": decision.get("selected_queries", []),
}
)
self.memory.save()
return assistant_message, out
def main() -> None:
parser = argparse.ArgumentParser(description="Chat-like AI auditor for DataQualityEnv")
parser.add_argument("--task-id", type=int, default=1, choices=[1, 2, 3])
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
auditor = ChatAuditor(task_id=args.task_id, seed=args.seed)
print(f"Chat auditor ready for task {args.task_id}. Type 'finalize' to submit, 'exit' to quit.")
while True:
user_text = input("you> ").strip()
if user_text.lower() in {"exit", "quit"}:
break
if user_text.lower() == "finalize":
user_text = "Finalize and submit the best report now."
msg, result = auditor.step(user_text)
reward = result.get("reward", {})
print(f"agent> {msg}")
print(f"reward={reward.get('value', 0.0)} done={reward.get('done', False)}")
if reward.get("done"):
print("Episode complete.")
break
if __name__ == "__main__":
main()