Seth0330 commited on
Commit
3c83ca9
·
verified ·
1 Parent(s): 5aa8f9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -81
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import streamlit as st
2
- import pandas as pd
3
  import os
4
- import requests
5
  import json
 
6
 
7
  # --- Page config
8
- st.set_page_config(page_title="CSV-Backed AI Chat Agent", layout="wide")
9
 
10
- # --- Title & image
11
- st.title("CSV-Backed AI Chat Agent")
12
 
13
  # --- Load API key
14
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -21,102 +19,140 @@ HEADERS = {
21
  "Content-Type": "application/json",
22
  }
23
 
24
- # --- Sidebar: CSV upload & preview
25
- st.sidebar.header("Upload CSV File")
26
- uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type="csv")
 
 
27
 
28
- # --- Conversation memory: Use Streamlit session state
 
 
29
  if "messages" not in st.session_state:
30
  st.session_state.messages = []
31
  if "temp_input" not in st.session_state:
32
  st.session_state.temp_input = ""
33
 
34
- # --- Only load df and reset chat on new file upload
35
- if uploaded_file is not None:
36
- try:
37
- df = pd.read_csv(uploaded_file)
38
- st.sidebar.success("File uploaded successfully!")
39
- st.sidebar.write("Preview of the uploaded file:")
40
- st.sidebar.dataframe(df.head())
41
- columns = ", ".join(df.columns)
42
- system_message = {
43
- "role": "system",
44
- "content": (
45
- f"You are an AI data analyst for a CSV file with these columns: {columns}. "
46
- "When the user asks a question, always use the most relevant function to get the answer directly. "
47
- "Do not describe your plan or reasoning steps. Do not ask the user for clarification. "
48
- "Just call the function needed and give the answer, as briefly as possible. "
49
- "If you need to search or filter the CSV, use the 'search_csv' function. "
50
- "If you need to count unique values, use the 'count_unique' function. "
51
- "If you use 'search_csv', use Pandas query syntax."
52
- ),
53
- }
54
- # Only reset memory on new file load
55
- if not st.session_state.messages or (
56
- st.session_state.messages and
57
- ("system" not in st.session_state.messages[0].get("role", ""))
58
- ):
59
- st.session_state.messages = [system_message]
60
- elif (
61
- st.session_state.messages and
62
- st.session_state.messages[0].get("role", "") == "system" and
63
- st.session_state.messages[0].get("content", "") != system_message["content"]
64
- ):
65
- st.session_state.messages[0] = system_message
66
- except Exception as e:
67
- st.sidebar.error(f"Error reading file: {e}")
68
- df = None
69
  else:
70
- df = None
71
 
72
- if df is not None:
73
- st.markdown(f"**Loaded CSV:** {df.shape[0]} rows × {df.shape[1]} columns")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # --- Functions for function calling
76
- def search_csv(query: str):
77
  try:
78
- result_df = df.query(query)
79
- return result_df.head(10).to_dict(orient="records") # limit for safety
 
 
 
 
 
80
  except Exception as e:
81
- return {"error": f"Invalid query. Example: 'price > 100'. Details: {str(e)}"}
82
 
83
- def count_unique(column: str):
 
84
  try:
85
- n = df[column].nunique()
86
- return {"column": column, "unique_count": int(n)}
 
 
 
 
 
87
  except Exception as e:
88
- return {"error": f"Column '{column}' not found or not countable. Details: {str(e)}"}
89
 
90
  # --- Function schemas for OpenAI
91
  function_schema = [
92
  {
93
- "name": "search_csv",
94
- "description": "Filter the CSV rows by a Pandas query. Example: price > 100",
95
  "parameters": {
96
  "type": "object",
97
  "properties": {
98
- "query": {
99
- "type": "string",
100
- "description": "A Pandas query string, e.g. 'price > 100 and city == \"Miami\"'"
101
- },
102
  },
103
- "required": ["query"],
104
  },
105
  },
106
  {
107
- "name": "count_unique",
108
- "description": "Count the number of unique values in a column.",
109
  "parameters": {
110
  "type": "object",
111
  "properties": {
112
- "column": {
113
- "type": "string",
114
- "description": "The column name to count unique values."
115
- },
116
  },
117
- "required": ["column"],
118
  },
119
- }
 
 
 
 
 
 
 
 
 
 
 
 
120
  ]
121
 
122
  # --- Chat interface
@@ -138,18 +174,18 @@ def send_message():
138
  user_input = st.session_state.temp_input
139
  if user_input and user_input.strip():
140
  st.session_state.messages.append({"role": "user", "content": user_input})
141
- # Limit history for context size (keep system + last 8)
142
  chat_messages = st.session_state.messages
