CyberLegalAIendpoint / agents /doc_editor.py
Charles Grandjean
update to suit gemini format
bd87ed7
#!/usr/bin/env python3
"""
Document Editor Agent - LangGraph agent for modifying HTML documents
Implements Cline-like iterative editing with validation
"""
import uuid
import logging
import traceback
from typing import Dict, Any, List, Optional
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from agent_states.doc_editor_state import DocEditorState
from utils.tools import (
tools_for_doc_editor_facade as doc_editor_tools_facade,
tools_for_doc_editor as doc_editor_tools
)
from prompts.doc_editor import (
get_doc_editor_system_prompt,
get_summary_system_prompt
)
from utils.utils_fn import push_document_update
logger = logging.getLogger(__name__)
class DocumentEditorAgent:
"""Agent for editing HTML documents using Cline-like iterative approach."""
def __init__(self, llm, llm_tool_calling, tools=doc_editor_tools, tools_facade=doc_editor_tools_facade):
"""
Initialize the document editor agent.
Args:
llm: LLM principal pour la génération du résumé final
llm_tool_calling: LLM pour les tool calls
tools: Liste des tools d'implémentation (avec doc_text injecté)
tools_facade: Liste des tools façades (pour le LLM)
"""
self.llm = llm
self.llm_tool_calling = llm_tool_calling
self.tools = tools
self.tools_facade = tools_facade
self.llm_with_tools = self.llm_tool_calling.bind_tools(self.tools_facade, tool_choice="any")
logger.info("🔧 Tool binding configured with tool_choice='any' to force tool calls")
logger.info(f"🤖 Using {type(llm_tool_calling).__name__} for tool calling")
logger.info(f"📝 Using {type(llm).__name__} for summary generation")
logger.info(f"🛠️ Tools loaded: {len(self.tools_facade)} facade, {len(self.tools)} implementation")
self.workflow = self._build_workflow()
def _build_workflow(self) -> StateGraph:
"""Build the LangGraph workflow for document editing."""
workflow = StateGraph(DocEditorState)
workflow.add_node("agent", self._agent_node)
workflow.add_node("tools", self._tools_node)
workflow.add_node("summary", self._summary_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", self._should_continue, {"continue": "tools", "end": END})
workflow.add_conditional_edges("tools", self._after_tools, {"continue": "agent", "summary": "summary", "end": END})
workflow.add_edge("summary", END)
return workflow.compile()
def _should_continue(self, state: DocEditorState) -> str:
"""Decide whether to continue after agent node."""
iteration_count = state.get("iteration_count", 0)
max_iterations = state.get("max_iterations", 10)
if iteration_count >= max_iterations:
logger.warning(f"Max iterations ({max_iterations}) reached - ending workflow")
return "end"
intermediate_steps = state.get("intermediate_steps", [])
if not intermediate_steps:
return "continue"
last_message = intermediate_steps[-1]
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
tool_names = [tc['name'] for tc in last_message.tool_calls]
logger.info(f"Agent calling tools: {tool_names}")
return "continue"
# Check if attempt_completion was ever called
attempt_completion_called = any(
isinstance(msg, ToolMessage) and msg.name == "attempt_completion"
for msg in intermediate_steps
)
if attempt_completion_called:
logger.info("attempt_completion already called - ending workflow")
return "end"
logger.warning("Agent response without tool calls - forcing continue")
return "continue"
def _after_tools(self, state: DocEditorState) -> str:
"""Decide whether to continue after tools node."""
intermediate_steps = state.get("intermediate_steps", [])
last_msg = intermediate_steps[-1]
if isinstance(last_msg, ToolMessage) and last_msg.name == "attempt_completion":
logger.info("attempt_completion called - generating summary")
return "summary"
iteration_count = state.get("iteration_count", 0)
max_iterations = state.get("max_iterations", 10)
if iteration_count >= max_iterations:
logger.warning(f"Max iterations ({max_iterations}) reached - ending workflow")
return "end"
return "continue"
async def _agent_node(self, state: DocEditorState) -> DocEditorState:
"""Agent node: Generate tool calls based on current state."""
intermediate_steps = state.get("intermediate_steps", [])
iteration_count = state.get("iteration_count", 0)
max_iterations = state.get("max_iterations", 10)
logger.info(f"Agent iteration {iteration_count + 1}/{max_iterations}")
if iteration_count == 0:
intermediate_steps.append(SystemMessage(content=get_doc_editor_system_prompt()))
context_msg = ""
conversation_history = state.get("conversation_history", [])
doc_summaries = state.get("doc_summaries", [])
if conversation_history:
context_msg += "\n\n##Previous conversation:\n"
for msg in conversation_history:
context_msg += f"{msg.get('role', '').capitalize()}: {msg.get('content', '')}\n"
if doc_summaries:
context_msg += "\n\n##Document summaries:\n" + "\n".join(f"- {s}" for s in doc_summaries)
plan_formatted = f"""## EXECUTION PLAN
You have been provided with an execution plan. Follow this plan carefully to complete the task:
{state['plan']}
Use this plan as your guide for the editing steps to perform."""
full_message = f"{context_msg}\n\n##Instruction\n{state['user_instruction']}\n\n{plan_formatted}"
intermediate_steps.append(HumanMessage(content=full_message))
# Add fake tool call/response after user message
fake_tool_call_id = str(uuid.uuid4())
intermediate_steps.append(
AIMessage(content="", tool_calls=[{
"id": fake_tool_call_id,
"name": "view_current_document",
"args": {}
}])
)
intermediate_steps.append(
ToolMessage(content=state['doc_text'], tool_call_id=fake_tool_call_id, name="view_current_document")
)
logger.info(f"🔍 Initial document provided via fake view_current_document ({len(state['doc_text'])}b)")
logger.info(f"Context: {len(conversation_history)} hist + {len(doc_summaries)} summaries")
response = await self.llm_with_tools.ainvoke(intermediate_steps)
intermediate_steps.append(response)
state["intermediate_steps"] = intermediate_steps
return state
async def _tools_node(self, state: DocEditorState) -> DocEditorState:
"""Tools node: Execute tool calls and update document state."""
intermediate_steps = state.get("intermediate_steps", [])
last_message = intermediate_steps[-1]
if not (hasattr(last_message, 'tool_calls') and last_message.tool_calls):
return state
# Increment iteration count after tools are executed
iteration_count = state.get("iteration_count", 0)
iteration_count += 1
state["iteration_count"] = iteration_count
for tool_call in last_message.tool_calls:
tool_name = tool_call['name']
# Look up implementation function with underscore prefix
tool_func = next((t for t in self.tools if t.name == "_" + tool_name), None)
if not tool_func:
logger.warning(f"Tool function not found for {tool_name}")
continue
args = tool_call['args'].copy()
# Inject doc_text for editing tools AND view_current_document
if tool_name in ["replace_html", "add_html", "delete_html", "view_current_document"]:
args["doc_text"] = state["doc_text"]
logger.info(f"Injecting doc_text ({len(state['doc_text'])}b) into {tool_name}")
try:
result = await tool_func.ainvoke(args)
if result.get("ok") and "doc_text" in result:
state["doc_text"] = result["doc_text"]
if tool_name in ["replace_html", "add_html", "delete_html"]:
document_id = state.get("document_id")
user_id = state.get("user_id")
if document_id and user_id:
logger.info(f"Pushing document update after successful {tool_name}...")
await push_document_update(
document_id=document_id,
content=state["doc_text"],
user_id=user_id
)
else:
logger.debug("No document_id/user_id, skipping update push")
intermediate_steps.append(
ToolMessage(content=str(result), tool_call_id=tool_call['id'], name=tool_name)
)
except Exception as e:
intermediate_steps.append(
ToolMessage(content=f"Error: {str(e)}", tool_call_id=tool_call['id'], name=tool_name)
)
logger.error(f"{tool_name} error: {str(e)}")
state["intermediate_steps"] = intermediate_steps
return state
async def _summary_node(self, state: DocEditorState) -> DocEditorState:
"""Summary node: Generate a clean summary of all modifications."""
logger.info("Generating modification summary...")
summary_messages = [
SystemMessage(content=get_summary_system_prompt()),
HumanMessage(content=f"""
Original instruction: {state['user_instruction']}
Full conversation history (including all tool calls and results):
""", name="user")
]
intermediate_steps = state.get("intermediate_steps", [])
for msg in intermediate_steps:
summary_messages.append(msg)
response = await self.llm.ainvoke(summary_messages)
state["final_summary"] = response.content
logger.info("Summary generated successfully")
logger.info(f"Summary preview: {response.content}...")
return state
async def edit_document(
self,
doc_text: str,
user_instruction: str,
plan: Optional[str] = None,
doc_summaries: List[str] = [],
conversation_history: List[Dict[str, str]] = [],
max_iterations: int = 10,
document_id: Optional[str] = None,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Edit a document according to the user instruction.
Args:
doc_text: HTML document string
user_instruction: What changes to make to the document
plan: Optional execution plan provided by DocAssistant
doc_summaries: Optional summaries of the document for context
conversation_history: Optional previous conversation messages for context
max_iterations: Maximum number of edit iterations (default: 10)
document_id: Optional UUID of the document for live updates
user_id: Optional user ID for authentication
Returns:
Dict with doc_text, message, success, iteration_count, final_summary
"""
logger.info("=" * 80)
logger.info("DOCUMENT EDITOR AGENT STARTING")
logger.info("=" * 80)
logger.info(f"Initial document size: {len(doc_text)} bytes")
logger.info(f"Instruction: {user_instruction}{'...' if len(user_instruction) > 100 else ''}")
logger.info(f"Document summaries: {len(doc_summaries)}")
logger.info(f"Conversation history: {len(conversation_history)} messages")
logger.info(f"Max iterations: {max_iterations}")
if document_id:
logger.info(f"Document ID: {document_id} (live updates enabled)")
if user_id:
logger.info(f"User ID: {user_id}")
if doc_summaries:
logger.info("Document summaries loaded:")
for i, summary in enumerate(doc_summaries, 1):
logger.info(f" [{i}] {str(summary)}...")
if len(doc_summaries) > 3:
logger.info(f" ... and {len(doc_summaries) - 3} more")
if conversation_history:
logger.info(f"Conversation history loaded ({len(conversation_history)} messages)")
initial_state = {
"doc_text": doc_text,
"doc_summaries": doc_summaries,
"plan": plan,
"conversation_history": conversation_history,
"user_instruction": user_instruction,
"iteration_count": 0,
"max_iterations": max_iterations,
"intermediate_steps": [],
"document_id": document_id,
"user_id": user_id
}
logger.info("Initial state prepared, starting workflow...")
try:
logger.info("Invoking LangGraph workflow...")
final_state = await self.workflow.ainvoke(initial_state)
final_summary = final_state.get("final_summary", "")
attempt_completion_called = any(
isinstance(msg, ToolMessage) and msg.name == "attempt_completion"
for msg in final_state.get("intermediate_steps", [])
)
success = attempt_completion_called
message = final_summary
iteration_count = final_state.get("iteration_count", 0)
final_doc_size = len(final_state["doc_text"])
size_change = final_doc_size - len(doc_text)
logger.info("=" * 80)
logger.info("DOCUMENT EDITING COMPLETED")
logger.info("=" * 80)
logger.info(f"Success: {success}")
logger.info(f"Iterations: {iteration_count}")
logger.info(f"Final document size: {final_doc_size} bytes")
logger.info(f"Size change: {size_change:+d} bytes ({size_change/len(doc_text)*100:+.1f}%)")
logger.info(f"Message: {message}{'...' if len(message) > 100 else ''}")
logger.info("=" * 80)
logger.info("FINAL DOCUMENT CONTENT")
logger.info("=" * 80)
logger.info(final_state["doc_text"])
logger.info("=" * 80)
if final_summary:
logger.info("=" * 80)
logger.info("MODIFICATION SUMMARY")
logger.info("=" * 80)
logger.info(final_summary)
logger.info("=" * 80)
if not success:
max_iters = final_state.get("max_iterations", 10)
if iteration_count >= max_iters:
logger.warning(f"Failed to complete editing within {max_iters} iterations")
message = f"Failed to complete editing within {max_iters} iterations"
return {
"doc_text": final_state["doc_text"],
"message": message,
"success": success,
"iteration_count": iteration_count,
"final_summary": final_summary
}
except Exception as e:
logger.error("=" * 80)
logger.error("DOCUMENT EDITING FAILED")
logger.error("=" * 80)
logger.error(f"Location: subagents/doc_editor.py:{traceback.extract_tb(e.__traceback__)[-1].lineno}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
logger.error(f"User Instruction: {user_instruction if len(user_instruction) > 100 else user_instruction}")
logger.error(f"Document Size: {len(doc_text):,} bytes")
if document_id:
logger.error(f"Document ID: {document_id}")
if user_id:
logger.error(f"User ID: {user_id}")
logger.error(f"Max Iterations: {max_iterations}")
logger.error(f"Document Summaries: {len(doc_summaries)}")
logger.error(f"Conversation History: {len(conversation_history)} messages")
if plan:
logger.error(f"Plan: {plan if len(plan) > 200 else plan}")
logger.error(f"Main LLM: {type(self.llm).__name__}")
logger.error(f"Tool Calling LLM: {type(self.llm_tool_calling).__name__}")
logger.error(f"Tools Available: {len(self.tools)}")
logger.error(f"Tool Names: {', '.join([t.name for t in self.tools])}")
logger.error(f"Traceback:\n{traceback.format_exc()}")
logger.error(f"Document Preview: {doc_text[:200]}")
logger.error("=" * 80)
return {
"doc_text": doc_text,
"message": f"Error during editing: {str(e)}",
"success": False,
"iteration_count": 0
}