jeonghin commited on
Commit
c793fd0
·
verified ·
1 Parent(s): fd150df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -1
app.py CHANGED
@@ -120,16 +120,140 @@ def handle_userinput(user_question):
120
  )
121
 
122
 
123
- def chat(slug):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  """
125
  Manages the chat interface in the Streamlit application, handling the conversation
126
  flow and displaying the chat history.
 
127
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  text_chunks = get_text_chunks(get_pdf_text(slug))
130
  vectorstore = get_vectorstore(text_chunks)
131
  st.session_state.conversation = get_conversation_chain(vectorstore)
132
 
 
 
 
 
 
 
 
133
  if len(st.session_state.messages) == 1:
134
  message = st.session_state.messages[0]
135
  with st.chat_message(message["role"]):
 
120
  )
121
 
122
 
123
+ def get_user_chat_count(user_id):
124
+ """
125
+ Retrieves the chat count for the user from the MySQL database.
126
+ """
127
+ try:
128
+ conn = mysql.connector.connect(
129
+ user=os.getenv("SQL_USER"),
130
+ password=os.getenv("SQL_PWD"),
131
+ host=os.getenv("SQL_HOST"),
132
+ database="Birdseye_DB",
133
+ )
134
+ cursor = conn.cursor()
135
+
136
+ cursor.execute("SELECT count FROM Chat WHERE user_id = %s", (user_id,))
137
+ result = cursor.fetchone()
138
+ if result:
139
+ return result[0]
140
+ else:
141
+ # Insert a new row for the user if not found
142
+ cursor.execute(
143
+ "INSERT INTO Chat (user_id, count) VALUES (%s, %s)", (user_id, 0)
144
+ )
145
+ conn.commit()
146
+ return 0
147
+ except mysql.connector.Error as err:
148
+ st.error(f"Error: {err}")
149
+ return None
150
+ finally:
151
+ if conn.is_connected():
152
+ cursor.close()
153
+ conn.close()
154
+
155
+
156
+ def increment_user_chat_count(user_id):
157
+ """
158
+ Increments the chat count for the user in the MySQL database.
159
+ """
160
+ try:
161
+ conn = mysql.connector.connect(
162
+ user=os.getenv("SQL_USER"),
163
+ password=os.getenv("SQL_PWD"),
164
+ host=os.getenv("SQL_HOST"),
165
+ database="Birdseye_DB",
166
+ )
167
+ cursor = conn.cursor()
168
+
169
+ cursor.execute(
170
+ "UPDATE Chat SET count = count + 1 WHERE user_id = %s", (user_id,)
171
+ )
172
+ conn.commit()
173
+ except mysql.connector.Error as err:
174
+ st.error(f"Error: {err}")
175
+ finally:
176
+ if conn.is_connected():
177
+ cursor.close()
178
+ conn.close()
179
+
180
+
181
+ def is_user_in_unlimited_chat_group(user_id):
182
+ """
183
+ Checks if the user belongs to the 'Unlimited Chat' group.
184
+ """
185
+ try:
186
+ conn = mysql.connector.connect(
187
+ user=os.getenv("SQL_USER"),
188
+ password=os.getenv("SQL_PWD"),
189
+ host=os.getenv("SQL_HOST"),
190
+ database="Birdseye_DB",
191
+ )
192
+ cursor = conn.cursor()
193
+
194
+ cursor.execute(
195
+ """
196
+ SELECT 1
197
+ FROM auth_user_groups
198
+ JOIN auth_group ON auth_user_groups.group_id = auth_group.id
199
+ WHERE auth_user_groups.user_id = %s AND auth_group.name = 'Unlimited Chat'
200
+ """,
201
+ (user_id,),
202
+ )
203
+ return cursor.fetchone() is not None
204
+ except mysql.connector.Error as err:
205
+ st.error(f"Error: {err}")
206
+ return False
207
+ finally:
208
+ if conn.is_connected():
209
+ cursor.close()
210
+ conn.close()
211
+
212
+
213
+ def chat(slug, user_id):
214
  """
215
  Manages the chat interface in the Streamlit application, handling the conversation
216
  flow and displaying the chat history.
217
+ Restricts chat based on user group and chat count.
218
  """
219
+ try:
220
+ conn = mysql.connector.connect(
221
+ user=os.getenv("SQL_USER"),
222
+ password=os.getenv("SQL_PWD"),
223
+ host=os.getenv("SQL_HOST"),
224
+ database="Birdseye_DB",
225
+ )
226
+ cursor = conn.cursor()
227
+
228
+ # Execute a query
229
+ cursor.execute("SELECT ocr_text FROM birdseye_temp WHERE slug = %s", (slug,))
230
+
231
+ # Fetch the results
232
+ rows = cursor.fetchall()
233
+ text = ""
234
+ for row in rows:
235
+ if row[0]:
236
+ text += row[0]
237
+
238
+ except mysql.connector.Error as err:
239
+ st.error(f"Error: {err}")
240
+ return
241
+ finally:
242
+ if conn.is_connected():
243
+ cursor.close()
244
+ conn.close()
245
 
246
  text_chunks = get_text_chunks(get_pdf_text(slug))
247
  vectorstore = get_vectorstore(text_chunks)
248
  st.session_state.conversation = get_conversation_chain(vectorstore)
249
 
250
+ # Check if the user can chat
251
+ if not is_user_in_unlimited_chat_group(user_id):
252
+ user_chat_count = get_user_chat_count(user_id)
253
+ if user_chat_count is None or user_chat_count >= 20:
254
+ st.write("You have reached your chat limit.")
255
+ return
256
+
257
  if len(st.session_state.messages) == 1:
258
  message = st.session_state.messages[0]
259
  with st.chat_message(message["role"]):