gkijkul commited on
Commit
6f25479
·
1 Parent(s): 6c4e5d3

more graph setup, comments on how the flow should work

Browse files
src/chatbot/graph_manager.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # graph_manager.py
2
+ # graph definition and compilation
3
+
4
+ from langgraph.graph import StateGraph, START, END
5
+ from langgraph.checkpoint.memory import MemorySaver
6
+ from nodes import (
7
+ State,
8
+ speaker_chatbot
9
+ )
10
+ from form_management.form_management import Form
11
+ import sys
12
+
13
+ class GraphManager:
14
+ def __init__(self):
15
+ self.graph = None
16
+ self.init_graph()
17
+
18
+ # user input ->
19
+ # if user intent is clear -> speaker form management -> audit review and feedback -> speaker generate response based on feedback
20
+ # if user intent is unclear -> speaker generate response for clarification
21
+ def init_graph(self):
22
+ graph_memory = MemorySaver()
23
+ graph_builder = StateGraph(State)
24
+
25
+ graph_builder.add_node("speaker_chatbot", speaker_chatbot)
26
+
27
+ graph_builder.add_edge(START, "speaker_chatbot")
28
+ graph_builder.add_edge("speaker_chatbot", END)
29
+
30
+ self.graph = graph_builder.compile(checkpointer=graph_memory)
31
+
32
+ def process_user_input(self, user_input, session_id):
33
+ """
34
+ Process the user input and update the graph state.
35
+ """
36
+ config = {"configurable": {"thread_id": session_id}}
37
+
38
+ input_dict = {
39
+ "messages": [("user", user_input)],
40
+ "session_id": session_id
41
+ }
42
+
43
+ for event in self.graph.stream(
44
+ input_dict,
45
+ config,
46
+ stream_mode="values"
47
+ ):
48
+ event["messages"][-1].pretty_print()
49
+
50
+ snapshot = self.graph.get_state(config)
51
+ return snapshot[0]["messages"][-1].content
52
+
53
+ graph_manager = GraphManager()
54
+
55
+ # Example usage
56
+ response = graph_manager.process_user_input(
57
+ "hi",
58
+ 1
59
+ )
src/chatbot/nodes.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nodes.py
2
+
3
+ from typing import Annotated
4
+ from typing_extensions import TypedDict
5
+ from langgraph.graph.message import add_messages
6
+ from langchain.tools import tool
7
+ from langgraph.prebuilt import create_react_agent
8
+ from langgraph.checkpoint.memory import MemorySaver
9
+
10
+ from agent_llm_engine import llm_overall_agent
11
+ from prompts.agent_prompts import speaker_system_message, auditor_system_message
12
+ from form_management.form_management import Form
13
+
14
+ # placeholder form
15
+ form = Form()
16
+
17
+ class State(TypedDict):
18
+ messages: Annotated[list, add_messages]
19
+ session_id: str
20
+
21
+ @tool
22
+ def placeholder_tool(input: str) -> str:
23
+ """
24
+ Placeholder tool that does nothing.
25
+ """
26
+ return "This is a placeholder tool."
27
+
28
+ tools = [placeholder_tool]
29
+
30
+ speaker_memory = MemorySaver()
31
+ speaker_agent = create_react_agent(
32
+ llm_overall_agent,
33
+ tools,
34
+ state_modifier=speaker_system_message,
35
+ checkpointer=speaker_memory
36
+ )
37
+
38
+ def speaker_chatbot(state: State):
39
+ config = {"configurable": {"thread_id": state["session_id"]}}
40
+
41
+ response = speaker_agent.invoke(
42
+ {
43
+ "messages": [
44
+ (
45
+ "user", "{}".format(state["messages"][-1].content)
46
+ )
47
+ ]
48
+ },
49
+ config,
50
+ )["messages"][-1]
51
+
52
+ return {
53
+ "messages": response.content,
54
+ "session_id": state["session_id"]
55
+ }
56
+
57
+ auditor_memory = MemorySaver()
58
+ auditor_agent = create_react_agent(
59
+ llm_overall_agent,
60
+ tools,
61
+ state_modifier=auditor_system_message,
62
+ checkpointer=speaker_memory
63
+ )
64
+
65
+ def auditor_feedback(state: State):
66
+ # get form based on session id, but for now just use placeholder
67
+ form_info = form.get_info_as_dict()
68
+
69
+ # auditor looks at form info and acceptance criteria and gives speaker feedback
70
+
71
+
72
+ pass
src/chatbot/prompts/agent_prompts.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ speaker_system_message = """
2
+ You are an assistant that helps users submit receipts and fill in a form.
3
+ Your job is to communicate with the user and determine which tools to use given the user input.
4
+ """
5
+
6
+ auditor_system_message = """
7
+ """
src/form_management/acceptance_criteria.json DELETED
@@ -1,32 +0,0 @@
1
- {
2
- "fields": [
3
- {
4
- "fieldName": "Project Name",
5
- "priority": 1,
6
- "description": "The name of the project or purpose for the payment.",
7
- "validationRules": [
8
- "Must be associated with environmental or sustainability projects."
9
- ],
10
- "exampleInput": "โครงการ \"การจัดการคาร์บอนเครดิตในป่าเพื่อการพัฒนาที่ยั่งยืน\""
11
- },
12
- {
13
- "fieldName": "Date",
14
- "priority": 2,
15
- "description": "The date of the transaction.",
16
- "validationRules": [
17
- "Must be a date that is in the past or present."
18
- ],
19
- "exampleInput": "3 เมษายน 2567"
20
- },
21
- {
22
- "fieldName": "Total Payment Amount",
23
- "priority": 3,
24
- "description": "The total amount paid.",
25
- "validationRules": [
26
- "Must be a positive amount."
27
- ],
28
- "exampleInput": "3 เมษายน 2567"
29
- }
30
- ]
31
- }
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/form_management/form_management.py DELETED
@@ -1,72 +0,0 @@
1
- class Form:
2
- class PaymentDetailsTable:
3
- def __init__(self, column_headers, row_entries):
4
- self.column_headers = column_headers
5
- self.row_entries = row_entries
6
-
7
- def __init__(self, receipt_title, project_name, date, receipt_number, payer_name, payment_details_table, total_payment_amount, incompleteness_description):
8
- self.receipt_title = receipt_title
9
- self.project_name = project_name
10
- self.date = date
11
- self.receipt_number = receipt_number
12
- self.payer_name = payer_name
13
- self.payment_details_table = self.PaymentDetailsTable(**payment_details_table)
14
- self.total_payment_amount = total_payment_amount
15
- self.incompleteness_description = incompleteness_description
16
-
17
- # Getter and Setter for receipt_title
18
- def get_receipt_title(self):
19
- return self.receipt_title
20
-
21
- def set_receipt_title(self, receipt_title):
22
- self.receipt_title = receipt_title
23
-
24
- # Getter and Setter for project_name
25
- def get_project_name(self):
26
- return self.project_name
27
-
28
- def set_project_name(self, project_name):
29
- self.project_name = project_name
30
-
31
- # Getter and Setter for date
32
- def get_date(self):
33
- return self.date
34
-
35
- def set_date(self, date):
36
- self.date = date
37
-
38
- # Getter and Setter for receipt_number
39
- def get_receipt_number(self):
40
- return self.receipt_number
41
-
42
- def set_receipt_number(self, receipt_number):
43
- self.receipt_number = receipt_number
44
-
45
- # Getter and Setter for payer_name
46
- def get_payer_name(self):
47
- return self.payer_name
48
-
49
- def set_payer_name(self, payer_name):
50
- self.payer_name = payer_name
51
-
52
- # Getter and Setter for payment_details_table
53
- def get_payment_details_table(self):
54
- return self.payment_details_table
55
-
56
- def set_payment_details_table(self, payment_details_table):
57
- self.payment_details_table = self.PaymentDetailsTable(**payment_details_table)
58
-
59
- # Getter and Setter for total_payment_amount
60
- def get_total_payment_amount(self):
61
- return self.total_payment_amount
62
-
63
- def set_total_payment_amount(self, total_payment_amount):
64
- self.total_payment_amount = total_payment_amount
65
-
66
- # Getter and Setter for incompleteness_description
67
- def get_incompleteness_description(self):
68
- return self.incompleteness_description
69
-
70
- def set_incompleteness_description(self, incompleteness_description):
71
- self.incompleteness_description = incompleteness_description
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/main.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  def main():
2
  print("Hello, World!")
