RAVENOCC commited on
Commit
172fd5d
Β·
verified Β·
1 Parent(s): 8a4bd37

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +282 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,284 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from dotenv import load_dotenv
3
+ from PyPDF2 import PdfReader
4
+ from langchain_groq import ChatGroq
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_community.llms import HuggingFaceHub
9
+ from langchain.prompts import ChatPromptTemplate
10
+ from langchain_core.output_parsers import StrOutputParser
11
+ from langchain_core.runnables import RunnablePassthrough
12
+ from htmlTemplates import css, bot_template, user_template
13
+ import os
14
 
15
+
16
+ def get_pdf_text(pdf_docs):
17
+ text = ""
18
+ for pdf in pdf_docs:
19
+ pdf_reader = PdfReader(pdf)
20
+ for page in pdf_reader.pages:
21
+ text += page.extract_text()
22
+ return text
23
+
24
+ def get_text_chunks(text):
25
+ text_splitter = CharacterTextSplitter(
26
+ separator="\n",
27
+ chunk_size=1000,
28
+ chunk_overlap=200,
29
+ length_function=len
30
+ )
31
+ chunks = text_splitter.split_text(text)
32
+ return chunks
33
+
34
+ def get_vector_store(text_chunks):
35
+ try:
36
+ model_name = "BAAI/bge-small-en"
37
+ model_kwargs = {'device': 'cpu'}
38
+ encode_kwargs = {"normalize_embeddings": True}
39
+
40
+ embeddings = HuggingFaceEmbeddings(
41
+ model_name=model_name,
42
+ model_kwargs=model_kwargs,
43
+ encode_kwargs=encode_kwargs,
44
+ cache_folder="/tmp/huggingface_cache"
45
+ )
46
+
47
+ vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
48
+ return vectorstore
49
+ except Exception as e:
50
+ st.error(f"Error creating vector store: {str(e)}")
51
+ return None
52
+
53
+ def get_conversation_chain(vectorstore, api_key):
54
+ if not api_key:
55
+ st.error("Please provide a valid Groq API key.")
56
+ return None
57
+
58
+ try:
59
+ # Set the API key in environment for this session
60
+ os.environ["GROQ_API_KEY"] = api_key
61
+
62
+ llm = ChatGroq(
63
+ model="llama3-8b-8192",
64
+ temperature=0,
65
+ api_key=api_key
66
+ )
67
+
68
+ # Create the prompt template
69
+ prompt = ChatPromptTemplate.from_messages([
70
+ ("system", """You are a helpful assistant answering questions based on the provided documents.
71
+ Answer the question using only the context provided.
72
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
73
+ Keep your answers focused and relevant to the question."""),
74
+ ("human", """Context: {context}
75
+
76
+ Question: {question}
77
+
78
+ Answer: """)
79
+ ])
80
+
81
+ # Create the retrieval chain
82
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
83
+
84
+ # Define the chain
85
+ chain = (
86
+ {"context": retriever, "question": RunnablePassthrough()}
87
+ | prompt
88
+ | llm
89
+ | StrOutputParser()
90
+ )
91
+
92
+ return chain
93
+
94
+ except Exception as e:
95
+ st.error(f"Failed to initialize Groq model: {str(e)}")
96
+ st.info("Please check if your API key is valid. Get your API key from: https://console.groq.com/keys")
97
+ return None
98
+
99
+ def handle_user_input(user_question):
100
+ if st.session_state.conversation is None:
101
+ st.warning("Please upload and process documents first.")
102
+ return
103
+
104
+ try:
105
+ # Invoke the chain with the question
106
+ response = st.session_state.conversation.invoke(user_question)
107
+
108
+ # Update chat history
109
+ if 'chat_history' not in st.session_state:
110
+ st.session_state.chat_history = []
111
+
112
+ # Add the new messages to chat history
113
+ st.session_state.chat_history.append(("user", user_question))
114
+ st.session_state.chat_history.append(("bot", response))
115
+
116
+ # Display chat history
117
+ for sender, message in st.session_state.chat_history:
118
+ if sender == "user":
119
+ st.write(user_template.replace("{{MSG}}", message), unsafe_allow_html=True)
120
+ else:
121
+ st.write(bot_template.replace("{{MSG}}", message), unsafe_allow_html=True)
122
+
123
+ except Exception as e:
124
+ st.error(f"An error occurred while processing your question: {str(e)}")
125
+ st.info("This might be due to an invalid API key or network issues.")
126
+
127
+ def main():
128
+ load_dotenv()
129
+
130
+ # Set environment variables for HuggingFace cache
131
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface_cache'
132
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
133
+
134
+ # Create cache directory
135
+ os.makedirs('/tmp/huggingface_cache', exist_ok=True)
136
+
137
+ if 'user_template' not in globals():
138
+ global user_template
139
+ user_template = '''
140
+ <div class="chat-message user">
141
+ <div class="avatar">
142
+ <img src="https://i.ibb.co/rdZC7LZ/user.png">
143
+ </div>
144
+ <div class="message">{{MSG}}</div>
145
+ </div>
146
+ '''
147
+
148
+ if 'bot_template' not in globals():
149
+ global bot_template
150
+ bot_template = '''
151
+ <div class="chat-message bot">
152
+ <div class="avatar">
153
+ <img src="https://i.ibb.co/cN0nmSj/robot.png">
154
+ </div>
155
+ <div class="message">{{MSG}}</div>
156
+ </div>
157
+ '''
158
+
159
+ st.set_page_config(page_title='Chat with PDFs', page_icon=":books:")
160
+ st.write(css, unsafe_allow_html=True)
161
+
162
+ # Initialize session state
163
+ if "conversation" not in st.session_state:
164
+ st.session_state.conversation = None
165
+
166
+ if "chat_history" not in st.session_state:
167
+ st.session_state.chat_history = []
168
+
169
+ if "groq_api_key" not in st.session_state:
170
+ st.session_state.groq_api_key = ""
171
+
172
+ st.header('PDF ChatBot πŸ“š')
173
+
174
+ # API Key Input Section
175
+ st.sidebar.header("πŸ”‘ API Configuration")
176
+
177
+ # API Key input
178
+ groq_api_key = st.sidebar.text_input(
179
+ "Enter your Groq API Key:",
180
+ type="password",
181
+ value=st.session_state.groq_api_key,
182
+ help="Get your free API key from https://console.groq.com/keys"
183
+ )
184
+
185
+ # Update session state
186
+ if groq_api_key:
187
+ st.session_state.groq_api_key = groq_api_key
188
+ st.sidebar.success("βœ… API Key provided!")
189
+ else:
190
+ st.sidebar.warning("⚠️ Please enter your Groq API key to continue.")
191
+ st.sidebar.info("Get your free API key from: https://console.groq.com/keys")
192
+
193
+ st.sidebar.markdown("---")
194
+
195
+ # Sidebar for PDF upload
196
+ st.sidebar.subheader("πŸ“„ Upload Documents")
197
+ pdf_docs = st.sidebar.file_uploader(
198
+ "Upload your PDFs here and click 'Process'",
199
+ accept_multiple_files=True,
200
+ type=['pdf']
201
+ )
202
+
203
+ # Process button
204
+ if st.sidebar.button('πŸš€ Process Documents'):
205
+ if not groq_api_key:
206
+ st.sidebar.error("❌ Please enter your Groq API key first!")
207
+ st.error("Please provide your Groq API key in the sidebar to continue.")
208
+ return
209
+
210
+ if not pdf_docs:
211
+ st.sidebar.warning("πŸ“‹ Please upload at least one PDF document.")
212
+ return
213
+
214
+ with st.spinner("Processing documents... This may take a few minutes for the first run."):
215
+ try:
216
+ # Get PDF text
217
+ raw_text = get_pdf_text(pdf_docs)
218
+
219
+ if not raw_text.strip():
220
+ st.error("❌ No text could be extracted from the PDFs. Please check if the PDFs contain readable text.")
221
+ return
222
+
223
+ st.info(f"βœ… Extracted {len(raw_text)} characters from {len(pdf_docs)} PDF(s)")
224
+
225
+ # Get text chunks
226
+ text_chunks = get_text_chunks(raw_text)
227
+ st.info(f"βœ… Created {len(text_chunks)} text chunks")
228
+
229
+ # Create vector store
230
+ with st.spinner("Creating embeddings..."):
231
+ vectorstore = get_vector_store(text_chunks)
232
+
233
+ if vectorstore is None:
234
+ st.error("❌ Failed to create vector store. Please try again.")
235
+ return
236
+
237
+ st.info("βœ… Vector store created successfully")
238
+
239
+ # Create conversation chain
240
+ with st.spinner("Initializing conversation chain..."):
241
+ conversation = get_conversation_chain(vectorstore, groq_api_key)
242
+
243
+ if conversation is None:
244
+ st.error("❌ Failed to create conversation chain. Please check your API key.")
245
+ return
246
+
247
+ st.session_state.conversation = conversation
248
+ st.success("πŸŽ‰ Documents processed successfully! You can now ask questions.")
249
+
250
+ except Exception as e:
251
+ st.error(f"❌ An error occurred: {str(e)}")
252
+ st.info("Please check your API key and try again.")
253
+
254
+ # Main chat interface
255
+ st.subheader("πŸ’¬ Ask Questions About Your Documents")
256
+
257
+ if not groq_api_key:
258
+ st.info("πŸ‘† Please enter your Groq API key in the sidebar to get started.")
259
+ st.info("πŸ”— Get your free API key from: https://console.groq.com/keys")
260
+ elif st.session_state.conversation is None:
261
+ st.info("πŸ“€ Upload and process your PDF documents using the sidebar to start chatting.")
262
+ else:
263
+ user_question = st.text_input(
264
+ "Your question:",
265
+ placeholder="Ask anything about your uploaded documents..."
266
+ )
267
+
268
+ if user_question:
269
+ handle_user_input(user_question)
270
+
271
+ # Display instructions
272
+ if not groq_api_key or st.session_state.conversation is None:
273
+ st.markdown("---")
274
+ st.markdown("### πŸ“‹ How to Use:")
275
+ st.markdown("""
276
+ 1. **Get API Key**: Visit [Groq Console](https://console.groq.com/keys) to get your free API key
277
+ 2. **Enter API Key**: Paste your API key in the sidebar
278
+ 3. **Upload PDFs**: Upload one or more PDF documents
279
+ 4. **Process**: Click 'Process Documents' to analyze your PDFs
280
+ 5. **Chat**: Ask questions about your documents!
281
+ """)
282
+
283
+ if __name__ == "__main__":
284
+ main()