Seth0330 commited on
Commit
8b91f01
·
verified ·
1 Parent(s): 644617d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -50,21 +50,42 @@ if uploaded_files:
50
  except Exception as e:
51
  st.sidebar.error(f"Error reading {f.name}: {e}")
52
 
 
53
  system_message = {
54
  "role": "system",
55
  "content": (
56
- "You are an AI data analyst for the uploaded JSON files. "
57
- "Each file may have a different structure and set of keys. "
58
- "When the user asks a question about all data (like 'How many females are there?'), "
59
- "use the 'search_all_jsons' function to search through every uploaded file recursively. "
60
- "Return results with the count and details, grouped by file if needed. "
61
- "If the question is specific to one file, use the appropriate function for that file."
 
 
 
62
  )
63
  }
64
  st.session_state.messages = [system_message]
65
  else:
66
  st.session_state.json_data.clear()
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def search_json(file_name, key, value):
69
  def recursive_search(obj, key, value, found):
70
  if isinstance(obj, dict):
@@ -107,22 +128,7 @@ def count_key_occurrences(file_name, key):
107
  except Exception as e:
108
  return {"error": str(e)}
109
 
110
- def search_all_jsons(key, value):
111
- """Search all JSON files for dicts with key==value, recursively."""
112
- found = []
113
- for file_name, data in st.session_state.json_data.items():
114
- def recursive_search(obj):
115
- if isinstance(obj, dict):
116
- if key in obj and str(obj[key]).lower() == str(value).lower():
117
- found.append({**obj, "__file__": file_name})
118
- for v in obj.values():
119
- recursive_search(v)
120
- elif isinstance(obj, list):
121
- for item in obj:
122
- recursive_search(item)
123
- recursive_search(data)
124
- return found
125
-
126
  function_schema = [
127
  {
128
  "name": "search_json",
@@ -174,6 +180,7 @@ function_schema = [
174
  }
175
  ]
176
 
 
177
  st.markdown("### Conversation")
178
  for i, msg in enumerate(st.session_state.messages[1:]):
179
  if msg["role"] == "user":
@@ -191,6 +198,7 @@ for i, msg in enumerate(st.session_state.messages[1:]):
191
  except Exception:
192
  st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
193
 
 
194
  def send_message():
195
  user_input = st.session_state.temp_input
196
  if user_input and user_input.strip():
@@ -205,7 +213,7 @@ def send_message():
205
  "https://api.openai.com/v1/chat/completions",
206
  headers=HEADERS,
207
  json={
208
- "model": "gpt-4.1",
209
  "messages": chat_messages,
210
  "functions": function_schema,
211
  "function_call": "auto",
@@ -248,7 +256,7 @@ def send_message():
248
  "https://api.openai.com/v1/chat/completions",
249
  headers=HEADERS,
250
  json={
251
- "model": "gpt-4.1",
252
  "messages": followup_messages,
253
  "temperature": 0,
254
  "max_tokens": 1500,
@@ -269,5 +277,10 @@ def send_message():
269
 
270
  if st.session_state.json_data:
271
  st.text_input("Your message:", key="temp_input", on_change=send_message)
 
 
 
 
 
272
  else:
273
  st.info("Please upload at least one JSON file to start chatting.")
 
50
  except Exception as e:
51
  st.sidebar.error(f"Error reading {f.name}: {e}")
52
 
53
+ # --- System prompt with explicit example ---
54
  system_message = {
55
  "role": "system",
56
  "content": (
57
+ "You are an AI data analyst for uploaded JSON files. "
58
+ "Each file may have different structures and keys, including lists and nested dictionaries. "
59
+ "You have access to a function 'search_all_jsons' that finds all records in all JSON files where a key matches a value, recursively. "
60
+ "If a user asks about groups of people or wants to know counts such as 'How many females are there?', "
61
+ "interpret this as 'search for all records where gender equals female'. "
62
+ "Always use the 'search_all_jsons' function with key='gender' and value='female' for such queries, unless another key/value is clear from context. "
63
+ "Example:\n"
64
+ "User: How many females are there?\n"
65
+ "Assistant: (Call search_all_jsons with key='gender', value='female')"
66
  )
67
  }
68
  st.session_state.messages = [system_message]
69
  else:
70
  st.session_state.json_data.clear()
71
 
72
+ # --- Recursive search for key/value in all files ---
73
+ def search_all_jsons(key, value):
74
+ found = []
75
+ for file_name, data in st.session_state.json_data.items():
76
+ def recursive_search(obj):
77
+ if isinstance(obj, dict):
78
+ if key in obj and str(obj[key]).lower() == str(value).lower():
79
+ found.append({**obj, "__file__": file_name})
80
+ for v in obj.values():
81
+ recursive_search(v)
82
+ elif isinstance(obj, list):
83
+ for item in obj:
84
+ recursive_search(item)
85
+ recursive_search(data)
86
+ return found
87
+
88
+ # --- Other function stubs (for completeness) ---
89
  def search_json(file_name, key, value):
90
  def recursive_search(obj, key, value, found):
91
  if isinstance(obj, dict):
 
128
  except Exception as e:
129
  return {"error": str(e)}
130
 
131
+ # --- Function schema including the new all-files search ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  function_schema = [
133
  {
134
  "name": "search_json",
 
180
  }
181
  ]
182
 
183
+ # --- Conversation UI ---
184
  st.markdown("### Conversation")
185
  for i, msg in enumerate(st.session_state.messages[1:]):
186
  if msg["role"] == "user":
 
198
  except Exception:
199
  st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
200
 
201
+ # --- Chat input and OpenAI handling ---
202
  def send_message():
203
  user_input = st.session_state.temp_input
204
  if user_input and user_input.strip():
 
213
  "https://api.openai.com/v1/chat/completions",
214
  headers=HEADERS,
215
  json={
216
+ "model": "gpt-4o",
217
  "messages": chat_messages,
218
  "functions": function_schema,
219
  "function_call": "auto",
 
256
  "https://api.openai.com/v1/chat/completions",
257
  headers=HEADERS,
258
  json={
259
+ "model": "gpt-4o",
260
  "messages": followup_messages,
261
  "temperature": 0,
262
  "max_tokens": 1500,
 
277
 
278
  if st.session_state.json_data:
279
  st.text_input("Your message:", key="temp_input", on_change=send_message)
280
+ # --- Manual debug/test button ---
281
+ if st.button("Test: Count females in all JSONs"):
282
+ results = search_all_jsons("gender", "female")
283
+ st.write(f"Females found: {len(results)}")
284
+ st.json(results)
285
  else:
286
  st.info("Please upload at least one JSON file to start chatting.")