Seth0330 commited on
Commit
8a030d9
·
verified ·
1 Parent(s): 6a9dd45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -47
app.py CHANGED
@@ -26,12 +26,44 @@ HEADERS = {
26
  st.sidebar.header("Upload CSV File")
27
  uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type="csv")
28
 
29
- if uploaded_file:
 
 
 
 
 
 
 
30
  try:
31
  df = pd.read_csv(uploaded_file)
32
  st.sidebar.success("File uploaded successfully!")
33
  st.sidebar.write("Preview of the uploaded file:")
34
  st.sidebar.dataframe(df.head())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
  st.sidebar.error(f"Error reading file: {e}")
37
  df = None
@@ -45,8 +77,7 @@ if df is not None:
45
  def search_csv(query: str):
46
  try:
47
  result_df = df.query(query)
48
- # Limit output to 50 rows for large results
49
- return result_df.head(50).to_dict(orient="records")
50
  except Exception as e:
51
  return {"error": f"Invalid query. Example: 'price > 100'. Details: {str(e)}"}
52
 
@@ -89,45 +120,13 @@ function_schema = [
89
  }
90
  ]
91
 
92
- # --- Map function names to Python functions
93
  function_map = {
94
  "search_csv": search_csv,
95
  "count_unique": count_unique,
96
  }
97
 
98
- # --- Conversation memory: Use Streamlit session state
99
- if "messages" not in st.session_state:
100
- st.session_state.messages = []
101
-
102
- if "temp_input" not in st.session_state:
103
- st.session_state.temp_input = ""
104
-
105
- # If CSV is loaded, update the system prompt with current columns
106
- if df is not None:
107
- columns = ", ".join(df.columns)
108
- system_message = {
109
- "role": "system",
110
- "content": (
111
- f"You are an AI data analyst for a CSV file with these columns: {columns}. "
112
- "When the user asks a question, always use the most relevant function to get the answer directly. "
113
- "Do not describe your plan or reasoning steps. Do not ask the user for clarification. "
114
- "Just call the function needed and give the answer, as briefly as possible. "
115
- "If you need to search or filter the CSV, use the 'search_csv' function. "
116
- "If you need to count unique values, use the 'count_unique' function. "
117
- "If you use 'search_csv', use Pandas query syntax."
118
- ),
119
- }
120
-
121
- # Ensure the system message is always at the start and up-to-date
122
- if not st.session_state.messages or st.session_state.messages[0]["role"] != "system":
123
- st.session_state.messages.insert(0, system_message)
124
- else:
125
- st.session_state.messages[0] = system_message
126
-
127
  # --- Chat interface
128
  st.markdown("### Conversation")
129
-
130
- # Display chat history (like ChatGPT)
131
  for i, msg in enumerate(st.session_state.messages[1:]): # Skip system message for display
132
  if msg["role"] == "user":
133
  st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
@@ -145,10 +144,12 @@ def send_message():
145
  user_input = st.session_state.temp_input
146
  if user_input and user_input.strip():
147
  st.session_state.messages.append({"role": "user", "content": user_input})
148
-
149
- # Compose messages for OpenAI (entire chat history)
150
- chat_messages = st.session_state.messages.copy()
151
-
 
 
152
  # First OpenAI call: Check for function call
153
  chat_resp = requests.post(
154
  "https://api.openai.com/v1/chat/completions",
@@ -177,15 +178,17 @@ def send_message():
177
  function_result = function_map[func_name](**args)
178
  else:
179
  function_result = {"error": f"Unknown function: {func_name}"}
180
- # Append function call and output to history
181
  st.session_state.messages.append({
182
  "role": "function",
183
  "name": func_name,
184
  "content": json.dumps(function_result),
185
  })
186
-
187
- # Second OpenAI call: Get final answer with function result
188
- followup_messages = st.session_state.messages.copy()
 
 
 
189
  final_resp = requests.post(
190
  "https://api.openai.com/v1/chat/completions",
191
  headers=HEADERS,
@@ -199,15 +202,11 @@ def send_message():
199
  )
200
  final_resp.raise_for_status()
201
  answer = final_resp.json()["choices"][0]["message"]["content"]
202
- # Add assistant's reply to chat
203
  st.session_state.messages.append({"role": "assistant", "content": answer})
204
  else:
205
- # No function call: Just add model's reply
206
  st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
207
 
208
- # Clear input after sending (now legal and safe)
209
  st.session_state.temp_input = ""
210
 
211
- # --- User input box at bottom (like ChatGPT)
212
  if df is not None:
213
  st.text_input("Your message:", key="temp_input", on_change=send_message)
 
26
  st.sidebar.header("Upload CSV File")
27
  uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type="csv")
