Focussy commited on
Commit
bc9d7e8
·
1 Parent(s): 3572fb3

fix/concurrent-form-info

Browse files
Files changed (2) hide show
  1. src/chatbot/app_interface.py +49 -34
  2. src/chatbot/nodes.py +41 -30
src/chatbot/app_interface.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from graph_manager import user_input_handler
3
- from nodes import form_info, get_form_info, set_form_info
4
  import ast
5
  import uuid
6
  import logging
@@ -14,7 +14,6 @@ def generate_session_id():
14
  session_id = str(uuid.uuid4())
15
  return session_id, f"Session ID: {session_id}"
16
 
17
- # CURRENT_SESSION_ID = generate_session_id()
18
  DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'database', 'hello_earth_data_2.db')
19
 
20
  def connect_db():
@@ -81,9 +80,9 @@ def is_form_complete(form_info_dict):
81
  return gr.update(interactive=True)
82
 
83
  def update_field_input(key):
84
- def inner(value):
85
- set_form_info(key, value)
86
- return get_form_info()
87
  return inner
88
 
89
  def update_form_information_handler(form_information):
@@ -113,8 +112,8 @@ def update_form_information_handler(form_information):
113
  total_payment_amount_input, expense_description_input
114
  )
115
 
116
- def handle_submit():
117
- form_data = get_form_info()
118
  required_fields = [
119
  "Associated Deliverable", "Seller Name", "Seller Address", "Seller Phone Number",
120
  "Buyer Name", "Buyer Address", "Transaction Date", "Total Payment Amount", "Expense Description"
@@ -139,15 +138,18 @@ def chat_handler(message,history,session_id):
139
  response = user_input_handler(message['text'], session_id)
140
  except Exception as e:
141
  raise e
142
- form_information = get_form_info()
143
  update_form_information_handler(form_information)
144
  return response, form_information
145
 
146
  def create_expense_register_page():
147
  with gr.Blocks() as page:
148
  session_id = gr.State("")
149
- form_information = gr.State(form_info)
150
- session_id_display = gr.Markdown("Session ID: Loading...")
 
 
 
151
  new_btn = gr.Button("Reset", variant="primary")
152
  with gr.Tabs():
153
  with gr.Tab("Chat"):
@@ -167,36 +169,36 @@ def create_expense_register_page():
167
  gr.Markdown("**Deliverable**")
168
  deliverable_dropdown = gr.Dropdown(choices=get_deliverable_titles(), filterable=False, container=False)
169
  gr.Markdown("**Seller Name**")
170
- seller_name_input = gr.Textbox(value=ast.literal_eval(form_info["Seller Name"]), label="", container=False)
171
  gr.Markdown("**Seller Address**")
172
- seller_address_input = gr.Textbox(value=ast.literal_eval(form_info["Seller Address"]), label="", container=False)
173
  gr.Markdown("**Seller Phone Number**")
174
- seller_phone_number_input = gr.Textbox(value=ast.literal_eval(form_info["Seller Phone Number"]), label="", container=False)
175
  gr.Markdown("**Buyer Name**")
176
- buyer_name_input = gr.Textbox(value=ast.literal_eval(form_info["Buyer Name"]), label="", container=False)
177
  gr.Markdown("**Buyer Address**")
178
- buyer_address_input = gr.Textbox(value=ast.literal_eval(form_info["Buyer Address"]), label="", container=False)
179
  gr.Markdown("**Transaction Date**")
180
  transaction_date_input = gr.DateTime(include_time=False, type="datetime", show_label=False)
181
  gr.Markdown("**Total Payment Amount**")
182
  with gr.Row():
183
- total_payment_amount_input = gr.Number(value=ast.literal_eval(form_info['Total Payment Amount']), label="", container=False)
184
  gr.Markdown("Baht")
185
  gr.Markdown("**Expense Description**")
186
- expense_description_input = gr.Textbox(value=ast.literal_eval(form_info["Expense Description"]), label="", container=False, lines=4)
187
 
188
  feedback_message = gr.Markdown("", visible=False)
189
  submit_btn = gr.Button("Submit", variant="primary")
190
 
191
- deliverable_dropdown.change(update_field_input("Associated Deliverable"), inputs=[deliverable_dropdown], outputs=[form_information],queue=False)
192
- seller_name_input.change(update_field_input("Seller Name"), inputs=[seller_name_input], outputs=[form_information],queue=False)
193
- seller_address_input.change(update_field_input("Seller Address"), inputs=[seller_address_input], outputs=[form_information],queue=False)
194
- seller_phone_number_input.change(update_field_input("Seller Phone Number"), inputs=[seller_phone_number_input], outputs=[form_information],queue=False)
195
- buyer_name_input.change(update_field_input("Buyer Name"), inputs=[buyer_name_input], outputs=[form_information],queue=False)
196
- buyer_address_input.change(update_field_input("Buyer Address"), inputs=[buyer_address_input], outputs=[form_information],queue=False)
197
- transaction_date_input.change(update_field_input("Transaction Date"), inputs=[transaction_date_input], outputs=[form_information],queue=False)
198
- total_payment_amount_input.change(update_field_input("Total Payment Amount"), inputs=[total_payment_amount_input], outputs=[form_information],queue=False)
199
- expense_description_input.change(update_field_input("Expense Description"), inputs=[expense_description_input], outputs=[form_information],queue=False)
200
 
201
  form_information.change(
202
  update_form_information_handler,
@@ -208,27 +210,40 @@ def create_expense_register_page():
208
  ],
209
  queue=False,
210
  )
