|
|
""" |
|
|
Workflow Engine Manager for CodeAct Agent. |
|
|
Manages the LangGraph workflow execution. |
|
|
""" |
|
|
|
|
|
import re |
|
|
import json |
|
|
import datetime |
|
|
from typing import Dict, Tuple, List, Any |
|
|
from pathlib import Path |
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn |
|
|
from rich.rule import Rule |
|
|
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage |
|
|
from langgraph.graph import StateGraph, START, END |
|
|
from core.types import AgentState, AgentConfig |
|
|
from .plan_manager import PlanManager |
|
|
|
|
|
|
|
|
class WorkflowEngine: |
|
|
"""Manages the LangGraph workflow execution.""" |
|
|
|
|
|
def __init__(self, model, config: AgentConfig, console_display, state_manager): |
|
|
self.model = model |
|
|
self.config = config |
|
|
self.console = console_display |
|
|
self.state_manager = state_manager |
|
|
self.plan_manager = PlanManager() |
|
|
self.graph = None |
|
|
self.trace_logs = [] |
|
|
self.message_history = [] |
|
|
|
|
|
def setup_workflow(self, generate_func, execute_func, should_continue_func): |
|
|
"""Setup the LangGraph workflow with provided functions.""" |
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("generate", generate_func) |
|
|
workflow.add_node("execute", execute_func) |
|
|
|
|
|
workflow.add_edge(START, "generate") |
|
|
workflow.add_edge("execute", "generate") |
|
|
|
|
|
workflow.add_conditional_edges("generate", should_continue_func, { |
|
|
"end": END, |
|
|
"execute": "execute" |
|
|
}) |
|
|
|
|
|
self.graph = workflow.compile() |
|
|
|
|
|
def run_workflow(self, initial_state: Dict) -> Tuple: |
|
|
"""Execute the workflow and handle display. |
|
|
|
|
|
Returns: |
|
|
tuple: (result_content, final_state) |
|
|
""" |
|
|
|
|
|
self.trace_logs = [] |
|
|
self.message_history = [] |
|
|
|
|
|
|
|
|
final_solution_provided = False |
|
|
previous_plan = None |
|
|
displayed_reasoning = set() |
|
|
|
|
|
|
|
|
self.console.console.print(Rule(title="Execution Steps", style="yellow")) |
|
|
|
|
|
with Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
console=self.console.console, |
|
|
transient=True |
|
|
) as progress: |
|
|
task = progress.add_task("Executing agent...", total=None) |
|
|
|
|
|
final_state = None |
|
|
for s in self.graph.stream(initial_state, stream_mode="values"): |
|
|
step_count = s.get("step_count", 0) |
|
|
current_plan = s.get("current_plan") |
|
|
final_state = s |
|
|
|
|
|
progress.update(task, description=f"Step {step_count}") |
|
|
|
|
|
message = s["messages"][-1] |
|
|
|
|
|
|
|
|
serialized_msg = self._serialize_message(message) |
|
|
self.message_history.append(serialized_msg) |
|
|
|
|
|
|
|
|
if isinstance(message, AIMessage): |
|
|
self._process_ai_message( |
|
|
message, step_count, current_plan, previous_plan, |
|
|
displayed_reasoning, final_solution_provided |
|
|
) |
|
|
if current_plan != previous_plan: |
|
|
previous_plan = current_plan |
|
|
|
|
|
elif "<observation>" in message.content: |
|
|
self._process_observation_message(message, step_count) |
|
|
|
|
|
result_content = final_state["messages"][-1].content if final_state else "" |
|
|
return result_content, final_state |
|
|
|
|
|
def _process_ai_message(self, message, step_count, current_plan, previous_plan, |
|
|
displayed_reasoning, final_solution_provided): |
|
|
"""Process AI message and display appropriate panels.""" |
|
|
full_content = message.content |
|
|
|
|
|
|
|
|
thinking_content = self._extract_thinking_content(full_content) |
|
|
if thinking_content and len(thinking_content) > 20: |
|
|
content_hash = hash(thinking_content.strip()) |
|
|
if content_hash not in displayed_reasoning: |
|
|
self.console.print_reasoning(thinking_content, step_count) |
|
|
displayed_reasoning.add(content_hash) |
|
|
|
|
|
self._add_trace_entry("reasoning", step_count, thinking_content) |
|
|
|
|
|
|
|
|
if (current_plan and current_plan != previous_plan and |
|
|
self.config.verbose and not final_solution_provided): |
|
|
self.console.print_plan(current_plan) |
|
|
|
|
|
self._add_trace_entry("plan", step_count, current_plan) |
|
|
|
|
|
|
|
|
if "<execute>" in full_content and "</execute>" in full_content: |
|
|
execute_match = re.search(r"<execute>(.*?)</execute>", full_content, re.DOTALL) |
|
|
if execute_match: |
|
|
code = execute_match.group(1).strip() |
|
|
self.console.print_code_execution(code, step_count) |
|
|
|
|
|
self._add_trace_entry("code_execution", step_count, code) |
|
|
|
|
|
elif "<solution>" in full_content and "</solution>" in full_content: |
|
|
solution_match = re.search(r"<solution>(.*?)</solution>", full_content, re.DOTALL) |
|
|
if solution_match: |
|
|
|
|
|
if current_plan: |
|
|
updated_plan = self.plan_manager.update_plan_for_solution(current_plan) |
|
|
if updated_plan != current_plan: |
|
|
self.console.print_plan(updated_plan) |
|
|
|
|
|
solution = solution_match.group(1).strip() |
|
|
self.console.print_solution(solution, step_count) |
|
|
final_solution_provided = True |
|
|
|
|
|
self._add_trace_entry("solution", step_count, solution) |
|
|
|
|
|
elif "<error>" in full_content: |
|
|
error_match = re.search(r"<error>(.*?)</error>", full_content, re.DOTALL) |
|
|
if error_match: |
|
|
error_content = error_match.group(1).strip() |
|
|
self.console.print_error(error_content, step_count) |
|
|
|
|
|
self._add_trace_entry("error", step_count, error_content) |
|
|
|
|
|
def _process_observation_message(self, message, step_count): |
|
|
"""Process observation message and display results.""" |
|
|
obs_match = re.search(r"<observation>(.*?)</observation>", message.content, re.DOTALL) |
|
|
if obs_match: |
|
|
observation = obs_match.group(1).strip() |
|
|
formatted_output = self._truncate_to_20_rows(observation) |
|
|
self.console.print_execution_result(formatted_output, step_count) |
|
|
|
|
|
self._add_trace_entry("observation", step_count, observation) |
|
|
|
|
|
def _extract_thinking_content(self, content: str) -> str: |
|
|
"""Extract thinking content from the message, removing tags and plan information.""" |
|
|
|
|
|
content = re.sub(r'</?(execute|solution|error)>', '', content) |
|
|
|
|
|
|
|
|
plan_pattern = r'\d+\.\s*\[[^\]]*\]\s*[^\n]+(?:\n\d+\.\s*\[[^\]]*\]\s*[^\n]+)*' |
|
|
content = re.sub(plan_pattern, '', content).strip() |
|
|
|
|
|
|
|
|
content = re.sub(r'<observation>.*?</observation>', '', content, flags=re.DOTALL) |
|
|
|
|
|
|
|
|
lines = [line.strip() for line in content.split('\n') if line.strip()] |
|
|
return '\n'.join(lines) |
|
|
|
|
|
def _truncate_to_20_rows(self, text: str) -> str: |
|
|
"""Truncate any text output to show only the first 20 rows.""" |
|
|
lines = text.split('\n') |
|
|
|
|
|
if len(lines) > 20: |
|
|
truncated = '\n'.join(lines[:20]) |
|
|
total_lines = len(lines) |
|
|
truncated += f"\n\n⚠️ Output truncated to 20 rows. Full output contains {total_lines} rows." |
|
|
return truncated |
|
|
|
|
|
return text |
|
|
|
|
|
def _add_trace_entry(self, step_type: str, step_count: int, content: Any, metadata: Dict = None): |
|
|
"""Add an entry to the trace log.""" |
|
|
entry = { |
|
|
"timestamp": datetime.datetime.now().isoformat(), |
|
|
"step_count": step_count, |
|
|
"step_type": step_type, |
|
|
"content": content, |
|
|
"metadata": metadata or {} |
|
|
} |
|
|
self.trace_logs.append(entry) |
|
|
|
|
|
def _serialize_message(self, message: BaseMessage) -> Dict: |
|
|
"""Serialize a message for saving.""" |
|
|
if isinstance(message, HumanMessage): |
|
|
msg_type = "human" |
|
|
elif isinstance(message, AIMessage): |
|
|
msg_type = "ai" |
|
|
else: |
|
|
msg_type = "system" |
|
|
|
|
|
return { |
|
|
"type": msg_type, |
|
|
"content": message.content, |
|
|
"timestamp": datetime.datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
def save_trace_to_file(self, filepath: str = None) -> str: |
|
|
"""Save the complete trace log to a JSON file.""" |
|
|
if filepath is None: |
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
filepath = f"agent_trace_{timestamp}.json" |
|
|
|
|
|
trace_data = { |
|
|
"execution_time": datetime.datetime.now().isoformat(), |
|
|
"config": { |
|
|
"max_steps": self.config.max_steps, |
|
|
"timeout_seconds": self.config.timeout_seconds, |
|
|
"verbose": self.config.verbose |
|
|
}, |
|
|
"messages": self.message_history, |
|
|
"trace_logs": self.trace_logs |
|
|
} |
|
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
|
json.dump(trace_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
return filepath |
|
|
|
|
|
def generate_summary(self) -> Dict: |
|
|
"""Generate a summary of the agent execution.""" |
|
|
summary = { |
|
|
"total_steps": len(self.trace_logs), |
|
|
"message_count": len(self.message_history), |
|
|
"execution_flow": [], |
|
|
"code_executions": [], |
|
|
"observations": [], |
|
|
"errors": [], |
|
|
"final_solution": None |
|
|
} |
|
|
|
|
|
for entry in self.trace_logs: |
|
|
step_info = { |
|
|
"step": entry["step_count"], |
|
|
"type": entry["step_type"], |
|
|
"timestamp": entry["timestamp"] |
|
|
} |
|
|
|
|
|
if entry["step_type"] == "reasoning": |
|
|
summary["execution_flow"].append({ |
|
|
**step_info, |
|
|
"reasoning": entry["content"][:200] + "..." if len(entry["content"]) > 200 else entry["content"] |
|
|
}) |
|
|
elif entry["step_type"] == "code_execution": |
|
|
summary["code_executions"].append({ |
|
|
**step_info, |
|
|
"code": entry["content"] |
|
|
}) |
|
|
elif entry["step_type"] == "observation": |
|
|
summary["observations"].append({ |
|
|
**step_info, |
|
|
"output": entry["content"][:500] + "..." if len(entry["content"]) > 500 else entry["content"] |
|
|
}) |
|
|
elif entry["step_type"] == "error": |
|
|
summary["errors"].append({ |
|
|
**step_info, |
|
|
"error": entry["content"] |
|
|
}) |
|
|
elif entry["step_type"] == "solution": |
|
|
summary["final_solution"] = { |
|
|
**step_info, |
|
|
"solution": entry["content"] |
|
|
} |
|
|
|
|
|
return summary |
|
|
|
|
|
def save_summary_to_file(self, filepath: str = None) -> str: |
|
|
"""Save the execution summary to a JSON file.""" |
|
|
if filepath is None: |
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
filepath = f"agent_summary_{timestamp}.json" |
|
|
|
|
|
summary = self.generate_summary() |
|
|
summary["timestamp"] = datetime.datetime.now().isoformat() |
|
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
|
json.dump(summary, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
return filepath |