more graph setup, comments on how the flow should work
Browse files- src/chatbot/graph_manager.py +59 -0
- src/chatbot/nodes.py +72 -0
- src/chatbot/prompts/agent_prompts.py +7 -0
- src/form_management/acceptance_criteria.json +0 -32
- src/form_management/form_management.py +0 -72
- src/main.py +2 -0
- src/poetry.lock +0 -0
- src/pyproject.toml +5 -2
src/chatbot/graph_manager.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# graph_manager.py
|
| 2 |
+
# graph definition and compilation
|
| 3 |
+
|
| 4 |
+
from langgraph.graph import StateGraph, START, END
|
| 5 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
+
from nodes import (
|
| 7 |
+
State,
|
| 8 |
+
speaker_chatbot
|
| 9 |
+
)
|
| 10 |
+
from form_management.form_management import Form
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
class GraphManager:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.graph = None
|
| 16 |
+
self.init_graph()
|
| 17 |
+
|
| 18 |
+
# user input ->
|
| 19 |
+
# if user intent is clear -> speaker form management -> audit review and feedback -> speaker generate response based on feedback
|
| 20 |
+
# if user intent is unclear -> speaker generate response for clarification
|
| 21 |
+
def init_graph(self):
|
| 22 |
+
graph_memory = MemorySaver()
|
| 23 |
+
graph_builder = StateGraph(State)
|
| 24 |
+
|
| 25 |
+
graph_builder.add_node("speaker_chatbot", speaker_chatbot)
|
| 26 |
+
|
| 27 |
+
graph_builder.add_edge(START, "speaker_chatbot")
|
| 28 |
+
graph_builder.add_edge("speaker_chatbot", END)
|
| 29 |
+
|
| 30 |
+
self.graph = graph_builder.compile(checkpointer=graph_memory)
|
| 31 |
+
|
| 32 |
+
def process_user_input(self, user_input, session_id):
|
| 33 |
+
"""
|
| 34 |
+
Process the user input and update the graph state.
|
| 35 |
+
"""
|
| 36 |
+
config = {"configurable": {"thread_id": session_id}}
|
| 37 |
+
|
| 38 |
+
input_dict = {
|
| 39 |
+
"messages": [("user", user_input)],
|
| 40 |
+
"session_id": session_id
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
for event in self.graph.stream(
|
| 44 |
+
input_dict,
|
| 45 |
+
config,
|
| 46 |
+
stream_mode="values"
|
| 47 |
+
):
|
| 48 |
+
event["messages"][-1].pretty_print()
|
| 49 |
+
|
| 50 |
+
snapshot = self.graph.get_state(config)
|
| 51 |
+
return snapshot[0]["messages"][-1].content
|
| 52 |
+
|
| 53 |
+
graph_manager = GraphManager()
|
| 54 |
+
|
| 55 |
+
# Example usage
|
| 56 |
+
response = graph_manager.process_user_input(
|
| 57 |
+
"hi",
|
| 58 |
+
1
|
| 59 |
+
)
|
src/chatbot/nodes.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nodes.py
|
| 2 |
+
|
| 3 |
+
from typing import Annotated
|
| 4 |
+
from typing_extensions import TypedDict
|
| 5 |
+
from langgraph.graph.message import add_messages
|
| 6 |
+
from langchain.tools import tool
|
| 7 |
+
from langgraph.prebuilt import create_react_agent
|
| 8 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 9 |
+
|
| 10 |
+
from agent_llm_engine import llm_overall_agent
|
| 11 |
+
from prompts.agent_prompts import speaker_system_message, auditor_system_message
|
| 12 |
+
from form_management.form_management import Form
|
| 13 |
+
|
| 14 |
+
# placeholder form
|
| 15 |
+
form = Form()
|
| 16 |
+
|
| 17 |
+
class State(TypedDict):
|
| 18 |
+
messages: Annotated[list, add_messages]
|
| 19 |
+
session_id: str
|
| 20 |
+
|
| 21 |
+
@tool
|
| 22 |
+
def placeholder_tool(input: str) -> str:
|
| 23 |
+
"""
|
| 24 |
+
Placeholder tool that does nothing.
|
| 25 |
+
"""
|
| 26 |
+
return "This is a placeholder tool."
|
| 27 |
+
|
| 28 |
+
tools = [placeholder_tool]
|
| 29 |
+
|
| 30 |
+
speaker_memory = MemorySaver()
|
| 31 |
+
speaker_agent = create_react_agent(
|
| 32 |
+
llm_overall_agent,
|
| 33 |
+
tools,
|
| 34 |
+
state_modifier=speaker_system_message,
|
| 35 |
+
checkpointer=speaker_memory
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def speaker_chatbot(state: State):
|
| 39 |
+
config = {"configurable": {"thread_id": state["session_id"]}}
|
| 40 |
+
|
| 41 |
+
response = speaker_agent.invoke(
|
| 42 |
+
{
|
| 43 |
+
"messages": [
|
| 44 |
+
(
|
| 45 |
+
"user", "{}".format(state["messages"][-1].content)
|
| 46 |
+
)
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
config,
|
| 50 |
+
)["messages"][-1]
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
"messages": response.content,
|
| 54 |
+
"session_id": state["session_id"]
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
auditor_memory = MemorySaver()
|
| 58 |
+
auditor_agent = create_react_agent(
|
| 59 |
+
llm_overall_agent,
|
| 60 |
+
tools,
|
| 61 |
+
state_modifier=auditor_system_message,
|
| 62 |
+
checkpointer=speaker_memory
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def auditor_feedback(state: State):
|
| 66 |
+
# get form based on session id, but for now just use placeholder
|
| 67 |
+
form_info = form.get_info_as_dict()
|
| 68 |
+
|
| 69 |
+
# auditor looks at form info and acceptance criteria and gives speaker feedback
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
pass
|
src/chatbot/prompts/agent_prompts.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
speaker_system_message = """
|
| 2 |
+
You are an assistant that helps users submit receipts and fill in a form.
|
| 3 |
+
Your job is to communicate with the user and determine which tools to use given the user input.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
auditor_system_message = """
|
| 7 |
+
"""
|
src/form_management/acceptance_criteria.json
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"fields": [
|
| 3 |
-
{
|
| 4 |
-
"fieldName": "Project Name",
|
| 5 |
-
"priority": 1,
|
| 6 |
-
"description": "The name of the project or purpose for the payment.",
|
| 7 |
-
"validationRules": [
|
| 8 |
-
"Must be associated with environmental or sustainability projects."
|
| 9 |
-
],
|
| 10 |
-
"exampleInput": "โครงการ \"การจัดการคาร์บอนเครดิตในป่าเพื่อการพัฒนาที่ยั่งยืน\""
|
| 11 |
-
},
|
| 12 |
-
{
|
| 13 |
-
"fieldName": "Date",
|
| 14 |
-
"priority": 2,
|
| 15 |
-
"description": "The date of the transaction.",
|
| 16 |
-
"validationRules": [
|
| 17 |
-
"Must be a date that is in the past or present."
|
| 18 |
-
],
|
| 19 |
-
"exampleInput": "3 เมษายน 2567"
|
| 20 |
-
},
|
| 21 |
-
{
|
| 22 |
-
"fieldName": "Total Payment Amount",
|
| 23 |
-
"priority": 3,
|
| 24 |
-
"description": "The total amount paid.",
|
| 25 |
-
"validationRules": [
|
| 26 |
-
"Must be a positive amount."
|
| 27 |
-
],
|
| 28 |
-
"exampleInput": "3 เมษายน 2567"
|
| 29 |
-
}
|
| 30 |
-
]
|
| 31 |
-
}
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/form_management/form_management.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
class Form:
|
| 2 |
-
class PaymentDetailsTable:
|
| 3 |
-
def __init__(self, column_headers, row_entries):
|
| 4 |
-
self.column_headers = column_headers
|
| 5 |
-
self.row_entries = row_entries
|
| 6 |
-
|
| 7 |
-
def __init__(self, receipt_title, project_name, date, receipt_number, payer_name, payment_details_table, total_payment_amount, incompleteness_description):
|
| 8 |
-
self.receipt_title = receipt_title
|
| 9 |
-
self.project_name = project_name
|
| 10 |
-
self.date = date
|
| 11 |
-
self.receipt_number = receipt_number
|
| 12 |
-
self.payer_name = payer_name
|
| 13 |
-
self.payment_details_table = self.PaymentDetailsTable(**payment_details_table)
|
| 14 |
-
self.total_payment_amount = total_payment_amount
|
| 15 |
-
self.incompleteness_description = incompleteness_description
|
| 16 |
-
|
| 17 |
-
# Getter and Setter for receipt_title
|
| 18 |
-
def get_receipt_title(self):
|
| 19 |
-
return self.receipt_title
|
| 20 |
-
|
| 21 |
-
def set_receipt_title(self, receipt_title):
|
| 22 |
-
self.receipt_title = receipt_title
|
| 23 |
-
|
| 24 |
-
# Getter and Setter for project_name
|
| 25 |
-
def get_project_name(self):
|
| 26 |
-
return self.project_name
|
| 27 |
-
|
| 28 |
-
def set_project_name(self, project_name):
|
| 29 |
-
self.project_name = project_name
|
| 30 |
-
|
| 31 |
-
# Getter and Setter for date
|
| 32 |
-
def get_date(self):
|
| 33 |
-
return self.date
|
| 34 |
-
|
| 35 |
-
def set_date(self, date):
|
| 36 |
-
self.date = date
|
| 37 |
-
|
| 38 |
-
# Getter and Setter for receipt_number
|
| 39 |
-
def get_receipt_number(self):
|
| 40 |
-
return self.receipt_number
|
| 41 |
-
|
| 42 |
-
def set_receipt_number(self, receipt_number):
|
| 43 |
-
self.receipt_number = receipt_number
|
| 44 |
-
|
| 45 |
-
# Getter and Setter for payer_name
|
| 46 |
-
def get_payer_name(self):
|
| 47 |
-
return self.payer_name
|
| 48 |
-
|
| 49 |
-
def set_payer_name(self, payer_name):
|
| 50 |
-
self.payer_name = payer_name
|
| 51 |
-
|
| 52 |
-
# Getter and Setter for payment_details_table
|
| 53 |
-
def get_payment_details_table(self):
|
| 54 |
-
return self.payment_details_table
|
| 55 |
-
|
| 56 |
-
def set_payment_details_table(self, payment_details_table):
|
| 57 |
-
self.payment_details_table = self.PaymentDetailsTable(**payment_details_table)
|
| 58 |
-
|
| 59 |
-
# Getter and Setter for total_payment_amount
|
| 60 |
-
def get_total_payment_amount(self):
|
| 61 |
-
return self.total_payment_amount
|
| 62 |
-
|
| 63 |
-
def set_total_payment_amount(self, total_payment_amount):
|
| 64 |
-
self.total_payment_amount = total_payment_amount
|
| 65 |
-
|
| 66 |
-
# Getter and Setter for incompleteness_description
|
| 67 |
-
def get_incompleteness_description(self):
|
| 68 |
-
return self.incompleteness_description
|
| 69 |
-
|
| 70 |
-
def set_incompleteness_description(self, incompleteness_description):
|
| 71 |
-
self.incompleteness_description = incompleteness_description
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
def main():
|
| 2 |
print("Hello, World!")
|
| 3 |
|
|
|
|
| 1 |
+
# from chatbot.graph_manager
|
| 2 |
+
|
| 3 |
def main():
|
| 4 |
print("Hello, World!")
|
| 5 |
|
src/poetry.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/pyproject.toml
CHANGED
|
@@ -6,11 +6,14 @@ authors = [
|
|
| 6 |
{name = "suratan.boon",email = "suratan.boon@cjexpress.co.th"}
|
| 7 |
]
|
| 8 |
readme = "README.md"
|
| 9 |
-
requires-python = ">=3.
|
| 10 |
dependencies = [
|
| 11 |
"gradio (>=5.28.0,<6.0.0)",
|
| 12 |
"openai (>=1.76.2,<2.0.0)",
|
| 13 |
-
"dotenv (>=0.9.9,<0.10.0)"
|
|
|
|
|
|
|
|
|
|
| 14 |
]
|
| 15 |
|
| 16 |
|
|
|
|
| 6 |
{name = "suratan.boon",email = "suratan.boon@cjexpress.co.th"}
|
| 7 |
]
|
| 8 |
readme = "README.md"
|
| 9 |
+
requires-python = ">=3.12.4,<4.0.0"
|
| 10 |
dependencies = [
|
| 11 |
"gradio (>=5.28.0,<6.0.0)",
|
| 12 |
"openai (>=1.76.2,<2.0.0)",
|
| 13 |
+
"dotenv (>=0.9.9,<0.10.0)",
|
| 14 |
+
"langgraph (>=0.4.1,<0.5.0)",
|
| 15 |
+
"langchain (>=0.3.24,<0.4.0)",
|
| 16 |
+
"langchain-openai (>=0.3.14,<0.4.0)"
|
| 17 |
]
|
| 18 |
|
| 19 |
|