211
- submit_btn.click(fn=handle_submit, inputs=[], outputs=[feedback_message])
212
 
213
  def reset_chat():
214
- new_session_id , new_session_id_display = generate_session_id()
215
  logger.info(f"Chat reset with new session ID: {new_session_id}")
216
- form_info.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  return (
218
  new_session_id,
219
  new_session_id_display,
220
- _chat_prefill, # ChatInterface.chatbot_value
221
- form_info, # State(form_info)
222
  "None", # deliverable_dropdown
223
  "", # seller_name_input
224
  "", # seller_address_input
225
  "", # seller_phone_number_input
226
  "", # buyer_name_input
227
  "", # buyer_address_input
228
- None, # transaction_date_input (a DateTime)
229
- None, # total_payment_amount_input (a Number)
230
  "", # expense_description_input
231
- gr.update(visible=False, value="") # hide/clear any feedback_message
232
  )
233
 
234
  new_btn.click(
 
1
  import gradio as gr
2
  from graph_manager import user_input_handler
3
+ from nodes import session_form_data, get_form_info, set_form_info
4
  import ast
5
  import uuid
6
  import logging
 
14
  session_id = str(uuid.uuid4())
15
  return session_id, f"Session ID: {session_id}"
16
 
 
17
  DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'database', 'hello_earth_data_2.db')
18
 
19
  def connect_db():
 
80
  return gr.update(interactive=True)
81
 
82
  def update_field_input(key):
83
+ def inner(session_id, value):
84
+ set_form_info(session_id, key, value)
85
+ return get_form_info(session_id)
86
  return inner
87
 
88
  def update_form_information_handler(form_information):
 
112
  total_payment_amount_input, expense_description_input
113
  )
114
 