3
 
 
1
+ # from chatbot.graph_manager
2
+
3
  def main():
4
  print("Hello, World!")
5
 
src/poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
src/pyproject.toml CHANGED
@@ -6,11 +6,14 @@ authors = [
6
  {name = "suratan.boon",email = "suratan.boon@cjexpress.co.th"}
7
  ]
8
  readme = "README.md"
9
- requires-python = ">=3.11"
10
  dependencies = [
11
  "gradio (>=5.28.0,<6.0.0)",
12
  "openai (>=1.76.2,<2.0.0)",
13
- "dotenv (>=0.9.9,<0.10.0)"
 
 
 
14
  ]
15
 
16
 
 
6
  {name = "suratan.boon",email = "suratan.boon@cjexpress.co.th"}
7
  ]
8
  readme = "README.md"
9
+ requires-python = ">=3.12.4,<4.0.0"
10
  dependencies = [
11
  "gradio (>=5.28.0,<6.0.0)",
12
  "openai (>=1.76.2,<2.0.0)",
13
+ "dotenv (>=0.9.9,<0.10.0)",
14
+ "langgraph (>=0.4.1,<0.5.0)",
15
+ "langchain (>=0.3.24,<0.4.0)",
16
+ "langchain-openai (>=0.3.14,<0.4.0)"
17
  ]
18
 
19