Spaces:
Sleeping
Sleeping
File size: 7,087 Bytes
91e7690 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | from __future__ import annotations
import json
import os
from dataclasses import dataclass
from typing import Any
from openai import OpenAI
from env.agent_memory import MemoryStore
from env.knowledge_brain import KnowledgeBrain
from env.reasoning_stack import build_plan_prompt, parse_plan_json, safe_query_filter, validate_and_repair_report
API_BASE_URL = os.environ.get("API_BASE_URL", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
def _get_client() -> OpenAI | None:
if not API_BASE_URL or not MODEL_NAME or not API_KEY:
return None
try:
return OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
except Exception:
return None
@dataclass
class OrchestratorPlan:
assistant_message: str
action: dict[str, Any]
hypotheses: list[str]
selected_queries: list[str]
class MultiAgentOrchestrator:
"""
Planner -> Critic -> Executor -> Fixer stack.
Designed to feel closer to a modern assistant product while still only
using safe OpenEnv actions.
"""
def __init__(self, memory: MemoryStore | None = None) -> None:
self.client = _get_client()
self.memory = memory
self.brain = KnowledgeBrain()
def _llm_json(self, system: str, user: dict[str, Any], max_tokens: int = 600) -> dict[str, Any]:
if self.client is None:
return {}
try:
c = self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": json.dumps(user)},
],
temperature=0.0,
max_tokens=max_tokens,
)
raw = (c.choices[0].message.content or "").strip()
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
def plan_queries(
self,
task_id: int,
obs: dict[str, Any],
base_queries: list[str],
reasoning_hints: list[str] | None = None,
) -> tuple[list[str], list[str]]:
reasoning_hints = reasoning_hints or []
user = {
"task_id": task_id,
"table_name": obs.get("table_name"),
"schema": obs.get("schema", {}),
"base_queries": base_queries,
"reasoning_hints": reasoning_hints,
"instruction": "Return JSON with hypotheses and extra_queries only.",
}
system = (
"You are a planning module for SQL auditing. Return JSON only with keys hypotheses and extra_queries. "
"extra_queries must be safe SELECT/WITH only and bounded to at most 3."
)
parsed = self._llm_json(system, user, max_tokens=350)
plan = parse_plan_json(json.dumps(parsed)) if parsed else parse_plan_json("{}")
extra_queries = safe_query_filter(plan.extra_queries)[:3]
hypotheses = plan.hypotheses[:6]
return hypotheses, extra_queries
def critique_report(self, task_id: int, report: dict[str, Any], evidence: dict[str, Any]) -> dict[str, Any]:
report = validate_and_repair_report(report)
# deterministic brain first
brain_report = self.brain.build_report(task_id, evidence)
merged = {
"null_issues": dict(brain_report.null_issues),
"duplicate_row_count": brain_report.duplicate_row_count,
"schema_violations": list(brain_report.schema_violations),
"drifted_columns": list(brain_report.drifted_columns),
"drift_details": dict(brain_report.drift_details),
"recommended_fixes": list(brain_report.recommended_fixes),
}
# preserve user/LLM-added details where safe
merged["null_issues"].update(report.get("null_issues", {}))
if int(report.get("duplicate_row_count", 0)) > merged["duplicate_row_count"]:
merged["duplicate_row_count"] = int(report["duplicate_row_count"])
merged["schema_violations"].extend(report.get("schema_violations", []))
for c in report.get("drifted_columns", []):
if c not in merged["drifted_columns"]:
merged["drifted_columns"].append(c)
merged["drift_details"].update(report.get("drift_details", {}))
for fix in report.get("recommended_fixes", []):
if fix not in merged["recommended_fixes"]:
merged["recommended_fixes"].append(fix)
return validate_and_repair_report(merged)
def build_chat_response(
self,
user_text: str,
obs: dict[str, Any],
task_id: int,
base_queries: list[str],
reasoning_hints: list[str] | None = None,
) -> OrchestratorPlan:
hypotheses, extra_queries = self.plan_queries(task_id, obs, base_queries, reasoning_hints)
selected_queries = base_queries + extra_queries
assistant_message = self._assistant_message(user_text, hypotheses, selected_queries, obs)
action: dict[str, Any]
lower = user_text.lower().strip()
if any(word in lower for word in ["final", "submit", "report", "done", "finish"]):
action = {"action_type": "submit_report", "report": self._fallback_report(task_id)}
else:
action = {"action_type": "query", "sql": selected_queries[0] if selected_queries else f"SELECT COUNT(*) AS n FROM {obs['table_name']}"}
return OrchestratorPlan(
assistant_message=assistant_message,
action=action,
hypotheses=hypotheses,
selected_queries=selected_queries,
)
def _assistant_message(self, user_text: str, hypotheses: list[str], queries: list[str], obs: dict[str, Any]) -> str:
if hypotheses:
lead = hypotheses[0]
else:
lead = "I will inspect the data with a targeted SQL probe."
if queries:
return f"{lead} Next I’ll run a focused query and keep the plan safe and deterministic."
return "I’ll use the available evidence to produce the final audit report."
def _fallback_report(self, task_id: int) -> dict[str, Any]:
if task_id == 1:
return {
"null_issues": {},
"duplicate_row_count": 0,
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"recommended_fixes": [],
}
if task_id == 2:
return {
"null_issues": {},
"duplicate_row_count": 0,
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"recommended_fixes": [],
}
return {
"null_issues": {},
"duplicate_row_count": 0,
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"recommended_fixes": [],
}
|