115
+ def handle_submit(session_id):
116
+ form_data = get_form_info(session_id)
117
  required_fields = [
118
  "Associated Deliverable", "Seller Name", "Seller Address", "Seller Phone Number",
119
  "Buyer Name", "Buyer Address", "Transaction Date", "Total Payment Amount", "Expense Description"
 
138
  response = user_input_handler(message['text'], session_id)
139
  except Exception as e:
140
  raise e
141
+ form_information = get_form_info(session_id)
142
  update_form_information_handler(form_information)
143
  return response, form_information
144
 
145
  def create_expense_register_page():
146
  with gr.Blocks() as page:
147
  session_id = gr.State("")
148
+ form_information = gr.State({})
149
+
150
+ with gr.Row():
151
+ session_id_display = gr.Markdown("Session ID: Loading...")
152
+
153
  new_btn = gr.Button("Reset", variant="primary")
154
  with gr.Tabs():
155
  with gr.Tab("Chat"):
 
169
  gr.Markdown("**Deliverable**")
170
  deliverable_dropdown = gr.Dropdown(choices=get_deliverable_titles(), filterable=False, container=False)
171
  gr.Markdown("**Seller Name**")
172
+ seller_name_input = gr.Textbox(value=ast.literal_eval(session_form_data[session_id]["Seller Name"]), label="", container=False)
173
  gr.Markdown("**Seller Address**")
174
+ seller_address_input = gr.Textbox(value=ast.literal_eval(session_form_data[session_id]["Seller Address"]), label="", container=False)
175
  gr.Markdown("**Seller Phone Number**")
176
+ seller_phone_number_input = gr.Textbox(value=ast.literal_eval(session_form_data[session_id]["Seller Phone Number"]), label="", container=False)
177
  gr.Markdown("**Buyer Name**")
178
+ buyer_name_input = gr.Textbox(value=ast.literal_eval(session_form_data[session_id]["Buyer Name"]), label="", container=False)
179
  gr.Markdown("**Buyer Address**")
180
+ buyer_address_input = gr.Textbox(value=ast.literal_eval(session_form_data[session_id]["Buyer Address"]), label="", container=False)
181
  gr.Markdown("**Transaction Date**")
182
  transaction_date_input = gr.DateTime(include_time=False, type="datetime", show_label=False)
183
  gr.Markdown("**Total Payment Amount**")
184
  with gr.Row():
185
+ total_payment_amount_input = gr.Number(value=ast.literal_eval(session_form_data[session_id]['Total Payment Amount']), label="", container=False)
186
  gr.Markdown("Baht")
187
  gr.Markdown("**Expense Description**")
188
+ expense_description_input = gr.Textbox(value=ast.literal_eval(session_form_data[session_id]["Expense Description"]), label="", container=False, lines=4)
189
 
190
  feedback_message = gr.Markdown("", visible=False)
191
  submit_btn = gr.Button("Submit", variant="primary")
192
 
193
+ deliverable_dropdown.change(update_field_input("Associated Deliverable"), inputs=[session_id, deliverable_dropdown], outputs=[form_information],queue=False)
194
+ seller_name_input.change(update_field_input("Seller Name"), inputs=[session_id, seller_name_input], outputs=[form_information],queue=False)
195
+ seller_address_input.change(update_field_input("Seller Address"), inputs=[session_id, seller_address_input], outputs=[form_information],queue=False)
196
+ seller_phone_number_input.change(update_field_input("Seller Phone Number"), inputs=[session_id, seller_phone_number_input], outputs=[form_information],queue=False)
197
+ buyer_name_input.change(update_field_input("Buyer Name"), inputs=[session_id, buyer_name_input], outputs=[form_information],queue=False)
198
+ buyer_address_input.change(update_field_input("Buyer Address"), inputs=[session_id, buyer_address_input], outputs=[form_information],queue=False)
199
+ transaction_date_input.change(update_field_input("Transaction Date"), inputs=[session_id, transaction_date_input], outputs=[form_information],queue=False)
200
+ total_payment_amount_input.change(update_field_input("Total Payment Amount"), inputs=[session_id, total_payment_amount_input], outputs=[form_information],queue=False)
201
+ expense_description_input.change(update_field_input("Expense Description"), inputs=[session_id, expense_description_input], outputs=[form_information],queue=False)
202
 
203
  form_information.change(
204
  update_form_information_handler,
 
210
  ],
211
  queue=False,
212
  )
213
+ submit_btn.click(fn=handle_submit, inputs=[session_id], outputs=[feedback_message])
214
 
215
  def reset_chat():
216
+ new_session_id, new_session_id_display = generate_session_id()
217
  logger.info(f"Chat reset with new session ID: {new_session_id}")
218
+
219
+ # Reset the state dictionary for the new session
220
+ session_form_data[new_session_id] = {
221
+ "Associated Deliverable": "None",
222
+ "Seller Name": "",
223
+ "Seller Address": "",
224
+ "Seller Phone Number": "",
225
+ "Buyer Name": "",
226
+ "Buyer Address": "",
227
+ "Transaction Date": None,
228
+ "Total Payment Amount": None,
229
+ "Expense Description": "",
230
+ }
231
+
232
  return (
233
  new_session_id,
234
  new_session_id_display,
235
+ _chat_prefill,
236
+ session_form_data[new_session_id], # return the new form state
237
  "None", # deliverable_dropdown
238
  "", # seller_name_input
239
  "", # seller_address_input
240
  "", # seller_phone_number_input
241
  "", # buyer_name_input
242
  "", # buyer_address_input
243
+ None, # transaction_date_input
244
+ None, # total_payment_amount_input
245
  "", # expense_description_input
246
+ gr.update(visible=False, value="") # feedback_message
247
  )
248
 