28
 
29
+ # --- Conversation memory: Use Streamlit session state
30
+ if "messages" not in st.session_state:
31
+ st.session_state.messages = []
32
+ if "temp_input" not in st.session_state:
33
+ st.session_state.temp_input = ""
34
+
35
+ # --- Only load df and reset chat on new file upload
36
+ if uploaded_file is not None:
37
  try:
38
  df = pd.read_csv(uploaded_file)
39
  st.sidebar.success("File uploaded successfully!")
40
  st.sidebar.write("Preview of the uploaded file:")
41
  st.sidebar.dataframe(df.head())
42
+ columns = ", ".join(df.columns)
43
+ system_message = {
44
+ "role": "system",
45
+ "content": (
46
+ f"You are an AI data analyst for a CSV file with these columns: {columns}. "
47
+ "When the user asks a question, always use the most relevant function to get the answer directly. "
48
+ "Do not describe your plan or reasoning steps. Do not ask the user for clarification. "
49
+ "Just call the function needed and give the answer, as briefly as possible. "
50
+ "If you need to search or filter the CSV, use the 'search_csv' function. "
51
+ "If you need to count unique values, use the 'count_unique' function. "
52
+ "If you use 'search_csv', use Pandas query syntax."
53
+ ),
54
+ }
55
+ # Only reset memory on new file load
56
+ if not st.session_state.messages or (
57
+ st.session_state.messages and
58
+ ("system" not in st.session_state.messages[0].get("role", ""))
59
+ ):
60
+ st.session_state.messages = [system_message]
61
+ elif (
62
+ st.session_state.messages and
63
+ st.session_state.messages[0].get("role", "") == "system" and
64
+ st.session_state.messages[0].get("content", "") != system_message["content"]
65
+ ):
66
+ st.session_state.messages[0] = system_message
67
  except Exception as e:
68
  st.sidebar.error(f"Error reading file: {e}")
69
  df = None
 
77
  def search_csv(query: str):
78
  try:
79
  result_df = df.query(query)
80
+ return result_df.head(10).to_dict(orient="records") # limit for safety
 
81
  except Exception as e:
82
  return {"error": f"Invalid query. Example: 'price > 100'. Details: {str(e)}"}
83
 
 
120
  }
121
  ]
122
 
 
123
  function_map = {
124
  "search_csv": search_csv,
125
  "count_unique": count_unique,
126
  }
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # --- Chat interface
129
  st.markdown("### Conversation")
 
 
130
  for i, msg in enumerate(st.session_state.messages[1:]): # Skip system message for display
131
  if msg["role"] == "user":
132
  st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
 
144
  user_input = st.session_state.temp_input
145
  if user_input and user_input.strip():
146
  st.session_state.messages.append({"role": "user", "content": user_input})
147
+ # Limit history for context size (keep system + last 8)
148
+ chat_messages = st.session_state.messages
149
+ if len(chat_messages) > 10:
150
+ chat_messages = [chat_messages[0]] + chat_messages[-9:]
151
+ else:
152
+ chat_messages = chat_messages.copy()
153
  # First OpenAI call: Check for function call
154
  chat_resp = requests.post(
155
  "https://api.openai.com/v1/chat/completions",
 
178
  function_result = function_map[func_name](**args)
179
  else:
180
  function_result = {"error": f"Unknown function: {func_name}"}
 
181
  st.session_state.messages.append({
182
  "role": "function",
183
  "name": func_name,
184
  "content": json.dumps(function_result),
185
  })
186
+ # Limit history again for second call
187
+ followup_messages = st.session_state.messages
188
+ if len(followup_messages) > 12:
189
+ followup_messages = [followup_messages[0]] + followup_messages[-11:]
190
+ else:
191
+ followup_messages = followup_messages.copy()
192
  final_resp = requests.post(
193
  "https://api.openai.com/v1/chat/completions",
194
  headers=HEADERS,
 
202
  )
203
  final_resp.raise_for_status()
204
  answer = final_resp.json()["choices"][0]["message"]["content"]
 
205
  st.session_state.messages.append({"role": "assistant", "content": answer})
206
  else:
 
207
  st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
208
 
 
209
  st.session_state.temp_input = ""
210
 
 
211
  if df is not None:
212
  st.text_input("Your message:", key="temp_input", on_change=send_message)