143
  if len(chat_messages) > 10:
144
  chat_messages = [chat_messages[0]] + chat_messages[-9:]
145
  else:
146
  chat_messages = chat_messages.copy()
147
- # First OpenAI call: Check for function call
148
  chat_resp = requests.post(
149
  "https://api.openai.com/v1/chat/completions",
150
  headers=HEADERS,
151
  json={
152
- "model": "gpt-4.1",
153
  "messages": chat_messages,
154
  "functions": function_schema,
155
  "function_call": "auto",
@@ -167,11 +203,12 @@ def send_message():
167
  func_name = msg["function_call"]["name"]
168
  args_json = msg["function_call"]["arguments"]
169
  args = json.loads(args_json)
170
- # --- FIXED: Only pass the expected arg for each function
171
- if func_name == "search_csv":
172
- function_result = search_csv(args.get("query", ""))
173
- elif func_name == "count_unique":
174
- function_result = count_unique(args.get("column", ""))
 
175
  else:
176
  function_result = {"error": f"Unknown function: {func_name}"}
177
  st.session_state.messages.append({
@@ -179,7 +216,7 @@ def send_message():
179
  "name": func_name,
180
  "content": json.dumps(function_result),
181
  })
182
- # Limit history again for second call
183
  followup_messages = st.session_state.messages
184
  if len(followup_messages) > 12:
185
  followup_messages = [followup_messages[0]] + followup_messages[-11:]
@@ -204,5 +241,8 @@ def send_message():
204
 
205
  st.session_state.temp_input = ""
206
 
207
- if df is not None:
208
  st.text_input("Your message:", key="temp_input", on_change=send_message)
 
 
 
 
1
  import streamlit as st
 
2
  import os
 
3
  import json
4
+ import requests
5
 
6
  # --- Page config
7
+ st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide")
8
 
9
+ st.title("JSON-Backed AI Chat Agent")
 
10
 
11
  # --- Load API key
12
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
19
  "Content-Type": "application/json",
20
  }
21
 
22
+ # --- Sidebar: Multiple JSON upload & preview
23
+ st.sidebar.header("Upload Multiple JSON Files")
24
+ uploaded_files = st.sidebar.file_uploader(
25
+ "Choose one or more JSON files", type="json", accept_multiple_files=True
26
+ )
27
 
28
+ # --- Session State for data and chat
29
+ if "json_data" not in st.session_state:
30
+ st.session_state.json_data = {}
31
  if "messages" not in st.session_state:
32
  st.session_state.messages = []
33
  if "temp_input" not in st.session_state:
34
  st.session_state.temp_input = ""
35
 
36
+ # --- Load all JSON files
37
+ if uploaded_files:
38
+ st.session_state.json_data.clear()
39
+ file_summaries = []
40
+ for f in uploaded_files:
41
+ try:
42
+ content = json.load(f)
43
+ st.session_state.json_data[f.name] = content
44
+ # For summary in system prompt
45
+ if isinstance(content, dict):
46
+ keys = list(content.keys())
47
+ elif isinstance(content, list) and content and isinstance(content[0], dict):
48
+ keys = list(content[0].keys())
49
+ else:
50
+ keys = []
51
+ file_summaries.append(f"{f.name}: keys={keys[:10]}{'...' if len(keys)>10 else ''}")
52
+ st.sidebar.success(f"Loaded: {f.name}")
53
+ st.sidebar.write(f"Keys: {keys[:10]}{'...' if len(keys)>10 else ''}")
54
+ except Exception as e:
55
+ st.sidebar.error(f"Error reading {f.name}: {e}")
56
+
57
+ # Compose system prompt for the LLM
58
+ system_message = {
59
+ "role": "system",
60
+ "content": (
61
+ "You are an AI data analyst for the following JSON files:\n" +
62
+ "\n".join(file_summaries) +
63
+ "\nEach file may have a different structure and set of keys. "
64
+ "When the user asks a question, identify which file(s) it applies to, "
65
+ "then use the most relevant function to extract the answer. "
66
+ "If the user does not specify a file, make your best guess based on keys/fields mentioned."
67
+ ),
68
+ }
69
+ # Reset chat if new files loaded
70
+ st.session_state.messages = [system_message]
71
  else:
72
+ st.session_state.json_data.clear()
73
 
74
+ # --- Functions for querying JSON files
75
+ def search_json(file_name, key, value):
76
+ """Return all records in the given JSON file (list of dicts) where key == value."""
77
+ try:
78
+ data = st.session_state.json_data[file_name]
79
+ if isinstance(data, list):
80
+ results = [item for item in data if isinstance(item, dict) and item.get(key) == value]
81
+ return results[:10]
82
+ elif isinstance(data, dict):
83
+ if key in data and data[key] == value:
84
+ return [{key: value}]
85
+ else:
86
+ return []
87
+ else:
88
+ return []
89
+ except Exception as e:
90
+ return {"error": str(e)}
91
 
92
+ def list_keys(file_name):
93
+ """Return all top-level keys of the JSON file."""
94
  try:
95
+ data = st.session_state.json_data[file_name]
96
+ if isinstance(data, dict):
97
+ return list(data.keys())
98
+ elif isinstance(data, list) and data and isinstance(data[0], dict):
99
+ return list(data[0].keys())
100
+ else:
101
+ return []
102
  except Exception as e:
103
+ return {"error": str(e)}
104
 
105
+ def count_key_occurrences(file_name, key):
106
+ """Count number of occurrences of a given key in the JSON file."""
107
  try:
108
+ data = st.session_state.json_data[file_name]
109
+ if isinstance(data, dict):
110
+ return 1 if key in data else 0
111
+ elif isinstance(data, list):
112
+ return sum(1 for item in data if isinstance(item, dict) and key in item)
113
+ else:
114
+ return 0
115
  except Exception as e:
116
+ return {"error": str(e)}
117
 
118
  # --- Function schemas for OpenAI
119
  function_schema = [
120
  {
121
+ "name": "search_json",
122
+ "description": "Find records in the specified JSON file where key matches a given value.",
123
  "parameters": {
124
  "type": "object",
125
  "properties": {
126
+ "file_name": {"type": "string", "description": "The uploaded JSON file to search."},
127
+ "key": {"type": "string", "description": "The key/field to filter by."},
128
+ "value": {"type": "string", "description": "The value to match."}
 
129
  },
130
+ "required": ["file_name", "key", "value"],
131
  },
132
  },
133
  {
134
+ "name": "list_keys",
135
+ "description": "List all top-level keys in a given JSON file.",
136
  "parameters": {
137
  "type": "object",
138
  "properties": {
139
+ "file_name": {"type": "string", "description": "The uploaded JSON file."},
 
 
 
140
  },
141
+ "required": ["file_name"],
142
  },
143
+ },
144
+ {
145
+ "name": "count_key_occurrences",
146
+ "description": "Count the number of times a given key appears in a JSON file.",
147
+ "parameters": {
148
+ "type": "object",
149
+ "properties": {
150
+ "file_name": {"type": "string", "description": "The uploaded JSON file."},
151
+ "key": {"type": "string", "description": "The key to count."},
152
+ },
153
+ "required": ["file_name", "key"],
154
+ },
155
+ },
156
  ]
157
 
158
  # --- Chat interface
 
174
  user_input = st.session_state.temp_input
175
  if user_input and user_input.strip():
176
  st.session_state.messages.append({"role": "user", "content": user_input})
177
+ # Limit history for context size
178
  chat_messages = st.session_state.messages
179
  if len(chat_messages) > 10:
180
  chat_messages = [chat_messages[0]] + chat_messages[-9:]
181
  else:
182
  chat_messages = chat_messages.copy()
183
+ # OpenAI call
184
  chat_resp = requests.post(
185
  "https://api.openai.com/v1/chat/completions",
186
  headers=HEADERS,
187
  json={
188
+ "model": "gpt-4.1", # Use latest available model for this purpose
189
  "messages": chat_messages,
190
  "functions": function_schema,
191
  "function_call": "auto",
 
203
  func_name = msg["function_call"]["name"]
204
  args_json = msg["function_call"]["arguments"]
205
  args = json.loads(args_json)
206
+ if func_name == "search_json":
207
+ function_result = search_json(args.get("file_name"), args.get("key"), args.get("value"))
208
+ elif func_name == "list_keys":
209
+ function_result = list_keys(args.get("file_name"))
210
+ elif func_name == "count_key_occurrences":
211
+ function_result = count_key_occurrences(args.get("file_name"), args.get("key"))
212
  else:
213
  function_result = {"error": f"Unknown function: {func_name}"}
214
  st.session_state.messages.append({
 
216
  "name": func_name,
217
  "content": json.dumps(function_result),
218
  })
219
+ # Second call to OpenAI for the final answer
220
  followup_messages = st.session_state.messages
221
  if len(followup_messages) > 12:
222
  followup_messages = [followup_messages[0]] + followup_messages[-11:]
 
241
 
242
  st.session_state.temp_input = ""
243
 
244
+ if st.session_state.json_data:
245
  st.text_input("Your message:", key="temp_input", on_change=send_message)
246
+ else:
247
+ st.info("Please upload at least one JSON file to start chatting.")
248
+