Seth0330 commited on
Commit
02564e8
·
verified ·
1 Parent(s): 18f7a84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -51
app.py CHANGED
@@ -4,6 +4,16 @@ import json
4
  import requests
5
  import traceback
6
 
 
 
 
 
 
 
 
 
 
 
7
  # --- Page config
8
  st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide")
9
  st.title("JSON-Backed AI Chat Agent")
@@ -24,14 +34,8 @@ uploaded_files = st.sidebar.file_uploader(
24
  "Choose one or more JSON files", type="json", accept_multiple_files=True
25
  )
26
 
27
- if "json_data" not in st.session_state:
28
- st.session_state.json_data = {}
29
- if "messages" not in st.session_state:
30
- st.session_state.messages = []
31
- if "temp_input" not in st.session_state:
32
- st.session_state.temp_input = ""
33
-
34
- if uploaded_files:
35
  st.session_state.json_data.clear()
36
  file_summaries = []
37
  for f in uploaded_files:
@@ -50,34 +54,35 @@ if uploaded_files:
50
  except Exception as e:
51
  st.sidebar.error(f"Error reading {f.name}: {e}")
52
 
53
- # --- System prompt with explicit example ---
54
  system_message = {
55
- "role": "system",
56
- "content": (
57
- "You are an AI data analyst for uploaded JSON files. "
58
- "Each file may have different structures and keys, including lists and nested dictionaries. "
59
- "You have access to a function 'search_all_jsons' that finds all records in all JSON files where a key matches a value, recursively. "
60
- "If a user asks about groups of people or wants to know counts such as 'How many females are there?', "
61
- "interpret this as 'search for all records where gender equals female'. "
62
- "Always use the 'search_all_jsons' function with key='gender' and value='female' for such queries, unless another key/value is clear from context. "
63
- "If someone asks 'How many males?', search for gender equals male, etc. "
64
- "EXAMPLES:\n"
65
- "User: How many females are there?\n"
66
- "Assistant: (Call search_all_jsons with key='gender', value='female')\n"
67
- "User: How many males are there?\n"
68
- "Assistant: (Call search_all_jsons with key='gender', value='male')\n"
69
- "User: Show all females\n"
70
- "Assistant: (Call search_all_jsons with key='gender', value='female')\n"
71
- "User: How many people are named Emily?\n"
72
- "Assistant: (Call search_all_jsons with key='firstName', value='Emily')"
73
- )
74
- }
75
-
76
  st.session_state.messages = [system_message]
77
- else:
 
78
  st.session_state.json_data.clear()
 
79
 
80
- # --- Recursive search for key/value in all files ---
81
  def search_all_jsons(key, value):
82
  found = []
83
  for file_name, data in st.session_state.json_data.items():
@@ -93,7 +98,7 @@ def search_all_jsons(key, value):
93
  recursive_search(data)
94
  return found
95
 
96
- # --- Other function stubs (for completeness) ---
97
  def search_json(file_name, key, value):
98
  def recursive_search(obj, key, value, found):
99
  if isinstance(obj, dict):
@@ -136,7 +141,7 @@ def count_key_occurrences(file_name, key):
136
  except Exception as e:
137
  return {"error": str(e)}
138
 
