Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Document Editor Agent - LangGraph agent for modifying HTML documents | |
| Implements Cline-like iterative editing with validation | |
| """ | |
| import uuid | |
| import logging | |
| import traceback | |
| from typing import Dict, Any, List, Optional | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage | |
| from agent_states.doc_editor_state import DocEditorState | |
| from utils.tools import ( | |
| tools_for_doc_editor_facade as doc_editor_tools_facade, | |
| tools_for_doc_editor as doc_editor_tools | |
| ) | |
| from prompts.doc_editor import ( | |
| get_doc_editor_system_prompt, | |
| get_summary_system_prompt | |
| ) | |
| from utils.utils_fn import push_document_update | |
| logger = logging.getLogger(__name__) | |
| class DocumentEditorAgent: | |
| """Agent for editing HTML documents using Cline-like iterative approach.""" | |
| def __init__(self, llm, llm_tool_calling, tools=doc_editor_tools, tools_facade=doc_editor_tools_facade): | |
| """ | |
| Initialize the document editor agent. | |
| Args: | |
| llm: LLM principal pour la génération du résumé final | |
| llm_tool_calling: LLM pour les tool calls | |
| tools: Liste des tools d'implémentation (avec doc_text injecté) | |
| tools_facade: Liste des tools façades (pour le LLM) | |
| """ | |
| self.llm = llm | |
| self.llm_tool_calling = llm_tool_calling | |
| self.tools = tools | |
| self.tools_facade = tools_facade | |
| self.llm_with_tools = self.llm_tool_calling.bind_tools(self.tools_facade, tool_choice="any") | |
| logger.info("🔧 Tool binding configured with tool_choice='any' to force tool calls") | |
| logger.info(f"🤖 Using {type(llm_tool_calling).__name__} for tool calling") | |
| logger.info(f"📝 Using {type(llm).__name__} for summary generation") | |
| logger.info(f"🛠️ Tools loaded: {len(self.tools_facade)} facade, {len(self.tools)} implementation") | |
| self.workflow = self._build_workflow() | |
| def _build_workflow(self) -> StateGraph: | |
| """Build the LangGraph workflow for document editing.""" | |
| workflow = StateGraph(DocEditorState) | |
| workflow.add_node("agent", self._agent_node) | |
| workflow.add_node("tools", self._tools_node) | |
| workflow.add_node("summary", self._summary_node) | |
| workflow.set_entry_point("agent") | |
| workflow.add_conditional_edges("agent", self._should_continue, {"continue": "tools", "end": END}) | |
| workflow.add_conditional_edges("tools", self._after_tools, {"continue": "agent", "summary": "summary", "end": END}) | |
| workflow.add_edge("summary", END) | |
| return workflow.compile() | |
| def _should_continue(self, state: DocEditorState) -> str: | |
| """Decide whether to continue after agent node.""" | |
| iteration_count = state.get("iteration_count", 0) | |
| max_iterations = state.get("max_iterations", 10) | |
| if iteration_count >= max_iterations: | |
| logger.warning(f"Max iterations ({max_iterations}) reached - ending workflow") | |
| return "end" | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| if not intermediate_steps: | |
| return "continue" | |
| last_message = intermediate_steps[-1] | |
| if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
| tool_names = [tc['name'] for tc in last_message.tool_calls] | |
| logger.info(f"Agent calling tools: {tool_names}") | |
| return "continue" | |
| # Check if attempt_completion was ever called | |
| attempt_completion_called = any( | |
| isinstance(msg, ToolMessage) and msg.name == "attempt_completion" | |
| for msg in intermediate_steps | |
| ) | |
| if attempt_completion_called: | |
| logger.info("attempt_completion already called - ending workflow") | |
| return "end" | |
| logger.warning("Agent response without tool calls - forcing continue") | |
| return "continue" | |
| def _after_tools(self, state: DocEditorState) -> str: | |
| """Decide whether to continue after tools node.""" | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| last_msg = intermediate_steps[-1] | |
| if isinstance(last_msg, ToolMessage) and last_msg.name == "attempt_completion": | |
| logger.info("attempt_completion called - generating summary") | |
| return "summary" | |
| iteration_count = state.get("iteration_count", 0) | |
| max_iterations = state.get("max_iterations", 10) | |
| if iteration_count >= max_iterations: | |
| logger.warning(f"Max iterations ({max_iterations}) reached - ending workflow") | |
| return "end" | |
| return "continue" | |
| async def _agent_node(self, state: DocEditorState) -> DocEditorState: | |
| """Agent node: Generate tool calls based on current state.""" | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| iteration_count = state.get("iteration_count", 0) | |
| max_iterations = state.get("max_iterations", 10) | |
| logger.info(f"Agent iteration {iteration_count + 1}/{max_iterations}") | |
| if iteration_count == 0: | |
| intermediate_steps.append(SystemMessage(content=get_doc_editor_system_prompt())) | |
| context_msg = "" | |
| conversation_history = state.get("conversation_history", []) | |
| doc_summaries = state.get("doc_summaries", []) | |
| if conversation_history: | |
| context_msg += "\n\n##Previous conversation:\n" | |
| for msg in conversation_history: | |
| context_msg += f"{msg.get('role', '').capitalize()}: {msg.get('content', '')}\n" | |
| if doc_summaries: | |
| context_msg += "\n\n##Document summaries:\n" + "\n".join(f"- {s}" for s in doc_summaries) | |
| plan_formatted = f"""## EXECUTION PLAN | |
| You have been provided with an execution plan. Follow this plan carefully to complete the task: | |
| {state['plan']} | |
| Use this plan as your guide for the editing steps to perform.""" | |
| full_message = f"{context_msg}\n\n##Instruction\n{state['user_instruction']}\n\n{plan_formatted}" | |
| intermediate_steps.append(HumanMessage(content=full_message)) | |
| # Add fake tool call/response after user message | |
| fake_tool_call_id = str(uuid.uuid4()) | |
| intermediate_steps.append( | |
| AIMessage(content="", tool_calls=[{ | |
| "id": fake_tool_call_id, | |
| "name": "view_current_document", | |
| "args": {} | |
| }]) | |
| ) | |
| intermediate_steps.append( | |
| ToolMessage(content=state['doc_text'], tool_call_id=fake_tool_call_id, name="view_current_document") | |
| ) | |
| logger.info(f"🔍 Initial document provided via fake view_current_document ({len(state['doc_text'])}b)") | |
| logger.info(f"Context: {len(conversation_history)} hist + {len(doc_summaries)} summaries") | |
| response = await self.llm_with_tools.ainvoke(intermediate_steps) | |
| intermediate_steps.append(response) | |
| state["intermediate_steps"] = intermediate_steps | |
| return state | |
| async def _tools_node(self, state: DocEditorState) -> DocEditorState: | |
| """Tools node: Execute tool calls and update document state.""" | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| last_message = intermediate_steps[-1] | |
| if not (hasattr(last_message, 'tool_calls') and last_message.tool_calls): | |
| return state | |
| # Increment iteration count after tools are executed | |
| iteration_count = state.get("iteration_count", 0) | |
| iteration_count += 1 | |
| state["iteration_count"] = iteration_count | |
| for tool_call in last_message.tool_calls: | |
| tool_name = tool_call['name'] | |
| # Look up implementation function with underscore prefix | |
| tool_func = next((t for t in self.tools if t.name == "_" + tool_name), None) | |
| if not tool_func: | |
| logger.warning(f"Tool function not found for {tool_name}") | |
| continue | |
| args = tool_call['args'].copy() | |
| # Inject doc_text for editing tools AND view_current_document | |
| if tool_name in ["replace_html", "add_html", "delete_html", "view_current_document"]: | |
| args["doc_text"] = state["doc_text"] | |
| logger.info(f"Injecting doc_text ({len(state['doc_text'])}b) into {tool_name}") | |
| try: | |
| result = await tool_func.ainvoke(args) | |
| if result.get("ok") and "doc_text" in result: | |
| state["doc_text"] = result["doc_text"] | |
| if tool_name in ["replace_html", "add_html", "delete_html"]: | |
| document_id = state.get("document_id") | |
| user_id = state.get("user_id") | |
| if document_id and user_id: | |
| logger.info(f"Pushing document update after successful {tool_name}...") | |
| await push_document_update( | |
| document_id=document_id, | |
| content=state["doc_text"], | |
| user_id=user_id | |
| ) | |
| else: | |
| logger.debug("No document_id/user_id, skipping update push") | |
| intermediate_steps.append( | |
| ToolMessage(content=str(result), tool_call_id=tool_call['id'], name=tool_name) | |
| ) | |
| except Exception as e: | |
| intermediate_steps.append( | |
| ToolMessage(content=f"Error: {str(e)}", tool_call_id=tool_call['id'], name=tool_name) | |
| ) | |
| logger.error(f"{tool_name} error: {str(e)}") | |
| state["intermediate_steps"] = intermediate_steps | |
| return state | |
| async def _summary_node(self, state: DocEditorState) -> DocEditorState: | |
| """Summary node: Generate a clean summary of all modifications.""" | |
| logger.info("Generating modification summary...") | |
| summary_messages = [ | |
| SystemMessage(content=get_summary_system_prompt()), | |
| HumanMessage(content=f""" | |
| Original instruction: {state['user_instruction']} | |
| Full conversation history (including all tool calls and results): | |
| """, name="user") | |
| ] | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| for msg in intermediate_steps: | |
| summary_messages.append(msg) | |
| response = await self.llm.ainvoke(summary_messages) | |
| state["final_summary"] = response.content | |
| logger.info("Summary generated successfully") | |
| logger.info(f"Summary preview: {response.content}...") | |
| return state | |
| async def edit_document( | |
| self, | |
| doc_text: str, | |
| user_instruction: str, | |
| plan: Optional[str] = None, | |
| doc_summaries: List[str] = [], | |
| conversation_history: List[Dict[str, str]] = [], | |
| max_iterations: int = 10, | |
| document_id: Optional[str] = None, | |
| user_id: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Edit a document according to the user instruction. | |
| Args: | |
| doc_text: HTML document string | |
| user_instruction: What changes to make to the document | |
| plan: Optional execution plan provided by DocAssistant | |
| doc_summaries: Optional summaries of the document for context | |
| conversation_history: Optional previous conversation messages for context | |
| max_iterations: Maximum number of edit iterations (default: 10) | |
| document_id: Optional UUID of the document for live updates | |
| user_id: Optional user ID for authentication | |
| Returns: | |
| Dict with doc_text, message, success, iteration_count, final_summary | |
| """ | |
| logger.info("=" * 80) | |
| logger.info("DOCUMENT EDITOR AGENT STARTING") | |
| logger.info("=" * 80) | |
| logger.info(f"Initial document size: {len(doc_text)} bytes") | |
| logger.info(f"Instruction: {user_instruction}{'...' if len(user_instruction) > 100 else ''}") | |
| logger.info(f"Document summaries: {len(doc_summaries)}") | |
| logger.info(f"Conversation history: {len(conversation_history)} messages") | |
| logger.info(f"Max iterations: {max_iterations}") | |
| if document_id: | |
| logger.info(f"Document ID: {document_id} (live updates enabled)") | |
| if user_id: | |
| logger.info(f"User ID: {user_id}") | |
| if doc_summaries: | |
| logger.info("Document summaries loaded:") | |
| for i, summary in enumerate(doc_summaries, 1): | |
| logger.info(f" [{i}] {str(summary)}...") | |
| if len(doc_summaries) > 3: | |
| logger.info(f" ... and {len(doc_summaries) - 3} more") | |
| if conversation_history: | |
| logger.info(f"Conversation history loaded ({len(conversation_history)} messages)") | |
| initial_state = { | |
| "doc_text": doc_text, | |
| "doc_summaries": doc_summaries, | |
| "plan": plan, | |
| "conversation_history": conversation_history, | |
| "user_instruction": user_instruction, | |
| "iteration_count": 0, | |
| "max_iterations": max_iterations, | |
| "intermediate_steps": [], | |
| "document_id": document_id, | |
| "user_id": user_id | |
| } | |
| logger.info("Initial state prepared, starting workflow...") | |
| try: | |
| logger.info("Invoking LangGraph workflow...") | |
| final_state = await self.workflow.ainvoke(initial_state) | |
| final_summary = final_state.get("final_summary", "") | |
| attempt_completion_called = any( | |
| isinstance(msg, ToolMessage) and msg.name == "attempt_completion" | |
| for msg in final_state.get("intermediate_steps", []) | |
| ) | |
| success = attempt_completion_called | |
| message = final_summary | |
| iteration_count = final_state.get("iteration_count", 0) | |
| final_doc_size = len(final_state["doc_text"]) | |
| size_change = final_doc_size - len(doc_text) | |
| logger.info("=" * 80) | |
| logger.info("DOCUMENT EDITING COMPLETED") | |
| logger.info("=" * 80) | |
| logger.info(f"Success: {success}") | |
| logger.info(f"Iterations: {iteration_count}") | |
| logger.info(f"Final document size: {final_doc_size} bytes") | |
| logger.info(f"Size change: {size_change:+d} bytes ({size_change/len(doc_text)*100:+.1f}%)") | |
| logger.info(f"Message: {message}{'...' if len(message) > 100 else ''}") | |
| logger.info("=" * 80) | |
| logger.info("FINAL DOCUMENT CONTENT") | |
| logger.info("=" * 80) | |
| logger.info(final_state["doc_text"]) | |
| logger.info("=" * 80) | |
| if final_summary: | |
| logger.info("=" * 80) | |
| logger.info("MODIFICATION SUMMARY") | |
| logger.info("=" * 80) | |
| logger.info(final_summary) | |
| logger.info("=" * 80) | |
| if not success: | |
| max_iters = final_state.get("max_iterations", 10) | |
| if iteration_count >= max_iters: | |
| logger.warning(f"Failed to complete editing within {max_iters} iterations") | |
| message = f"Failed to complete editing within {max_iters} iterations" | |
| return { | |
| "doc_text": final_state["doc_text"], | |
| "message": message, | |
| "success": success, | |
| "iteration_count": iteration_count, | |
| "final_summary": final_summary | |
| } | |
| except Exception as e: | |
| logger.error("=" * 80) | |
| logger.error("DOCUMENT EDITING FAILED") | |
| logger.error("=" * 80) | |
| logger.error(f"Location: subagents/doc_editor.py:{traceback.extract_tb(e.__traceback__)[-1].lineno}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| logger.error(f"Error message: {str(e)}") | |
| logger.error(f"User Instruction: {user_instruction if len(user_instruction) > 100 else user_instruction}") | |
| logger.error(f"Document Size: {len(doc_text):,} bytes") | |
| if document_id: | |
| logger.error(f"Document ID: {document_id}") | |
| if user_id: | |
| logger.error(f"User ID: {user_id}") | |
| logger.error(f"Max Iterations: {max_iterations}") | |
| logger.error(f"Document Summaries: {len(doc_summaries)}") | |
| logger.error(f"Conversation History: {len(conversation_history)} messages") | |
| if plan: | |
| logger.error(f"Plan: {plan if len(plan) > 200 else plan}") | |
| logger.error(f"Main LLM: {type(self.llm).__name__}") | |
| logger.error(f"Tool Calling LLM: {type(self.llm_tool_calling).__name__}") | |
| logger.error(f"Tools Available: {len(self.tools)}") | |
| logger.error(f"Tool Names: {', '.join([t.name for t in self.tools])}") | |
| logger.error(f"Traceback:\n{traceback.format_exc()}") | |
| logger.error(f"Document Preview: {doc_text[:200]}") | |
| logger.error("=" * 80) | |
| return { | |
| "doc_text": doc_text, | |
| "message": f"Error during editing: {str(e)}", | |
| "success": False, | |
| "iteration_count": 0 | |
| } |