BetaGen commited on
Commit
ccd7243
Β·
verified Β·
1 Parent(s): eaf6a20

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +442 -0
app.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import PyPDF2
4
+ import docx
5
+ from io import BytesIO
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sentence_transformers import SentenceTransformer
9
+ import faiss
10
+ import pickle
11
+ from groq import Groq
12
+ from typing import List, Tuple
13
+ import re
14
+
15
+ # Page configuration
16
+ st.set_page_config(
17
+ page_title="πŸ€– Smart RAG Assistant",
18
+ page_icon="🧠",
19
+ layout="wide",
20
+ initial_sidebar_state="expanded"
21
+ )
22
+
23
+ # Custom CSS for better styling
24
+ st.markdown("""
25
+ <style>
26
+ .main-header {
27
+ text-align: center;
28
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
29
+ padding: 2rem;
30
+ border-radius: 10px;
31
+ margin-bottom: 2rem;
32
+ color: white;
33
+ }
34
+
35
+ .chat-message {
36
+ padding: 1rem;
37
+ border-radius: 10px;
38
+ margin: 1rem 0;
39
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
40
+ }
41
+
42
+ .user-message {
43
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
44
+ color: white;
45
+ margin-left: 20%;
46
+ }
47
+
48
+ .bot-message {
49
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
50
+ color: white;
51
+ margin-right: 20%;
52
+ }
53
+
54
+ .sidebar-info {
55
+ background: #f0f2f6;
56
+ padding: 1rem;
57
+ border-radius: 10px;
58
+ border-left: 4px solid #667eea;
59
+ }
60
+
61
+ .doc-info {
62
+ background: #e8f4fd;
63
+ padding: 1rem;
64
+ border-radius: 10px;
65
+ border: 1px solid #b3d9ff;
66
+ margin: 1rem 0;
67
+ }
68
+
69
+ .stButton > button {
70
+ width: 100%;
71
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
72
+ color: white;
73
+ border: none;
74
+ padding: 0.5rem 1rem;
75
+ border-radius: 10px;
76
+ font-weight: bold;
77
+ }
78
+
79
+ .stButton > button:hover {
80
+ transform: translateY(-2px);
81
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2);
82
+ }
83
+ </style>
84
+ """, unsafe_allow_html=True)
85
+
86
+ class RAGSystem:
87
+ def __init__(self):
88
+ self.embedding_model = None
89
+ self.index = None
90
+ self.documents = []
91
+ self.groq_client = None
92
+
93
+ @st.cache_resource
94
+ def load_embedding_model(_self):
95
+ """Load the sentence transformer model"""
96
+ try:
97
+ model = SentenceTransformer('all-MiniLM-L6-v2')
98
+ return model
99
+ except Exception as e:
100
+ st.error(f"Error loading embedding model: {str(e)}")
101
+ return None
102
+
103
+ def setup_groq_client(self, api_key: str):
104
+ """Setup Groq client"""
105
+ try:
106
+ self.groq_client = Groq(api_key=api_key)
107
+ return True
108
+ except Exception as e:
109
+ st.error(f"Error setting up Groq client: {str(e)}")
110
+ return False
111
+
112
+ def extract_text_from_pdf(self, pdf_file) -> str:
113
+ """Extract text from PDF file"""
114
+ try:
115
+ pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_file.read()))
116
+ text = ""
117
+ for page in pdf_reader.pages:
118
+ text += page.extract_text() + "\n"
119
+ return text
120
+ except Exception as e:
121
+ st.error(f"Error reading PDF: {str(e)}")
122
+ return ""
123
+
124
+ def extract_text_from_docx(self, docx_file) -> str:
125
+ """Extract text from DOCX file"""
126
+ try:
127
+ doc = docx.Document(BytesIO(docx_file.read()))
128
+ text = ""
129
+ for paragraph in doc.paragraphs:
130
+ text += paragraph.text + "\n"
131
+ return text
132
+ except Exception as e:
133
+ st.error(f"Error reading DOCX: {str(e)}")
134
+ return ""
135
+
136
+ def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
137
+ """Split text into overlapping chunks"""
138
+ sentences = re.split(r'[.!?]+', text)
139
+ chunks = []
140
+ current_chunk = ""
141
+
142
+ for sentence in sentences:
143
+ sentence = sentence.strip()
144
+ if not sentence:
145
+ continue
146
+
147
+ if len(current_chunk) + len(sentence) < chunk_size:
148
+ current_chunk += sentence + ". "
149
+ else:
150
+ if current_chunk:
151
+ chunks.append(current_chunk.strip())
152
+ current_chunk = sentence + ". "
153
+
154
+ if current_chunk:
155
+ chunks.append(current_chunk.strip())
156
+
157
+ return chunks
158
+
159
+ def create_embeddings_and_index(self, documents: List[str]):
160
+ """Create embeddings and FAISS index"""
161
+ if not self.embedding_model:
162
+ self.embedding_model = self.load_embedding_model()
163
+
164
+ if not self.embedding_model:
165
+ return False
166
+
167
+ try:
168
+ # Create embeddings
169
+ embeddings = self.embedding_model.encode(documents, show_progress_bar=True)
170
+
171
+ # Create FAISS index
172
+ dimension = embeddings.shape[1]
173
+ self.index = faiss.IndexFlatIP(dimension) # Inner product similarity
174
+
175
+ # Normalize embeddings for cosine similarity
176
+ faiss.normalize_L2(embeddings)
177
+ self.index.add(embeddings.astype('float32'))
178
+
179
+ self.documents = documents
180
+ return True
181
+ except Exception as e:
182
+ st.error(f"Error creating embeddings: {str(e)}")
183
+ return False
184
+
185
+ def retrieve_relevant_docs(self, query: str, k: int = 3) -> List[Tuple[str, float]]:
186
+ """Retrieve most relevant documents for the query"""
187
+ if not self.embedding_model or not self.index:
188
+ return []
189
+
190
+ try:
191
+ # Encode query
192
+ query_embedding = self.embedding_model.encode([query])
193
+ faiss.normalize_L2(query_embedding)
194
+
195
+ # Search
196
+ scores, indices = self.index.search(query_embedding.astype('float32'), k)
197
+
198
+ results = []
199
+ for score, idx in zip(scores[0], indices[0]):
200
+ if idx < len(self.documents):
201
+ results.append((self.documents[idx], float(score)))
202
+
203
+ return results
204
+ except Exception as e:
205
+ st.error(f"Error retrieving documents: {str(e)}")
206
+ return []
207
+
208
+ def generate_answer(self, query: str, context: str, model: str = "llama-3.3-70b-versatile") -> str:
209
+ """Generate answer using Groq"""
210
+ if not self.groq_client:
211
+ return "Error: Groq client not initialized"
212
+
213
+ try:
214
+ prompt = f"""Based on the following context, please answer the question accurately and concisely. If the answer cannot be found in the context, please say so.
215
+
216
+ Context:
217
+ {context}
218
+
219
+ Question: {query}
220
+
221
+ Answer:"""
222
+
223
+ chat_completion = self.groq_client.chat.completions.create(
224
+ messages=[
225
+ {
226
+ "role": "system",
227
+ "content": "You are a helpful assistant that answers questions based on the provided context. Be accurate and concise."
228
+ },
229
+ {
230
+ "role": "user",
231
+ "content": prompt
232
+ }
233
+ ],
234
+ model=model,
235
+ temperature=0.3,
236
+ max_tokens=1000
237
+ )
238
+
239
+ return chat_completion.choices[0].message.content
240
+ except Exception as e:
241
+ return f"Error generating answer: {str(e)}"
242
+
243
+ def main():
244
+ # Header
245
+ st.markdown("""
246
+ <div class="main-header">
247
+ <h1>πŸ€– Smart RAG Assistant</h1>
248
+ <p>Upload documents and ask questions - powered by Groq & Sentence Transformers</p>
249
+ </div>
250
+ """, unsafe_allow_html=True)
251
+
252
+ # Initialize RAG system
253
+ if 'rag_system' not in st.session_state:
254
+ st.session_state.rag_system = RAGSystem()
255
+
256
+ if 'chat_history' not in st.session_state:
257
+ st.session_state.chat_history = []
258
+
259
+ # Sidebar
260
+ with st.sidebar:
261
+ st.markdown("## βš™οΈ Configuration")
262
+
263
+ # API Key input
264
+ api_key = st.text_input(
265
+ "πŸ”‘ Groq API Key",
266
+ type="password",
267
+ value="GROQ_API_KEY",
268
+ help="Enter your Groq API key"
269
+ )
270
+
271
+ if api_key:
272
+ if st.session_state.rag_system.setup_groq_client(api_key):
273
+ st.success("βœ… Groq client configured!")
274
+
275
+ st.markdown("---")
276
+
277
+ # Model selection
278
+ model_options = [
279
+ "llama-3.3-70b-versatile",
280
+ "llama-3.1-70b-versatile",
281
+ "llama-3.1-8b-instant",
282
+ "mixtral-8x7b-32768"
283
+ ]
284
+ selected_model = st.selectbox("πŸ€– Select Model", model_options)
285
+
286
+ st.markdown("---")
287
+
288
+ # Document upload
289
+ st.markdown("## πŸ“ Document Upload")
290
+ uploaded_files = st.file_uploader(
291
+ "Upload documents",
292
+ type=['pdf', 'docx', 'txt'],
293
+ accept_multiple_files=True,
294
+ help="Upload PDF, DOCX, or TXT files"
295
+ )
296
+
297
+ if uploaded_files and st.button("πŸš€ Process Documents"):
298
+ with st.spinner("Processing documents..."):
299
+ all_text = ""
300
+ doc_info = []
301
+
302
+ for file in uploaded_files:
303
+ if file.type == "application/pdf":
304
+ text = st.session_state.rag_system.extract_text_from_pdf(file)
305
+ doc_info.append(f"πŸ“„ {file.name} ({len(text)} chars)")
306
+ elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
307
+ text = st.session_state.rag_system.extract_text_from_docx(file)
308
+ doc_info.append(f"πŸ“ {file.name} ({len(text)} chars)")
309
+ else: # txt
310
+ text = str(file.read(), "utf-8")
311
+ doc_info.append(f"πŸ“„ {file.name} ({len(text)} chars)")
312
+
313
+ all_text += text + "\n\n"
314
+
315
+ # Chunk the text
316
+ chunks = st.session_state.rag_system.chunk_text(all_text)
317
+
318
+ # Create embeddings and index
319
+ if st.session_state.rag_system.create_embeddings_and_index(chunks):
320
+ st.success(f"βœ… Processed {len(chunks)} chunks from {len(uploaded_files)} documents!")
321
+
322
+ # Show document info
323
+ st.markdown("### πŸ“Š Processed Documents:")
324
+ for info in doc_info:
325
+ st.markdown(f"- {info}")
326
+
327
+ # Clear chat history
328
+ if st.button("πŸ—‘οΈ Clear Chat History"):
329
+ st.session_state.chat_history = []
330
+ st.rerun()
331
+
332
+ # Main content area
333
+ col1, col2 = st.columns([2, 1])
334
+
335
+ with col1:
336
+ st.markdown("## πŸ’¬ Chat with your documents")
337
+
338
+ # Display chat history
339
+ chat_container = st.container()
340
+ with chat_container:
341
+ for i, (role, message) in enumerate(st.session_state.chat_history):
342
+ if role == "user":
343
+ st.markdown(f"""
344
+ <div class="chat-message user-message">
345
+ <strong>πŸ™‹β€β™‚οΈ You:</strong><br>{message}
346
+ </div>
347
+ """, unsafe_allow_html=True)
348
+ else:
349
+ st.markdown(f"""
350
+ <div class="chat-message bot-message">
351
+ <strong>πŸ€– Assistant:</strong><br>{message}
352
+ </div>
353
+ """, unsafe_allow_html=True)
354
+
355
+ # Query input
356
+ query = st.text_input(
357
+ "Ask a question about your documents:",
358
+ placeholder="e.g., What is the main topic discussed in the documents?",
359
+ key="query_input"
360
+ )
361
+
362
+ col_send, col_clear = st.columns([3, 1])
363
+ with col_send:
364
+ send_button = st.button("πŸ“€ Send", key="send_button")
365
+
366
+ if (send_button or query) and query:
367
+ if not st.session_state.rag_system.documents:
368
+ st.warning("⚠️ Please upload and process documents first!")
369
+ elif not api_key:
370
+ st.warning("⚠️ Please enter your Groq API key!")
371
+ else:
372
+ with st.spinner("Searching and generating answer..."):
373
+ # Retrieve relevant documents
374
+ relevant_docs = st.session_state.rag_system.retrieve_relevant_docs(query, k=3)
375
+
376
+ if relevant_docs:
377
+ # Combine context
378
+ context = "\n\n".join([doc for doc, score in relevant_docs])
379
+
380
+ # Generate answer
381
+ answer = st.session_state.rag_system.generate_answer(query, context, selected_model)
382
+
383
+ # Add to chat history
384
+ st.session_state.chat_history.append(("user", query))
385
+ st.session_state.chat_history.append(("assistant", answer))
386
+
387
+ # Clear input and rerun
388
+ st.rerun()
389
+ else:
390
+ st.error("No relevant documents found for your query.")
391
+
392
+ with col2:
393
+ st.markdown("## πŸ“ˆ System Status")
394
+
395
+ # System info
396
+ if st.session_state.rag_system.documents:
397
+ st.markdown(f"""
398
+ <div class="doc-info">
399
+ <h4>πŸ“š Knowledge Base</h4>
400
+ <p><strong>Documents:</strong> {len(st.session_state.rag_system.documents)} chunks</p>
401
+ <p><strong>Status:</strong> βœ… Ready</p>
402
+ <p><strong>Model:</strong> {selected_model}</p>
403
+ </div>
404
+ """, unsafe_allow_html=True)
405
+ else:
406
+ st.markdown("""
407
+ <div class="doc-info">
408
+ <h4>πŸ“š Knowledge Base</h4>
409
+ <p><strong>Status:</strong> ❌ No documents loaded</p>
410
+ <p>Upload documents to get started!</p>
411
+ </div>
412
+ """, unsafe_allow_html=True)
413
+
414
+ # Instructions
415
+ st.markdown("""
416
+ <div class="sidebar-info">
417
+ <h4>πŸ“‹ How to use:</h4>
418
+ <ol>
419
+ <li>Enter your Groq API key</li>
420
+ <li>Upload documents (PDF, DOCX, TXT)</li>
421
+ <li>Click "Process Documents"</li>
422
+ <li>Ask questions about your documents</li>
423
+ </ol>
424
+ </div>
425
+ """, unsafe_allow_html=True)
426
+
427
+ # Features
428
+ st.markdown("""
429
+ <div class="sidebar-info">
430
+ <h4>✨ Features:</h4>
431
+ <ul>
432
+ <li>πŸš€ Fast inference with Groq</li>
433
+ <li>🧠 Smart document chunking</li>
434
+ <li>πŸ” Semantic search</li>
435
+ <li>πŸ’¬ Chat history</li>
436
+ <li>πŸ“± Responsive design</li>
437
+ </ul>
438
+ </div>
439
+ """, unsafe_allow_html=True)
440
+
441
+ if __name__ == "__main__":
442
+ main()