File size: 5,828 Bytes
91e7690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Chat-style AI auditor for DataQualityEnv.

This wrapper now behaves like a modern assistant stack:
- planner produces hypotheses and safe probe ideas
- executor runs OpenEnv tool calls
- critic normalizes/repairs the final report
- memory influences future turns
"""

from __future__ import annotations

import argparse
import json
import os
from typing import Any

import requests
from openai import OpenAI

from env.agent_memory import MemoryStore
from env.multi_agent_orchestrator import MultiAgentOrchestrator

API_BASE_URL = os.environ.get("API_BASE_URL", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
MEMORY_PATH = os.environ.get("AGENT_MEMORY_PATH", "outputs/agent_memory.json")


SYSTEM_PROMPT = """You are a data quality auditing assistant.
You can investigate data via SQL and then submit a final JSON report.

Return valid JSON only in this schema:
{
  "assistant_message": "short natural language reply",
  "action": {
    "action_type": "query" | "submit_report",
    "sql": "... optional when query ...",
    "report": {
      "null_issues": {"col": 0},
      "duplicate_row_count": 0,
      "schema_violations": [],
      "drifted_columns": [],
      "drift_details": {},
      "recommended_fixes": []
    }
  }
}

Rules:
- If user asks to inspect, use action_type=query with safe SELECT/WITH SQL.
- If enough evidence exists or user asks to finalize, use action_type=submit_report.
- Keep assistant_message concise and helpful.
"""


class ChatAuditor:
    def __init__(self, task_id: int, seed: int) -> None:
        if not API_BASE_URL or not MODEL_NAME or not API_KEY:
            raise RuntimeError("Set API_BASE_URL, MODEL_NAME, and HF_TOKEN/OPENAI_API_KEY.")
        self.client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
        self.memory = MemoryStore(MEMORY_PATH)
        self.orchestrator = MultiAgentOrchestrator(memory=self.memory)
        self.task_id = task_id
        self.seed = seed
        self.history: list[dict[str, Any]] = []
        self.obs = self.call_env("reset", {"task_id": task_id, "seed": seed})

    def call_env(self, endpoint: str, payload: dict | None = None, method: str = "POST") -> dict:
        url = f"{ENV_URL}/{endpoint}"
        if method == "POST":
            r = requests.post(url, json=payload or {}, timeout=30)
        else:
            r = requests.get(url, timeout=30)
        r.raise_for_status()
        return r.json()

    def build_user_payload(self, user_text: str) -> str:
        view = {
            "user_request": user_text,
            "task_id": self.obs.get("task_id"),
            "task_description": self.obs.get("task_description"),
            "table_name": self.obs.get("table_name"),
            "schema": self.obs.get("schema"),
            "row_count": self.obs.get("row_count"),
            "step": self.obs.get("step"),
            "max_steps": self.obs.get("max_steps"),
            "last_query_result": (self.obs.get("last_query_result") or [])[:5],
            "last_action_error": self.obs.get("last_action_error"),
            "recent_history": self.history[-6:],
        }
        return json.dumps(view)

    def decide(self, user_text: str) -> dict:
        base_queries = [
            f"SELECT COUNT(*) AS n FROM {self.obs['table_name']}",
            f"SELECT * FROM {self.obs['table_name']} LIMIT 5",
        ]
        plan = self.orchestrator.build_chat_response(
            user_text=user_text,
            obs=self.obs,
            task_id=self.task_id,
            base_queries=base_queries,
            reasoning_hints=[],
        )
        return {
            "assistant_message": plan.assistant_message,
            "action": plan.action,
            "hypotheses": plan.hypotheses,
            "selected_queries": plan.selected_queries,
        }

    def step(self, user_text: str) -> tuple[str, dict]:
        decision = self.decide(user_text)
        assistant_message = str(decision.get("assistant_message", ""))
        action = decision.get("action", {"action_type": "query", "sql": f"SELECT COUNT(*) FROM {self.obs['table_name']}"})

        out = self.call_env("step", {"action": action})
        self.obs = out.get("observation", self.obs)
        reward = out.get("reward", {})

        self.history.append(
            {
                "user": user_text,
                "assistant_message": assistant_message,
                "action_type": action.get("action_type"),
                "reward": reward.get("value", 0.0),
                "done": reward.get("done", False),
                "selected_queries": decision.get("selected_queries", []),
            }
        )
        self.memory.save()
        return assistant_message, out


def main() -> None:
    parser = argparse.ArgumentParser(description="Chat-like AI auditor for DataQualityEnv")
    parser.add_argument("--task-id", type=int, default=1, choices=[1, 2, 3])
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    auditor = ChatAuditor(task_id=args.task_id, seed=args.seed)
    print(f"Chat auditor ready for task {args.task_id}. Type 'finalize' to submit, 'exit' to quit.")

    while True:
        user_text = input("you> ").strip()
        if user_text.lower() in {"exit", "quit"}:
            break
        if user_text.lower() == "finalize":
            user_text = "Finalize and submit the best report now."

        msg, result = auditor.step(user_text)
        reward = result.get("reward", {})
        print(f"agent> {msg}")
        print(f"reward={reward.get('value', 0.0)} done={reward.get('done', False)}")
        if reward.get("done"):
            print("Episode complete.")
            break


if __name__ == "__main__":
    main()