govon-runtime / src /inference /agent_loop.py
GovOn Deploy
sync: PR#584 RAG removal + ReAct architecture
1635ec4
"""์„ธ์…˜ ๊ธฐ๋ฐ˜ task loop (v1 ์—”๋“œํฌ์ธํŠธ์šฉ)."""
from __future__ import annotations
import asyncio
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional
from loguru import logger
from .query_builder import build_runtime_query_context, normalize_text
from .session_context import SessionContext
from .tool_router import ToolName, ToolType, tool_name
@dataclass
class ToolResult:
tool: ToolName
success: bool
data: Dict[str, Any] = field(default_factory=dict)
error: Optional[str] = None
latency_ms: float = 0.0
def to_dict(self) -> Dict[str, Any]:
return {
"tool": tool_name(self.tool),
"success": self.success,
"data": self.data,
"error": self.error,
"latency_ms": round(self.latency_ms, 2),
}
@dataclass
class AgentTrace:
request_id: str
session_id: str
plan_tools: List[str] = field(default_factory=list)
plan_reason: str = ""
tool_results: List[ToolResult] = field(default_factory=list)
total_latency_ms: float = 0.0
final_text: str = ""
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"session_id": self.session_id,
"plan": self.plan_tools,
"plan_reason": self.plan_reason,
"tool_results": [result.to_dict() for result in self.tool_results],
"total_latency_ms": round(self.total_latency_ms, 2),
"error": self.error,
}
ToolFunction = Callable[..., Any]
DEFAULT_TOOL_TIMEOUT = 30.0
class AgentLoop:
"""GovOn MVP capability loop (v1).
๋“ฑ๋ก๋œ tool์„ ์ˆœ์ฐจ ์‹คํ–‰ํ•˜๋Š” ๋‹จ์ˆœ ๋ฃจํ”„.
"""
def __init__(
self,
tool_registry: Dict[ToolName, ToolFunction],
tool_timeout: float = DEFAULT_TOOL_TIMEOUT,
) -> None:
self._tools = {tool_name(name): runner for name, runner in tool_registry.items()}
self._tool_timeout = tool_timeout
async def run(
self,
query: str,
session: SessionContext,
request_id: Optional[str] = None,
force_tools: Optional[List[ToolName]] = None,
) -> AgentTrace:
rid = request_id or str(uuid.uuid4())
trace = AgentTrace(request_id=rid, session_id=session.session_id)
loop_start = time.monotonic()
started_at = time.time()
try:
session.add_turn("user", query)
# ๋“ฑ๋ก๋œ ๋ชจ๋“  tool์„ ์ˆœ์ฐจ ์‹คํ–‰
tool_names = (
[tool_name(t) for t in force_tools] if force_tools else list(self._tools.keys())
)
trace.plan_tools = tool_names
trace.plan_reason = "๋“ฑ๋ก๋œ ๋„๊ตฌ ์ˆœ์ฐจ ์‹คํ–‰"
accumulated: Dict[str, Any] = build_runtime_query_context(session, query)
accumulated["conversation"] = [
{"role": turn.role, "content": turn.content} for turn in session.recent_history[-5:]
]
accumulated["query"] = normalize_text(query)
for step_name in tool_names:
result = await self._execute_tool(step_name, accumulated, session)
trace.tool_results.append(result)
accumulated[step_name] = result.data if result.success else {}
session.add_tool_run(
tool=step_name,
graph_run_request_id=rid,
success=result.success,
latency_ms=result.latency_ms,
error=result.error,
metadata=self._build_tool_log_metadata(result.data),
)
trace.final_text = self._extract_final_text(accumulated, tool_names)
session.add_turn("assistant", trace.final_text)
except Exception as exc:
trace.error = str(exc)
logger.error(f"[AgentLoop] request_id={rid} ์˜ค๋ฅ˜: {exc}", exc_info=True)
finally:
trace.total_latency_ms = (time.monotonic() - loop_start) * 1000
self._record_graph_run(
session=session,
trace=trace,
started_at=started_at,
completed_at=time.time(),
)
return trace
async def run_stream(
self,
query: str,
session: SessionContext,
request_id: Optional[str] = None,
force_tools: Optional[List[ToolName]] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
rid = request_id or str(uuid.uuid4())
loop_start = time.monotonic()
started_at = time.time()
trace = AgentTrace(request_id=rid, session_id=session.session_id)
try:
session.add_turn("user", query)
tool_names = (
[tool_name(t) for t in force_tools] if force_tools else list(self._tools.keys())
)
trace.plan_tools = tool_names
trace.plan_reason = "๋“ฑ๋ก๋œ ๋„๊ตฌ ์ˆœ์ฐจ ์‹คํ–‰"
yield {
"type": "plan",
"request_id": rid,
"plan": tool_names,
"reason": trace.plan_reason,
}
accumulated: Dict[str, Any] = build_runtime_query_context(session, query)
accumulated["query"] = normalize_text(query)
for step_name in tool_names:
yield {"type": "tool_start", "request_id": rid, "tool": step_name}
result = await self._execute_tool(step_name, accumulated, session)
trace.tool_results.append(result)
accumulated[step_name] = result.data if result.success else {}
session.add_tool_run(
tool=step_name,
graph_run_request_id=rid,
success=result.success,
latency_ms=result.latency_ms,
error=result.error,
metadata=self._build_tool_log_metadata(result.data),
)
yield {
"type": "tool_result",
"request_id": rid,
"tool": step_name,
"success": result.success,
"latency_ms": round(result.latency_ms, 2),
"error": result.error,
}
trace.final_text = self._extract_final_text(accumulated, tool_names)
session.add_turn("assistant", trace.final_text)
trace.total_latency_ms = (time.monotonic() - loop_start) * 1000
yield {
"type": "final",
"request_id": rid,
"text": trace.final_text,
"trace": trace.to_dict(),
"finished": True,
}
except Exception as exc:
trace.error = str(exc)
trace.total_latency_ms = (time.monotonic() - loop_start) * 1000
logger.error(f"[AgentLoop] stream request_id={rid} ์˜ค๋ฅ˜: {exc}", exc_info=True)
yield {
"type": "error",
"request_id": rid,
"error": "์—์ด์ „ํŠธ ์ฒ˜๋ฆฌ ์ค‘ ๋‚ด๋ถ€ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค.",
"finished": True,
}
finally:
if trace.total_latency_ms == 0.0:
trace.total_latency_ms = (time.monotonic() - loop_start) * 1000
self._record_graph_run(
session=session,
trace=trace,
started_at=started_at,
completed_at=time.time(),
)
async def _execute_tool(
self,
step_name: str,
accumulated: Dict[str, Any],
session: SessionContext,
) -> ToolResult:
tool_fn = self._tools.get(step_name)
if tool_fn is None:
return ToolResult(
tool=step_name, success=False, error=f"๋“ฑ๋ก๋˜์ง€ ์•Š์€ tool: {step_name}"
)
start = time.monotonic()
try:
execution_query = normalize_text(accumulated.get("query", ""))
result_data = await asyncio.wait_for(
tool_fn(
query=execution_query,
context=accumulated,
session=session,
),
timeout=self._tool_timeout,
)
return ToolResult(
tool=step_name,
success=True,
data=result_data if isinstance(result_data, dict) else {"result": result_data},
latency_ms=(time.monotonic() - start) * 1000,
)
except asyncio.TimeoutError:
return ToolResult(
tool=step_name,
success=False,
error=f"tool {step_name} ํƒ€์ž„์•„์›ƒ ({self._tool_timeout}์ดˆ)",
latency_ms=(time.monotonic() - start) * 1000,
)
except Exception as exc:
logger.error(f"[AgentLoop] tool {step_name} ์‹คํ–‰ ์˜ค๋ฅ˜: {exc}", exc_info=True)
return ToolResult(
tool=step_name,
success=False,
error=str(exc),
latency_ms=(time.monotonic() - start) * 1000,
)
@staticmethod
def _build_tool_log_metadata(data: Dict[str, Any]) -> Dict[str, Any]:
"""tool log์— ๋‚จ๊ธธ ์ž‘์€ preview๋งŒ ๋ณด๊ด€ํ•œ๋‹ค."""
metadata: Dict[str, Any] = {}
if "count" in data:
metadata["count"] = data["count"]
if "query" in data:
metadata["query"] = data["query"]
if "results" in data and isinstance(data["results"], list):
metadata["result_count"] = len(data["results"])
if "text" in data:
metadata["text_preview"] = str(data["text"])[:200]
return metadata
@staticmethod
def _build_plan_summary(trace: AgentTrace) -> str:
tools = " -> ".join(trace.plan_tools)
if trace.plan_reason:
return f"{trace.plan_reason} | tools: {tools}"
return tools
@staticmethod
def _graph_run_status(trace: AgentTrace) -> str:
if trace.error:
return "failed"
if any(not result.success for result in trace.tool_results):
return "completed_with_errors"
return "completed"
@classmethod
def _record_graph_run(
cls,
session: SessionContext,
trace: AgentTrace,
started_at: float,
completed_at: float,
) -> None:
success_count = sum(1 for result in trace.tool_results if result.success)
failure_count = len(trace.tool_results) - success_count
session.add_graph_run(
request_id=trace.request_id,
plan_summary=cls._build_plan_summary(trace),
approval_status="not_requested",
executed_capabilities=[tool_name(result.tool) for result in trace.tool_results],
status=cls._graph_run_status(trace),
error=trace.error,
total_latency_ms=trace.total_latency_ms,
metadata={
"plan_reason": trace.plan_reason,
"tool_result_count": len(trace.tool_results),
"success_count": success_count,
"failure_count": failure_count,
"final_text_preview": trace.final_text[:200],
},
started_at=started_at,
completed_at=completed_at,
)
@staticmethod
def _extract_final_text(accumulated: Dict[str, Any], tool_names: List[str]) -> str:
# draft_response ๊ฒฐ๊ณผ๊ฐ€ ์žˆ์œผ๋ฉด ์šฐ์„  ์‚ฌ์šฉ
payload = accumulated.get("draft_response", {})
if isinstance(payload, dict) and payload.get("text"):
return str(payload["text"])
# ๊ฐ tool ๊ฒฐ๊ณผ์—์„œ text ์ถ”์ถœ ์‹œ๋„
for step_name in tool_names:
payload = accumulated.get(step_name, {})
if isinstance(payload, dict) and payload.get("text"):
return str(payload["text"])
parts: List[str] = []
api_data = accumulated.get(ToolType.API_LOOKUP.value, {})
if api_data.get("context_text"):
parts.append(api_data["context_text"])
elif api_data.get("results"):
lines = ["[์™ธ๋ถ€ ์กฐํšŒ ๊ฒฐ๊ณผ]"]
for item in api_data["results"][:3]:
title = item.get("title", item.get("qnaTitle", ""))
content = item.get("content", item.get("qnaContent", ""))[:120]
lines.append(f"- {title}: {content}")
parts.append("\n".join(lines))
return "\n\n".join(parts) if parts else "์š”์ฒญ์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."