Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Flexible LangGraph agent for cyber-legal assistant | |
| Agent can call tools, process results, and decide to continue or answer | |
| """ | |
| import os | |
| import copy | |
| import logging | |
| import traceback | |
| from typing import Dict, Any, List, Optional | |
| from datetime import datetime | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage | |
| logger = logging.getLogger(__name__) | |
| from agent_states.agent_state import AgentState | |
| from utils.utils_fn import PerformanceMonitor, find_tool | |
| from utils.tools import tools, tools_for_client, tools_for_lawyer | |
| class CyberLegalAgent: | |
| def __init__(self, llm, tools: List[Any] = tools, tools_facade: List[Any] = tools): | |
| self.tools = tools | |
| self.tools_facade = tools_facade | |
| self.llm = llm | |
| self.performance_monitor = PerformanceMonitor() | |
| # Constrain to one tool call at a time | |
| self.llm_with_tools = self.llm.bind_tools(self.tools_facade,) | |
| self.workflow = self._build_workflow() | |
| def _build_workflow(self) -> StateGraph: | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("agent", self._agent_node) | |
| workflow.add_node("tools", self._tools_node) | |
| workflow.set_entry_point("agent") | |
| workflow.add_conditional_edges("agent", self._should_continue, {"continue": "tools", "end": END}) | |
| workflow.add_conditional_edges("tools", self._after_tools, {"continue": "agent", "end": END}) | |
| return workflow.compile() | |
| def _after_tools(self, state: AgentState) -> str: | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| if not intermediate_steps: | |
| return "continue" | |
| # Check if the last message is a ToolMessage from find_lawyers | |
| last_message = intermediate_steps[-1] | |
| if isinstance(last_message, ToolMessage): | |
| if last_message.name == "_find_lawyers": | |
| logger.info("π find_lawyers tool completed - ending with tool output") | |
| return "end" | |
| return "continue" | |
| def _should_continue(self, state: AgentState) -> str: | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| if not intermediate_steps: | |
| return "continue" | |
| last_message = intermediate_steps[-1] | |
| if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
| print(last_message.tool_calls) | |
| logger.info(f"π§ Calling tools: {[tc['name'] for tc in last_message.tool_calls]}") | |
| return "continue" | |
| return "end" | |
| async def _agent_node(self, state: AgentState) -> AgentState: | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| if not intermediate_steps: | |
| history = state.get("conversation_history", []) | |
| # Use provided system prompt if available (not None), otherwise use the default | |
| system_prompt_to_use = state.get("system_prompt") | |
| jurisdiction = state.get("jurisdiction", "unknown") | |
| # Deepcopy to avoid modifying the original prompt string | |
| system_prompt_to_use = copy.deepcopy(system_prompt_to_use) | |
| system_prompt_to_use = system_prompt_to_use.format(jurisdiction=jurisdiction) | |
| logger.info(f"π Formatted system prompt with jurisdiction: {jurisdiction}") | |
| intermediate_steps.append(SystemMessage(content=system_prompt_to_use)) | |
| for msg in history: | |
| if isinstance(msg, dict): | |
| if msg.get("role") == "user": | |
| intermediate_steps.append(HumanMessage(content=msg.get("content"))) | |
| elif msg.get("role") == "assistant": | |
| intermediate_steps.append(AIMessage(content=msg.get("content"))) | |
| intermediate_steps.append(HumanMessage(content=state["user_query"])) | |
| response = await self.llm_with_tools.ainvoke(intermediate_steps) | |
| intermediate_steps.append(response) | |
| state["intermediate_steps"] = intermediate_steps | |
| return state | |
| async def _tools_node(self, state: AgentState) -> AgentState: | |
| logger.info("=" * 80) | |
| logger.info("π§ TOOLS NODE STARTED") | |
| logger.info("=" * 80) | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| last_message = intermediate_steps[-1] | |
| if not (hasattr(last_message, 'tool_calls') and last_message.tool_calls): | |
| logger.info("β οΈ No tool calls found in last message, skipping tools node") | |
| logger.info("=" * 80) | |
| return state | |
| logger.info(f"π Found {len(last_message.tool_calls)} tool call(s) to execute") | |
| for tool_call in last_message.tool_calls: | |
| try: | |
| tool_name = tool_call['name'] | |
| logger.info(f"\nπ¨ Executing tool: {tool_name}") | |
| logger.info(f"π Tool call ID: {tool_call['id']}") | |
| logger.info(f"π Tool args: {tool_call['args']}") | |
| # Use utility function to find tool with smart naming | |
| tool_func, tool_name_normalized = find_tool(tool_name, self.tools) | |
| if not tool_func: | |
| continue | |
| logger.info(f"β Tool function found: {tool_func.name}") | |
| # Inject parameters from state into tool calls | |
| args = tool_call['args'].copy() | |
| # Use actual tool name for comparisons (includes underscore prefix) | |
| if tool_name_normalized in ["_find_lawyers", "_query_knowledge_graph", "_message_lawyer"]: | |
| args["conversation_history"] = state.get("conversation_history", []) | |
| logger.info(f"π Injecting conversation_history to {tool_name_normalized}: {len(args['conversation_history'])} messages") | |
| # Inject jurisdiction for query_knowledge_graph tool | |
| if tool_name_normalized == "_query_knowledge_graph": | |
| args["jurisdiction"] = state.get("jurisdiction") | |
| logger.info(f"π Injecting jurisdiction: {args['jurisdiction']}") | |
| # Inject user_id for message_lawyer tool | |
| if tool_name_normalized == "_message_lawyer": | |
| args["user_id"] = state.get("user_id") | |
| logger.info(f"π€ Injecting user_id: {args['user_id']}") | |
| # Inject user_id for retrieve_lawyer_document tool | |
| if tool_name_normalized == "_retrieve_lawyer_document": | |
| args["user_id"] = state.get("user_id") | |
| logger.info(f"π Injecting user_id for retrieve_lawyer_document: {args['user_id']}") | |
| # Inject user_id for create_draft_document tool | |
| if tool_name_normalized == "_create_draft_document": | |
| args["user_id"] = state.get("user_id") | |
| args["instruction"] = state.get("user_query", "") | |
| # Inject conversation_history for context | |
| args["conversation_history"] = state.get("conversation_history", []) | |
| logger.info(f"π¬ Injecting conversation_history for create_draft_document: {len(args['conversation_history'])} messages") | |
| # Inject doc_summaries for context | |
| args["doc_summaries"] = state.get("doc_summaries", []) | |
| logger.info(f"π Injecting doc_summaries for create_draft_document: {len(args['doc_summaries'])} documents") | |
| logger.info(f"π Using user_query as instruction for create_draft_document") | |
| logger.info(f"π Injecting user_id for create_draft_document: {args['user_id']}") | |
| # Use normalized name for ToolMessage | |
| tool_call['name'] = tool_name_normalized | |
| logger.info(f"π Calling tool function.ainvoke with args: {args}") | |
| result = await tool_func.ainvoke(args) | |
| logger.info(f"β Tool {tool_call} returned successfully") | |
| logger.info(f"π Result type: {type(result)}") | |
| logger.info(f"π Result preview (first 500 chars): {str(result)[:500]}") | |
| logger.info(f"π’ Result length: {len(str(result))} characters") | |
| intermediate_steps.append(ToolMessage(content=str(result), tool_call_id=tool_call['id'], name=tool_name_normalized)) | |
| logger.info(f"β ToolMessage added to intermediate_steps") | |
| except Exception as e: | |
| logger.error("=" * 80) | |
| logger.error(f"β ERROR EXECUTING TOOL: {tool_name}") | |
| logger.error("=" * 80) | |
| logger.error(f"π¬ Error message: {str(e)}") | |
| logger.error(f"π Error type: {type(e).__name__}") | |
| logger.error(f"π Tool call: {tool_call}") | |
| logger.error(f"π Full traceback:") | |
| logger.error(traceback.format_exc()) | |
| logger.error("=" * 80) | |
| # Add error as ToolMessage so agent knows about the failure | |
| error_message = f"Tool {tool_name} failed: {str(e)}" | |
| intermediate_steps.append(ToolMessage( | |
| content=error_message, | |
| tool_call_id=tool_call['id'], | |
| name=tool_name_normalized | |
| )) | |
| logger.warning(f"β οΈ Error ToolMessage added to intermediate_steps, continuing workflow") | |
| state["intermediate_steps"] = intermediate_steps | |
| logger.info("=" * 80) | |
| logger.info("β TOOLS NODE COMPLETED") | |
| logger.info(f"π Total intermediate_steps: {len(intermediate_steps)}") | |
| logger.info("=" * 80) | |
| return state | |
| async def process_query(self, user_query: str, user_id: Optional[str] = None, jurisdiction: str = "Romania", conversation_history: Optional[List[Dict[str, str]]] = None, system_prompt: Optional[str] = None) -> Dict[str, Any]: | |
| initial_state = { | |
| "user_query": user_query, | |
| "user_id": user_id, | |
| "conversation_history": conversation_history or [], | |
| "intermediate_steps": [], | |
| "relevant_documents": [], | |
| "query_timestamp": datetime.now().isoformat(), | |
| "processing_time": None, | |
| "jurisdiction": jurisdiction, | |
| "system_prompt": system_prompt | |
| } | |
| self.performance_monitor.reset() | |
| final_state = await self.workflow.ainvoke(initial_state) | |
| intermediate_steps = final_state.get("intermediate_steps", []) | |
| final_response = intermediate_steps[-1].content | |
| logger.info("=" * 80) | |
| logger.info("π€ SENDING RESPONSE TO CLIENT") | |
| logger.info("=" * 80) | |
| logger.info(f"π€ User ID: {user_id}") | |
| logger.info(f"π Timestamp: {final_state.get('query_timestamp')}") | |
| logger.info(f"β±οΈ Processing time: {sum(self.performance_monitor.get_metrics().values())}s") | |
| logger.info(f"π Response length: {len(final_response) if final_response else 0} characters") | |
| logger.info(f"π¬ Response preview (first 200 chars): {final_response[:200] if final_response else '(empty)'}") | |
| logger.info(f"π’ Full response length: {len(str(final_response)) if final_response else 0} characters") | |
| # Build the full response object | |
| response_obj = { | |
| "response": final_response or "I apologize, but I couldn't generate a response.", | |
| "processing_time": sum(self.performance_monitor.get_metrics().values()), | |
| "references": final_state.get("relevant_documents", []), | |
| "timestamp": final_state.get("query_timestamp") | |
| } | |
| logger.info(f"π’ Full response object keys: {list(response_obj.keys())}") | |
| logger.info(f"π Response value type: {type(response_obj['response'])}") | |
| logger.info("=" * 80) | |
| logger.info(f"π€ Returning response to user") | |
| logger.info("=" * 80) | |
| return response_obj | |