from __future__ import annotations import json import os from dataclasses import dataclass from typing import Any from openai import OpenAI from env.agent_memory import MemoryStore from env.knowledge_brain import KnowledgeBrain from env.reasoning_stack import build_plan_prompt, parse_plan_json, safe_query_filter, validate_and_repair_report API_BASE_URL = os.environ.get("API_BASE_URL", "") MODEL_NAME = os.environ.get("MODEL_NAME", "") API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "") def _get_client() -> OpenAI | None: if not API_BASE_URL or not MODEL_NAME or not API_KEY: return None try: return OpenAI(base_url=API_BASE_URL, api_key=API_KEY) except Exception: return None @dataclass class OrchestratorPlan: assistant_message: str action: dict[str, Any] hypotheses: list[str] selected_queries: list[str] class MultiAgentOrchestrator: """ Planner -> Critic -> Executor -> Fixer stack. Designed to feel closer to a modern assistant product while still only using safe OpenEnv actions. """ def __init__(self, memory: MemoryStore | None = None) -> None: self.client = _get_client() self.memory = memory self.brain = KnowledgeBrain() def _llm_json(self, system: str, user: dict[str, Any], max_tokens: int = 600) -> dict[str, Any]: if self.client is None: return {} try: c = self.client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": system}, {"role": "user", "content": json.dumps(user)}, ], temperature=0.0, max_tokens=max_tokens, ) raw = (c.choices[0].message.content or "").strip() parsed = json.loads(raw) return parsed if isinstance(parsed, dict) else {} except Exception: return {} def plan_queries( self, task_id: int, obs: dict[str, Any], base_queries: list[str], reasoning_hints: list[str] | None = None, ) -> tuple[list[str], list[str]]: reasoning_hints = reasoning_hints or [] user = { "task_id": task_id, "table_name": obs.get("table_name"), "schema": obs.get("schema", {}), "base_queries": base_queries, "reasoning_hints": reasoning_hints, "instruction": "Return JSON with hypotheses and extra_queries only.", } system = ( "You are a planning module for SQL auditing. Return JSON only with keys hypotheses and extra_queries. " "extra_queries must be safe SELECT/WITH only and bounded to at most 3." ) parsed = self._llm_json(system, user, max_tokens=350) plan = parse_plan_json(json.dumps(parsed)) if parsed else parse_plan_json("{}") extra_queries = safe_query_filter(plan.extra_queries)[:3] hypotheses = plan.hypotheses[:6] return hypotheses, extra_queries def critique_report(self, task_id: int, report: dict[str, Any], evidence: dict[str, Any]) -> dict[str, Any]: report = validate_and_repair_report(report) # deterministic brain first brain_report = self.brain.build_report(task_id, evidence) merged = { "null_issues": dict(brain_report.null_issues), "duplicate_row_count": brain_report.duplicate_row_count, "schema_violations": list(brain_report.schema_violations), "drifted_columns": list(brain_report.drifted_columns), "drift_details": dict(brain_report.drift_details), "recommended_fixes": list(brain_report.recommended_fixes), } # preserve user/LLM-added details where safe merged["null_issues"].update(report.get("null_issues", {})) if int(report.get("duplicate_row_count", 0)) > merged["duplicate_row_count"]: merged["duplicate_row_count"] = int(report["duplicate_row_count"]) merged["schema_violations"].extend(report.get("schema_violations", [])) for c in report.get("drifted_columns", []): if c not in merged["drifted_columns"]: merged["drifted_columns"].append(c) merged["drift_details"].update(report.get("drift_details", {})) for fix in report.get("recommended_fixes", []): if fix not in merged["recommended_fixes"]: merged["recommended_fixes"].append(fix) return validate_and_repair_report(merged) def build_chat_response( self, user_text: str, obs: dict[str, Any], task_id: int, base_queries: list[str], reasoning_hints: list[str] | None = None, ) -> OrchestratorPlan: hypotheses, extra_queries = self.plan_queries(task_id, obs, base_queries, reasoning_hints) selected_queries = base_queries + extra_queries assistant_message = self._assistant_message(user_text, hypotheses, selected_queries, obs) action: dict[str, Any] lower = user_text.lower().strip() if any(word in lower for word in ["final", "submit", "report", "done", "finish"]): action = {"action_type": "submit_report", "report": self._fallback_report(task_id)} else: action = {"action_type": "query", "sql": selected_queries[0] if selected_queries else f"SELECT COUNT(*) AS n FROM {obs['table_name']}"} return OrchestratorPlan( assistant_message=assistant_message, action=action, hypotheses=hypotheses, selected_queries=selected_queries, ) def _assistant_message(self, user_text: str, hypotheses: list[str], queries: list[str], obs: dict[str, Any]) -> str: if hypotheses: lead = hypotheses[0] else: lead = "I will inspect the data with a targeted SQL probe." if queries: return f"{lead} Next I’ll run a focused query and keep the plan safe and deterministic." return "I’ll use the available evidence to produce the final audit report." def _fallback_report(self, task_id: int) -> dict[str, Any]: if task_id == 1: return { "null_issues": {}, "duplicate_row_count": 0, "schema_violations": [], "drifted_columns": [], "drift_details": {}, "recommended_fixes": [], } if task_id == 2: return { "null_issues": {}, "duplicate_row_count": 0, "schema_violations": [], "drifted_columns": [], "drift_details": {}, "recommended_fixes": [], } return { "null_issues": {}, "duplicate_row_count": 0, "schema_violations": [], "drifted_columns": [], "drift_details": {}, "recommended_fixes": [], }