139
- # --- Function schema including the new all-files search ---
140
  function_schema = [
141
  {
142
  "name": "search_json",
@@ -188,7 +193,7 @@ function_schema = [
188
  }
189
  ]
190
 
191
- # --- Conversation UI ---
192
  st.markdown("### Conversation")
193
  for i, msg in enumerate(st.session_state.messages[1:]):
194
  if msg["role"] == "user":
@@ -206,17 +211,17 @@ for i, msg in enumerate(st.session_state.messages[1:]):
206
  except Exception:
207
  st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
208
 
209
- # --- Chat input and OpenAI handling ---
210
  def send_message():
211
- user_input = st.session_state.temp_input
212
- if user_input and user_input.strip():
213
- st.session_state.messages.append({"role": "user", "content": user_input})
214
- chat_messages = st.session_state.messages
215
- if len(chat_messages) > 10:
216
- chat_messages = [chat_messages[0]] + chat_messages[-9:]
217
- else:
218
- chat_messages = chat_messages.copy()
219
- try:
220
  chat_resp = requests.post(
221
  "https://api.openai.com/v1/chat/completions",
222
  headers=HEADERS,
@@ -277,15 +282,13 @@ def send_message():
277
  st.session_state.messages.append({"role": "assistant", "content": answer})
278
  else:
279
  st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
280
- except Exception as e:
281
- st.error("Exception: " + str(e))
282
- st.code(traceback.format_exc())
283
- finally:
284
- st.session_state.temp_input = ""
285
 
286
  if st.session_state.json_data:
287
  st.text_input("Your message:", key="temp_input", on_change=send_message)
288
- # --- Manual debug/test button ---
289
  if st.button("Test: Count females in all JSONs"):
290
  results = search_all_jsons("gender", "female")
291
  st.write(f"Females found: {len(results)}")
 
4
  import requests
5
  import traceback
6
 
7
+ # --- ALWAYS INIT SESSION STATE FIRST (before any widgets)
8
+ if "json_data" not in st.session_state:
9
+ st.session_state.json_data = {}
10
+ if "messages" not in st.session_state:
11
+ st.session_state.messages = []
12
+ if "temp_input" not in st.session_state:
13
+ st.session_state.temp_input = ""
14
+ if "files_loaded" not in st.session_state:
15
+ st.session_state.files_loaded = False
16
+
17
  # --- Page config
18
  st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide")
19
  st.title("JSON-Backed AI Chat Agent")
 
34
  "Choose one or more JSON files", type="json", accept_multiple_files=True
35
  )
36
 
37
+ # --- Only clear/load state when files change
38
+ if uploaded_files and not st.session_state.files_loaded:
 
 
 
 
 
 
39
  st.session_state.json_data.clear()
40
  file_summaries = []
41
  for f in uploaded_files:
 
54
  except Exception as e:
55
  st.sidebar.error(f"Error reading {f.name}: {e}")
56
 
57
+ # --- System prompt with explicit few-shot examples
58
  system_message = {
59
+ "role": "system",
60
+ "content": (
61
+ "You are an AI data analyst for uploaded JSON files. "
62
+ "Each file may have different structures and keys, including lists and nested dictionaries. "
63
+ "You have access to a function 'search_all_jsons' that finds all records in all JSON files where a key matches a value, recursively. "
64
+ "If a user asks about groups of people or wants to know counts such as 'How many females are there?', "
65
+ "interpret this as 'search for all records where gender equals female'. "
66
+ "Always use the 'search_all_jsons' function with key='gender' and value='female' for such queries, unless another key/value is clear from context. "
67
+ "If someone asks 'How many males?', search for gender equals male, etc. "
68
+ "EXAMPLES:\n"
69
+ "User: How many females are there?\n"
70
+ "Assistant: (Call search_all_jsons with key='gender', value='female')\n"
71
+ "User: How many males are there?\n"
72
+ "Assistant: (Call search_all_jsons with key='gender', value='male')\n"
73
+ "User: Show all females\n"
74
+ "Assistant: (Call search_all_jsons with key='gender', value='female')\n"
75
+ "User: How many people are named Emily?\n"
76
+ "Assistant: (Call search_all_jsons with key='firstName', value='Emily')"
77
+ )
78
+ }
 
79
  st.session_state.messages = [system_message]
80
+ st.session_state.files_loaded = True
81
+ elif not uploaded_files:
82
  st.session_state.json_data.clear()
83
+ st.session_state.files_loaded = False
84
 
85
+ # --- Recursive search for key/value in all files
86
  def search_all_jsons(key, value):
87
  found = []
88
  for file_name, data in st.session_state.json_data.items():
 
98
  recursive_search(data)
99
  return found
100
 
101
+ # --- Other functions for LLM
102
  def search_json(file_name, key, value):
103
  def recursive_search(obj, key, value, found):
104
  if isinstance(obj, dict):
 
141
  except Exception as e:
142
  return {"error": str(e)}
143
 
144
+ # --- Function schema
145
  function_schema = [
146
  {
147
  "name": "search_json",
 
193
  }
194
  ]
195
 
196
+ # --- Conversation UI
197
  st.markdown("### Conversation")
198
  for i, msg in enumerate(st.session_state.messages[1:]):
199
  if msg["role"] == "user":
 
211
  except Exception:
212
  st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
213
 
214
+ # --- Chat input and OpenAI handling
215
  def send_message():
216
+ try:
217
+ user_input = st.session_state.temp_input
218
+ if user_input and user_input.strip():
219
+ st.session_state.messages.append({"role": "user", "content": user_input})
220
+ chat_messages = st.session_state.messages
221
+ if len(chat_messages) > 10:
222
+ chat_messages = [chat_messages[0]] + chat_messages[-9:]
223
+ else:
224
+ chat_messages = chat_messages.copy()
225
  chat_resp = requests.post(
226
  "https://api.openai.com/v1/chat/completions",
227
  headers=HEADERS,
 
282
  st.session_state.messages.append({"role": "assistant", "content": answer})
283
  else:
284
  st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
285
+ st.session_state.temp_input = ""
286
+ except Exception as e:
287
+ st.error("Exception: " + str(e))
288
+ st.code(traceback.format_exc())
 
289
 
290
  if st.session_state.json_data:
291
  st.text_input("Your message:", key="temp_input", on_change=send_message)
 
292
  if st.button("Test: Count females in all JSONs"):
293
  results = search_all_jsons("gender", "female")
294
  st.write(f"Females found: {len(results)}")