249
  new_btn.click(
src/chatbot/nodes.py CHANGED
@@ -6,6 +6,7 @@ from langgraph.graph.message import add_messages
6
  from langchain.tools import tool
7
  from langgraph.prebuilt import create_react_agent, ToolNode
8
  from langgraph.checkpoint.memory import MemorySaver
 
9
 
10
  # from chatbot.llm_engine import llm_overall_agent, generate_expense_info_feedback
11
  from llm_engine import (
@@ -23,6 +24,7 @@ import ast
23
  import json
24
  import sqlite3
25
  import os
 
26
 
27
  # conn = sqlite3.connect(r'C:\Users\Gavin\Desktop\hello-earth\db\project_data.db')
28
  db_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'database', 'hello_earth_data_2.db')
@@ -30,22 +32,17 @@ conn = sqlite3.connect(db_path)
30
  cursor = conn.cursor()
31
 
32
  # placeholder, init empty form info
33
- form_info = {
34
- "Seller Name": 'None',
35
- "Seller Address": 'None',
36
- "Seller Phone Number": 'None',
37
- "Buyer Name": 'None',
38
- "Buyer Address": 'None',
39
- "Transaction Date": 'None',
40
- "Total Payment Amount": 'None',
41
- # "Payment Details Table": {
42
- # "Column Headers": [],
43
- # "Row Entries": [],
44
- # },
45
- # "Incompleteness Description": 'None',
46
- "Associated Deliverable": 'None',
47
- "Expense Description": 'None',
48
- }
49
 
50
  # define the project_deliverables from the db
51
  # cursor.execute("SELECT title FROM deliverables")
@@ -54,11 +51,14 @@ rows = cursor.fetchall()
54
  deliverable_titles = [row[0] for row in rows]
55
  print(f"deliverable_titles: {deliverable_titles}")
56
 
57
- def get_form_info():
58
- return form_info
59
 
60
- def set_form_info(key, value):
61
- form_info[key] = value
 
 
 
62
 
63
  class State(TypedDict):
64
  messages: Annotated[list, add_messages]
@@ -74,7 +74,7 @@ class State(TypedDict):
74
 
75
  # session_id: str might need to be added later when there are concurrent users
76
  @tool
77
- def apply_ocr(image_path: str) -> str:
78
  """
79
  Use this tool when the user provides an image path.
80
  Apply optical-character-recognition (OCR) to the image and initalize the form information.
@@ -82,10 +82,11 @@ def apply_ocr(image_path: str) -> str:
82
  """
83
  print("using apply_ocr tool")
84
 
 
 
85
  temp_form_info = receipt_kie(image_path)
86
 
87
- # print(type(temp_form_info))
88
- # print(temp_form_info)
89
 
90
  for key in temp_form_info.keys():
91
  if key in form_info:
@@ -109,7 +110,7 @@ def apply_ocr(image_path: str) -> str:
109
  """
110
 
111
  @tool
112
- def edit_form(key: str, value: str) -> str:
113
  """
114
  This is used to edit the form data.
115
  The key must be one of the keys in the form_info dictionary.
@@ -119,16 +120,20 @@ def edit_form(key: str, value: str) -> str:
119
  """
120
  print("using edit_form tool")
121
 
 
 
 
 
122
  if key in form_info:
123
  form_info[key] = value
124
  print(form_info)
125
  print(f"Updated {key} to {value}")
126
  if key == "Expense Description":
127
- feedback = expense_description_feedback(value)
128
  elif key == "Associated Deliverable":
129
- res = map_associated_deliverable(value, deliverable_titles) # llm function that maps the user text to a project deliverable
130
- if res == "No Matching Deliverable": # if no matching deliverable,
131
- form_info[key] = "None" # set the associated deliverable to None
132
  feedback = "A valid associated deliverable is required."
133
  else:
134
  form_info[key] = res
@@ -152,12 +157,18 @@ def edit_form(key: str, value: str) -> str:
152
  return f"Invalid key: {key}. No changes made."
153
 
154
  @tool
155
- def inspect_form() -> str:
156
  """
157
  This is used to inspect the form data.
158
  Use this when you need to address any inquiries the user might have about the current state of the form.
159
  """
160
  print("using inspect_form tool")
 
 
 
 
 
 
161
  return f"""
162
  <form_status>
163
  {form_info}
@@ -201,7 +212,7 @@ def auditor_feedback(form_info):
201
  response = generate_expense_info_feedback(acceptance_criteria, form_info)
202
  return response
203
 
204
- def expense_description_feedback(expense_description):
205
  response = generate_expense_description_feedback(form_info["Associated Deliverable"], expense_description)
206
  return response
207
 
 
6
  from langchain.tools import tool
7
  from langgraph.prebuilt import create_react_agent, ToolNode
8
  from langgraph.checkpoint.memory import MemorySaver
9
+ from langchain_core.runnables import RunnableConfig
10
 
11
  # from chatbot.llm_engine import llm_overall_agent, generate_expense_info_feedback
12
  from llm_engine import (
 
24
  import json
25
  import sqlite3
26
  import os
27
+ from collections import defaultdict
28
 
29
  # conn = sqlite3.connect(r'C:\Users\Gavin\Desktop\hello-earth\db\project_data.db')
30
  db_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'database', 'hello_earth_data_2.db')
 
32
  cursor = conn.cursor()
33
 
34
  # placeholder, init empty form info
35
+ session_form_data = defaultdict(lambda: {
36
+ "Seller Name": 'None',
37
+ "Seller Address": 'None',
38
+ "Seller Phone Number": 'None',
39
+ "Buyer Name": 'None',
40
+ "Buyer Address": 'None',
41
+ "Transaction Date": 'None',
42
+ "Total Payment Amount": 'None',
43
+ "Associated Deliverable": 'None',
44
+ "Expense Description": 'None',
45
+ })
 
 
 
 
 
46
 
47
  # define the project_deliverables from the db
48
  # cursor.execute("SELECT title FROM deliverables")
 
51
  deliverable_titles = [row[0] for row in rows]
52
  print(f"deliverable_titles: {deliverable_titles}")
53
 
54
+ def get_form_info(session_id):
55
+ return session_form_data[session_id]
56
 
57
+ def set_form_info(session_id, key, value):
58
+ print("set form info")
59
+ session_form_data[session_id][key] = value
60
+ print(session_id)
61
+ print(session_form_data[session_id])
62
 
63
  class State(TypedDict):
64
  messages: Annotated[list, add_messages]
 
74
 
75
  # session_id: str might need to be added later when there are concurrent users
76
  @tool
77
+ def apply_ocr(image_path: str, config: RunnableConfig) -> str:
78
  """
79
  Use this tool when the user provides an image path.
80
  Apply optical-character-recognition (OCR) to the image and initalize the form information.
 
82
  """
83
  print("using apply_ocr tool")
84
 
85
+ session_id = config["configurable"].get("thread_id")
86
+
87
  temp_form_info = receipt_kie(image_path)
88
 
89
+ form_info = session_form_data[session_id]
 
90
 
91
  for key in temp_form_info.keys():
92
  if key in form_info:
 
110
  """
111
 
112
  @tool
113
+ def edit_form(key: str, value: str, config: RunnableConfig) -> str:
114
  """
115
  This is used to edit the form data.
116
  The key must be one of the keys in the form_info dictionary.
 
120
  """
121
  print("using edit_form tool")
122
 
123
+ session_id = config["configurable"].get("thread_id")
124
+
125
+ form_info = session_form_data[session_id]
126
+
127
  if key in form_info:
128
  form_info[key] = value
129
  print(form_info)
130
  print(f"Updated {key} to {value}")
131
  if key == "Expense Description":
132
+ feedback = expense_description_feedback(form_info, value)
133
  elif key == "Associated Deliverable":
134
+ res = map_associated_deliverable(value, deliverable_titles)
135
+ if res == "No Matching Deliverable":
136
+ form_info[key] = "None"
137
  feedback = "A valid associated deliverable is required."
138
  else:
139
  form_info[key] = res
 
157
  return f"Invalid key: {key}. No changes made."
158
 
159
  @tool
160
+ def inspect_form(config: RunnableConfig) -> str:
161
  """
162
  This is used to inspect the form data.
163
  Use this when you need to address any inquiries the user might have about the current state of the form.
164
  """
165
  print("using inspect_form tool")
166
+
167
+ session_id = config["configurable"].get("thread_id")
168
+ form_info = session_form_data[session_id]
169
+ print("session_id:",session_id)
170
+ print("form_info:",form_info)
171
+
172
  return f"""
173
  <form_status>
174
  {form_info}
 
212
  response = generate_expense_info_feedback(acceptance_criteria, form_info)
213
  return response
214
 
215
+ def expense_description_feedback(form_info, expense_description):
216
  response = generate_expense_description_feedback(form_info["Associated Deliverable"], expense_description)
217
  return response
218