agentic-saul / app.py
vin00d's picture
report generator
967a65d
from typing import List, Dict, TypedDict, Union, Annotated
import chainlit as cl
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import BaseTool
from operator import itemgetter
from pydantic import BaseModel, Field, ConfigDict
from saul.tools import tools
from loguru import logger
import json
from dotenv import load_dotenv
load_dotenv()
# Types for our nodes
class AgentState(TypedDict):
"""State for the research agent."""
messages: Annotated[list, add_messages]
# Initialize the LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, streaming=True)
# bind tools to the llm
llm = llm.bind_tools(tools)
# Agent node implementation
async def call_model(state: AgentState) -> Dict:
"""Agent node that decides which tool to use."""
logger.success(f"Calling agent model with state: {state}")
# print("...........................................Calling agent model...........................................")
# print(f"State:: {state}\n\n")
response = llm.invoke(state["messages"])
return {"messages": [response]}
execute_tool = ToolNode(tools)
# Create the graph
uncompiled_graph = StateGraph(AgentState)
# Add nodes
uncompiled_graph.add_node("agent", call_model)
uncompiled_graph.add_node("action", execute_tool)
# conditional edge function
def should_continue(state):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "action"
return END
# Add edges
uncompiled_graph.add_conditional_edges("agent", should_continue)
uncompiled_graph.add_edge("action", "agent")
# Set entry point
uncompiled_graph.set_entry_point("agent")
# Compile the graph
compiled_graph = uncompiled_graph.compile()
system_prompt = """You are a helpful legal research assistant.
Only answer questions that are related to legal research, else politely decline to answer.
Only answer the last question.
"""
@cl.on_chat_start
async def start():
"""Initialize the chat session."""
# Initialize session state
initial_state = AgentState(
messages=[SystemMessage(content=system_prompt)]
)
image = cl.Image(path="./bcs-calling-card.jpg")
# Set initial state in session
cl.user_session.set("state", initial_state)
await cl.Message(
content="""### Better Call Agentic-Saul
🫡 Hey! I'm Saul, your legal research assistant.
I'm here to help you with your legal research needs.
I will find information from:
- πŸ“„ Legal Glossary - legal terms and definitions
- πŸ“š Wikipedia - basic information
- πŸ’¬ Reddit discussions - current chatter in social media
- πŸ“– Google Scholar Case Law - judicial opinions from numerous federal and state courts
What would you like to research?""",
elements=[image],
).send()
@cl.on_message
async def main(message: cl.Message):
"""Handle incoming messages."""
# Get current session state
state_dict = cl.user_session.get("state")
state = AgentState(**state_dict)
# Update messages in state
state["messages"].append(HumanMessage(content=message.content))
inputs = {"messages": state["messages"]}
# try:
msg = cl.Message(content="")
# Run the graph with current state
async for chunk in compiled_graph.astream(inputs, stream_mode="updates"):
for node, values in chunk.items():
logger.success(f"Receiving update from node: {node}")
# print(f"-------------- Receiving update from node: '{node}' --------------")
await msg.stream_token(f"Receiving update from node: **{node}**\n")
if node == "action":
for tool_msg in values["messages"]:
output = f"Tool used: {tool_msg.name}"
# output += f"\nTool output: {tool_msg.content}"
logger.success(output)
# print(output)
await msg.stream_token(f"{output}\n\n")
else: # node == "agent"
if values["messages"][0].tool_calls:
tool_names = [tool["name"] for tool in values["messages"][0].tool_calls]
output = f"Tool(s) Selected: {', '.join(tool_names)}"
logger.success(output)
# print(output)
await msg.stream_token(f"{output}\n\n")
else:
# output = f"\n\n\n**Final Model output**: {values['messages'][-1].content}"
output = "\n**Final output**\n"
logger.success(output)
# print(output)
print(values["messages"][-1].content)
await msg.stream_token(f"{output}")
# await msg.stream_token(values["messages"][-1].content)
print("\n\n")
# stream messages to the UI
if token := values["messages"][-1].content:
await msg.stream_token(token)
# Update messages in state
# state["messages"].extend(values["messages"])
# msg = cl.Message(content=values["messages"][-1].content)
# await message.send()
# Update session state
cl.user_session.set("state", state)