Spaces:
Sleeping
Sleeping
File size: 5,828 Bytes
91e7690 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
Chat-style AI auditor for DataQualityEnv.
This wrapper now behaves like a modern assistant stack:
- planner produces hypotheses and safe probe ideas
- executor runs OpenEnv tool calls
- critic normalizes/repairs the final report
- memory influences future turns
"""
from __future__ import annotations
import argparse
import json
import os
from typing import Any
import requests
from openai import OpenAI
from env.agent_memory import MemoryStore
from env.multi_agent_orchestrator import MultiAgentOrchestrator
API_BASE_URL = os.environ.get("API_BASE_URL", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
MEMORY_PATH = os.environ.get("AGENT_MEMORY_PATH", "outputs/agent_memory.json")
SYSTEM_PROMPT = """You are a data quality auditing assistant.
You can investigate data via SQL and then submit a final JSON report.
Return valid JSON only in this schema:
{
"assistant_message": "short natural language reply",
"action": {
"action_type": "query" | "submit_report",
"sql": "... optional when query ...",
"report": {
"null_issues": {"col": 0},
"duplicate_row_count": 0,
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"recommended_fixes": []
}
}
}
Rules:
- If user asks to inspect, use action_type=query with safe SELECT/WITH SQL.
- If enough evidence exists or user asks to finalize, use action_type=submit_report.
- Keep assistant_message concise and helpful.
"""
class ChatAuditor:
def __init__(self, task_id: int, seed: int) -> None:
if not API_BASE_URL or not MODEL_NAME or not API_KEY:
raise RuntimeError("Set API_BASE_URL, MODEL_NAME, and HF_TOKEN/OPENAI_API_KEY.")
self.client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
self.memory = MemoryStore(MEMORY_PATH)
self.orchestrator = MultiAgentOrchestrator(memory=self.memory)
self.task_id = task_id
self.seed = seed
self.history: list[dict[str, Any]] = []
self.obs = self.call_env("reset", {"task_id": task_id, "seed": seed})
def call_env(self, endpoint: str, payload: dict | None = None, method: str = "POST") -> dict:
url = f"{ENV_URL}/{endpoint}"
if method == "POST":
r = requests.post(url, json=payload or {}, timeout=30)
else:
r = requests.get(url, timeout=30)
r.raise_for_status()
return r.json()
def build_user_payload(self, user_text: str) -> str:
view = {
"user_request": user_text,
"task_id": self.obs.get("task_id"),
"task_description": self.obs.get("task_description"),
"table_name": self.obs.get("table_name"),
"schema": self.obs.get("schema"),
"row_count": self.obs.get("row_count"),
"step": self.obs.get("step"),
"max_steps": self.obs.get("max_steps"),
"last_query_result": (self.obs.get("last_query_result") or [])[:5],
"last_action_error": self.obs.get("last_action_error"),
"recent_history": self.history[-6:],
}
return json.dumps(view)
def decide(self, user_text: str) -> dict:
base_queries = [
f"SELECT COUNT(*) AS n FROM {self.obs['table_name']}",
f"SELECT * FROM {self.obs['table_name']} LIMIT 5",
]
plan = self.orchestrator.build_chat_response(
user_text=user_text,
obs=self.obs,
task_id=self.task_id,
base_queries=base_queries,
reasoning_hints=[],
)
return {
"assistant_message": plan.assistant_message,
"action": plan.action,
"hypotheses": plan.hypotheses,
"selected_queries": plan.selected_queries,
}
def step(self, user_text: str) -> tuple[str, dict]:
decision = self.decide(user_text)
assistant_message = str(decision.get("assistant_message", ""))
action = decision.get("action", {"action_type": "query", "sql": f"SELECT COUNT(*) FROM {self.obs['table_name']}"})
out = self.call_env("step", {"action": action})
self.obs = out.get("observation", self.obs)
reward = out.get("reward", {})
self.history.append(
{
"user": user_text,
"assistant_message": assistant_message,
"action_type": action.get("action_type"),
"reward": reward.get("value", 0.0),
"done": reward.get("done", False),
"selected_queries": decision.get("selected_queries", []),
}
)
self.memory.save()
return assistant_message, out
def main() -> None:
parser = argparse.ArgumentParser(description="Chat-like AI auditor for DataQualityEnv")
parser.add_argument("--task-id", type=int, default=1, choices=[1, 2, 3])
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
auditor = ChatAuditor(task_id=args.task_id, seed=args.seed)
print(f"Chat auditor ready for task {args.task_id}. Type 'finalize' to submit, 'exit' to quit.")
while True:
user_text = input("you> ").strip()
if user_text.lower() in {"exit", "quit"}:
break
if user_text.lower() == "finalize":
user_text = "Finalize and submit the best report now."
msg, result = auditor.step(user_text)
reward = result.get("reward", {})
print(f"agent> {msg}")
print(f"reward={reward.get('value', 0.0)} done={reward.get('done', False)}")
if reward.get("done"):
print("Episode complete.")
break
if __name__ == "__main__":
main()
|