Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Flexible LangGraph agent for cyber-legal assistant | |
| Agent can call tools, process results, and decide to continue or answer | |
| """ | |
| import os | |
| import copy | |
| import logging | |
| from typing import Dict, Any, List, Optional | |
| from datetime import datetime | |
| from langgraph.graph import StateGraph, END | |
| from langchain_openai import ChatOpenAI | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage | |
| logger = logging.getLogger(__name__) | |
| from agent_states.agent_state import AgentState | |
| from utils.utils import PerformanceMonitor | |
| from utils.lightrag_client import LightRAGClient | |
| from utils.tools import tools, tools_for_client, tools_for_lawyer | |
| class CyberLegalAgent: | |
| def __init__(self, llm, tools: List[Any] = tools, tools_facade: List[Any] = tools): | |
| self.tools = tools | |
| self.tools_facade = tools_facade | |
| self.llm = llm | |
| self.performance_monitor = PerformanceMonitor() | |
| self.llm_with_tools = self.llm.bind_tools(self.tools_facade) | |
| self.workflow = self._build_workflow() | |
| def _build_workflow(self) -> StateGraph: | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("agent", self._agent_node) | |
| workflow.add_node("tools", self._tools_node) | |
| workflow.set_entry_point("agent") | |
| workflow.add_conditional_edges("agent", self._should_continue, {"continue": "tools", "end": END}) | |
| workflow.add_conditional_edges("tools", self._after_tools, {"continue": "agent", "end": END}) | |
| return workflow.compile() | |
| def _after_tools(self, state: AgentState) -> str: | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| if not intermediate_steps: | |
| return "continue" | |
| # Check if the last message is a ToolMessage from find_lawyers | |
| last_message = intermediate_steps[-1] | |
| if isinstance(last_message, ToolMessage): | |
| if last_message.name == "_find_lawyers": | |
| logger.info("π find_lawyers tool completed - ending with tool output") | |
| return "end" | |
| return "continue" | |
| def _should_continue(self, state: AgentState) -> str: | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| if not intermediate_steps: | |
| return "continue" | |
| last_message = intermediate_steps[-1] | |
| if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
| print(last_message.tool_calls) | |
| logger.info(f"π§ Calling tools: {[tc['name'] for tc in last_message.tool_calls]}") | |
| return "continue" | |
| return "end" | |
| async def _agent_node(self, state: AgentState) -> AgentState: | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| if not intermediate_steps: | |
| history = state.get("conversation_history", []) | |
| # Use provided system prompt if available (not None), otherwise use the default | |
| system_prompt_to_use = state.get("system_prompt") | |
| jurisdiction = state.get("jurisdiction", "unknown") | |
| # Deepcopy to avoid modifying the original prompt string | |
| system_prompt_to_use = copy.deepcopy(system_prompt_to_use) | |
| system_prompt_to_use = system_prompt_to_use.format(jurisdiction=jurisdiction) | |
| logger.info(f"π Formatted system prompt with jurisdiction: {jurisdiction}") | |
| intermediate_steps.append(SystemMessage(content=system_prompt_to_use)) | |
| for msg in history: | |
| if isinstance(msg, dict): | |
| if msg.get("role") == "user": | |
| intermediate_steps.append(HumanMessage(content=msg.get("content"))) | |
| elif msg.get("role") == "assistant": | |
| intermediate_steps.append(AIMessage(content=msg.get("content"))) | |
| intermediate_steps.append(HumanMessage(content=state["user_query"])) | |
| response = await self.llm_with_tools.ainvoke(intermediate_steps) | |
| intermediate_steps.append(response) | |
| state["intermediate_steps"] = intermediate_steps | |
| return state | |
| async def _tools_node(self, state: AgentState) -> AgentState: | |
| intermediate_steps = state.get("intermediate_steps", []) | |
| last_message = intermediate_steps[-1] | |
| if not (hasattr(last_message, 'tool_calls') and last_message.tool_calls): | |
| return state | |
| for tool_call in last_message.tool_calls: | |
| tool_func = next((t for t in self.tools if t.name == "_" + tool_call['name']), None) | |
| if tool_func: | |
| # Inject parameters from state into tool calls | |
| args = tool_call['args'].copy() | |
| # Inject conversation_history for tools that need it | |
| if tool_call['name'] in ["find_lawyers", "query_knowledge_graph", "message_lawyer"]: | |
| args["conversation_history"] = state.get("conversation_history", []) | |
| logger.info(f"π Injecting conversation_history to {tool_call['name']}: {len(args['conversation_history'])} messages") | |
| # Inject jurisdiction for query_knowledge_graph tool | |
| if tool_call['name'] == "query_knowledge_graph": | |
| args["jurisdiction"] = state.get("jurisdiction") | |
| logger.info(f"π Injecting jurisdiction: {args['jurisdiction']}") | |
| # Inject client_id for message_lawyer tool | |
| if tool_call['name'] == "message_lawyer": | |
| args["client_id"] = state.get("client_id") | |
| logger.info(f"π€ Injecting client_id: {args['client_id']}") | |
| tool_call['name']="_" + tool_call['name'] | |
| result = await tool_func.ainvoke(args) | |
| logger.info(f"π§ Tool {tool_call} returned: {result}") | |
| intermediate_steps.append(ToolMessage(content=str(result), tool_call_id=tool_call['id'], name=tool_call['name'])) | |
| state["intermediate_steps"] = intermediate_steps | |
| return state | |
| async def process_query(self, user_query: str, client_id: Optional[str] = None, jurisdiction: str = "Romania", conversation_history: Optional[List[Dict[str, str]]] = None, system_prompt: Optional[str] = None) -> Dict[str, Any]: | |
| initial_state = { | |
| "user_query": user_query, | |
| "client_id": client_id, | |
| "conversation_history": conversation_history or [], | |
| "intermediate_steps": [], | |
| "relevant_documents": [], | |
| "query_timestamp": datetime.now().isoformat(), | |
| "processing_time": None, | |
| "jurisdiction": jurisdiction, | |
| "system_prompt": system_prompt | |
| } | |
| self.performance_monitor.reset() | |
| final_state = await self.workflow.ainvoke(initial_state) | |
| intermediate_steps = final_state.get("intermediate_steps", []) | |
| final_response = intermediate_steps[-1].content | |
| return { | |
| "response": final_response or "I apologize, but I couldn't generate a response.", | |
| "processing_time": sum(self.performance_monitor.get_metrics().values()), | |
| "references": final_state.get("relevant_documents", []), | |
| "timestamp": final_state.get("query_timestamp") | |
| } | |