SRA25 commited on
Commit
a48bc9f
Β·
verified Β·
1 Parent(s): 238a322

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +301 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,303 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import uuid
3
+ import hashlib
4
+ from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar
5
+ from huggingface_hub import login
6
+ import logging
7
+ import time
8
+ import os
9
+ from dotenv import load_dotenv
10
+ from langgraph_init import upload_documents, submit_feedback, get_conversations,get_conversation_history, chat_with_rag, login_hf
11
+ # --- 1. Configuration ---
12
+ # API_URL = "https://sra25-fastapi-backend.hf.space"
13
+ # API_URL = "http://localhost:8000"
14
 
15
+ load_dotenv()
16
+ login_hf()
17
+ # hf_token = os.getenv("hf_user_token")
18
+ # login(hf_token)
19
+
20
+ # --- 1. Database Setup ---
21
+
22
+
23
+ # --- 2. Streamlit UI Components and State Management ---
24
+ st.set_page_config(page_title="Conversational RAG Chatbot", layout="wide")
25
+ st.title("πŸ’¬ Conversational RAG Chatbot")
26
+ st.caption("Powered by a FastAPI backend")
27
+
28
+ # Initialize session state for conversations, messages, and the current session ID
29
+ if "conversations" not in st.session_state:
30
+ st.session_state.conversations = []
31
+ if "session_id" not in st.session_state:
32
+ st.session_state.session_id = str(uuid.uuid4())
33
+ if "messages" not in st.session_state:
34
+ st.session_state.messages = []
35
+ if "retriever_ready" not in st.session_state:
36
+ st.session_state.retriever_ready = False
37
+ if "feedback_given" not in st.session_state:
38
+ st.session_state.feedback_given = {}
39
+ # New state variable to handle negative feedback comments
40
+ if "negative_feedback_for" not in st.session_state:
41
+ st.session_state.negative_feedback_for = None
42
+
43
+ # Initialize session state for storing uploaded file hashes
44
+ if 'uploaded_file_hashes' not in st.session_state:
45
+ st.session_state.uploaded_file_hashes = set()
46
+ if 'uploaded_files_info' not in st.session_state:
47
+ st.session_state.uploaded_files_info = []
48
+
49
+ def get_file_hash(file):
50
+ """Generates a unique hash for a file using its name, size, and content."""
51
+ hasher = hashlib.sha256()
52
+ # Read a small chunk of the file to ensure content-based uniqueness
53
+ # Combine with file name and size for a robust identifier
54
+ file_content = file.getvalue()
55
+ hasher.update(file.name.encode('utf-8'))
56
+ hasher.update(str(file.size).encode('utf-8'))
57
+ hasher.update(file_content[:1024]) # Use first 1KB of content
58
+ return hasher.hexdigest()
59
+ # --- 3. Helper Functions for Backend Communication ---
60
+ # def send_documents_to_backend(uploaded_files):
61
+ # try:
62
+ # for file in uploaded_files:
63
+ # process_status = upload_documents(file)
64
+ # return process_status
65
+ # except Exception as e:
66
+ # st.error(f"Error processing documents: {e}")
67
+ # return None
68
+
69
+ def send_chat_message_to_backend(prompt: str, chat_history: List[Dict[str, Any]]):
70
+ """Sends a chat message to the FastAPI backend and handles the response."""
71
+ if not prompt.strip():
72
+ return {"empty":"Invalid Question"}
73
+ history_for_api = [
74
+ {"role": msg.get("role"), "content": msg.get("content")}
75
+ for msg in chat_history
76
+ ]
77
+
78
+ payload = {
79
+ "user_question": str(prompt),
80
+ "session_id": st.session_state.session_id,
81
+ "chat_history": history_for_api,
82
+ }
83
+ print(f"Sending payload: {payload}") # Debug print
84
+ response_dict = chat_with_rag(payload)
85
+ try:
86
+ return response_dict
87
+ except Exception as e:
88
+ st.error(f"Error communicating with the backend")
89
+ print(f"Error communicating with the backend: {e}")
90
+ return None
91
+
92
+ def send_feedback_to_backend(telemetry_entry_id: str, feedback_score: int, feedback_text: Optional[str] = None):
93
+ """Sends feedback to the FastAPI backend."""
94
+ payload = {
95
+ "session_id": st.session_state.session_id,
96
+ "telemetry_entry_id": telemetry_entry_id,
97
+ "feedback_score": feedback_score,
98
+ "feedback_text": feedback_text
99
+ }
100
+ try:
101
+ # response = requests.post(f"{API_URL}/feedback", json=payload)
102
+ response = submit_feedback(payload)
103
+ # response.raise_for_status()
104
+ st.toast("Feedback submitted! Thank you.")
105
+ except Exception as e:
106
+ st.error(f"Error submitting feedback: {e}")
107
+
108
+ def get_conversations_from_backend() -> list:
109
+ """Fetches a list of all conversations from the backend."""
110
+ try:
111
+ # response = requests.get(f"{API_URL}/conversations")
112
+ response = get_conversations()
113
+ # response.raise_for_status()
114
+ return response
115
+ except Exception as e:
116
+ st.sidebar.error(f"Error fetching conversations: {e}")
117
+ return []
118
+
119
+ def get_conversation_history_from_backend(session_id: str):
120
+ """Fetches the messages for a specific conversation ID."""
121
+ try:
122
+ # response = requests.get(f"{API_URL}/conversations/{session_id}")
123
+
124
+ response = get_conversation_history(session_id)
125
+ return response
126
+ except Exception as e:
127
+ st.error(f"Error loading conversation history: {e}")
128
+ return None
129
+
130
+ def handle_positive_feedback(telemetry_id):
131
+ """Handles positive feedback submission."""
132
+ send_feedback_to_backend(telemetry_id, 1)
133
+ st.session_state.feedback_given[telemetry_id] = True
134
+
135
+
136
+ def handle_negative_feedback_comment_submit(telemetry_id, comment_text):
137
+ """Handles the negative feedback comment submission."""
138
+ send_feedback_to_backend(telemetry_id, -1, comment_text)
139
+ st.session_state.feedback_given[telemetry_id] = True
140
+ st.session_state.negative_feedback_for = None
141
+
142
+
143
+ def refresh_conversations():
144
+ """Refreshes the conversation list in the sidebar."""
145
+ st.session_state.conversations = get_conversations_from_backend()
146
+
147
+ # --- 4. Sidebar for Document Upload and Conversation History ---
148
+ with st.sidebar:
149
+ st.header("Upload Documents")
150
+ uploaded_files = st.file_uploader(
151
+ "Upload Text, PDF, Docx files:",
152
+ type=["txt","pdf","docx"],
153
+ accept_multiple_files=True,
154
+ key="file_uploader"
155
+ )
156
+ if st.button("Process Documents", key="process_docs_button"):
157
+ if uploaded_files:
158
+ new_uploaded_files = []
159
+ newly_added_files_info = []
160
+ duplicate_file = []
161
+ for file in uploaded_files:
162
+ file_hash = get_file_hash(file)
163
+
164
+ if file_hash not in st.session_state.uploaded_file_hashes:
165
+ st.session_state.uploaded_file_hashes.add(file_hash)
166
+ new_uploaded_files.append(file)
167
+ newly_added_files_info.append({"name": file.name, "size": file.size})
168
+ else:
169
+ st.warning(f"File '{file.name}' has already been uploaded.")
170
+ duplicate_file.append(file)
171
+
172
+ if new_uploaded_files:
173
+ st.session_state.uploaded_files_info.extend(newly_added_files_info)
174
+ with st.spinner("Processing documents..."):
175
+ process_status = upload_documents(new_uploaded_files)
176
+ if process_status:
177
+ st.session_state.retriever_ready = True
178
+ # st.success(response_data.get("message", "Documents processed and knowledge base ready!"))
179
+ st.success(process_status)
180
+ st.session_state.messages = []
181
+ refresh_conversations() # sql query need to be added here
182
+ else:
183
+ st.session_state.retriever_ready = False
184
+ st.error("Failed to process documents.")
185
+ elif duplicate_file:
186
+ # st.session_state.uploaded_files_info.extend(newly_added_files_info)
187
+ with st.spinner("Processing documents..."):
188
+ process_status = upload_documents(duplicate_file)
189
+ if process_status:
190
+ st.session_state.retriever_ready = True
191
+ # st.success(response_data.get("message", "Documents processed and knowledge base ready!"))
192
+ st.success(process_status)
193
+ st.session_state.messages = []
194
+ refresh_conversations()
195
+ else:
196
+ st.session_state.retriever_ready = False
197
+ st.error("Failed to process documents.")
198
+ else:
199
+ st.warning("Please upload at least one document to process.")
200
+
201
+ st.markdown("---")
202
+ st.header("Conversations")
203
+ if st.button("βž• New Chat", key="new_chat_button", use_container_width=True, type="primary"):
204
+ st.session_state.session_id = str(uuid.uuid4())
205
+ st.session_state.messages = []
206
+ st.session_state.feedback_given = {}
207
+ st.session_state.negative_feedback_for = None
208
+ refresh_conversations()
209
+ st.rerun()
210
+
211
+ refresh_conversations()
212
+
213
+ if st.session_state.conversations:
214
+ for conv in st.session_state.conversations:
215
+ if st.button(
216
+ conv["title"],
217
+ key=f"conv_{conv['session_id']}",
218
+ use_container_width=True
219
+ ):
220
+ if st.session_state.session_id != conv["session_id"]:
221
+ st.session_state.session_id = conv["session_id"]
222
+ history = get_conversation_history_from_backend(conv["session_id"])
223
+ if history != [] or history != None:
224
+ st.session_state.messages = history
225
+ st.session_state.feedback_given = {msg.get("telemetry_id"): True for msg in history if msg.get("telemetry_id")}
226
+ else:
227
+ st.session_state.messages = []
228
+ st.session_state.negative_feedback_for = None
229
+ st.rerun()
230
+
231
+ # --- 5. Main Chat Interface ---
232
+ # Display chat messages from history on app rerun
233
+ for message in st.session_state.messages:
234
+ with st.chat_message(message["role"]):
235
+ st.markdown(message["content"])
236
+
237
+ # Display feedback buttons for the last AI response
238
+ if message["role"] == "assistant" and message.get("telemetry_id") and not st.session_state.feedback_given.get(message["telemetry_id"], False):
239
+ col1, col2 = st.columns(2)
240
+ with col1:
241
+ if st.button("πŸ‘", key=f"positive_{message['telemetry_id']}", on_click=handle_positive_feedback, args=(message['telemetry_id'],)):
242
+ pass
243
+ with col2:
244
+ if st.button("πŸ‘Ž", key=f"negative_{message['telemetry_id']}"):
245
+ st.session_state.negative_feedback_for = message['telemetry_id']
246
+ st.rerun()
247
+
248
+ # --- NEW LOGIC FOR NEGATIVE FEEDBACK COMMENT ---
249
+ # Only render the comment input if this is the message the user clicked thumbs down on
250
+ if st.session_state.negative_feedback_for == message['telemetry_id']:
251
+ with st.container():
252
+ comment = st.text_area(
253
+ "Please provide some details (optional):",
254
+ key=f"feedback_text_{message['telemetry_id']}"
255
+ )
256
+ if st.button("Submit Comment", key=f"submit_feedback_button_{message['telemetry_id']}"):
257
+ handle_negative_feedback_comment_submit(message['telemetry_id'], comment)
258
+
259
+ # Chat input for new questions
260
+ if st.session_state.retriever_ready:
261
+ if prompt := st.chat_input("Ask me anything about the uploaded documents..."):
262
+ st.session_state.messages.append({"role": "user", "content": prompt})
263
+ with st.chat_message("user"):
264
+ st.markdown(prompt)
265
+
266
+ with st.chat_message("assistant"):
267
+ with st.spinner("Thinking..."):
268
+ response_data = send_chat_message_to_backend(prompt, st.session_state.messages)
269
+ if response_data:
270
+ if response_data.get("is_restricted"):
271
+ ai_response = response_data.get("ai_response", "Sorry, I couldn't generate a response.")
272
+ reason = response_data.get("moderation_reason")
273
+ st.markdown(ai_response)
274
+ st.markdown(reason)
275
+ elif response_data.get("empty"):
276
+ st.markdown(response_data.get("empty"))
277
+
278
+ ai_response = response_data.get("ai_response", "Sorry, I couldn't generate a response.")
279
+ telemetry_id = response_data.get("telemetry_entry_id")
280
+
281
+ st.markdown(ai_response)
282
+
283
+ st.session_state.messages.append({
284
+ "role": "assistant",
285
+ "content": ai_response,
286
+ "telemetry_id": telemetry_id
287
+ })
288
+
289
+ refresh_conversations()
290
+
291
+ if telemetry_id:
292
+ col1, col2 = st.columns(2)
293
+ with col1:
294
+ if st.button("πŸ‘", key=f"positive_{telemetry_id}", on_click=handle_positive_feedback, args=(telemetry_id,)):
295
+ pass
296
+ with col2:
297
+ if st.button("πŸ‘Ž", key=f"negative_{telemetry_id}"):
298
+ st.session_state.negative_feedback_for = telemetry_id
299
+ st.rerun()
300
+ else:
301
+ st.markdown("An error occurred.")
302
+ else:
303
+ st.info("Please upload and process documents to start chatting.")