anasfsd123 commited on
Commit
6feec7f
Β·
verified Β·
1 Parent(s): 6a202e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -0
app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer
6
+ from groq import Groq
7
+ import faiss
8
+ import pickle
9
+ from typing import List, Dict, Tuple
10
+ import PyPDF2
11
+ import docx
12
+ from io import BytesIO
13
+ import time
14
+
15
+ # Initialize Groq client
16
+ def init_groq_client(api_key: str):
17
+ """Initialize Groq client with API key"""
18
+ return Groq(api_key=api_key)
19
+
20
+ # Initialize embedding model
21
+ @st.cache_resource
22
+ def load_embedding_model():
23
+ """Load and cache the sentence transformer model"""
24
+ return SentenceTransformer('all-MiniLM-L6-v2')
25
+
26
+ # Document processing functions
27
+ def extract_text_from_pdf(file):
28
+ """Extract text from PDF file"""
29
+ pdf_reader = PyPDF2.PdfReader(file)
30
+ text = ""
31
+ for page in pdf_reader.pages:
32
+ text += page.extract_text()
33
+ return text
34
+
35
+ def extract_text_from_docx(file):
36
+ """Extract text from DOCX file"""
37
+ doc = docx.Document(file)
38
+ text = ""
39
+ for paragraph in doc.paragraphs:
40
+ text += paragraph.text + "\n"
41
+ return text
42
+
43
+ def extract_text_from_txt(file):
44
+ """Extract text from TXT file"""
45
+ return str(file.read(), "utf-8")
46
+
47
+ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
48
+ """Split text into overlapping chunks"""
49
+ words = text.split()
50
+ chunks = []
51
+
52
+ for i in range(0, len(words), chunk_size - overlap):
53
+ chunk = ' '.join(words[i:i + chunk_size])
54
+ chunks.append(chunk)
55
+
56
+ if i + chunk_size >= len(words):
57
+ break
58
+
59
+ return chunks
60
+
61
+ # Vector store class
62
+ class VectorStore:
63
+ def __init__(self, embedding_model):
64
+ self.embedding_model = embedding_model
65
+ self.documents = []
66
+ self.embeddings = []
67
+ self.index = None
68
+
69
+ def add_documents(self, documents: List[str]):
70
+ """Add documents to the vector store"""
71
+ self.documents.extend(documents)
72
+
73
+ # Generate embeddings
74
+ new_embeddings = self.embedding_model.encode(documents)
75
+
76
+ if len(self.embeddings) == 0:
77
+ self.embeddings = new_embeddings
78
+ else:
79
+ self.embeddings = np.vstack([self.embeddings, new_embeddings])
80
+
81
+ # Build/update FAISS index
82
+ self._build_index()
83
+
84
+ def _build_index(self):
85
+ """Build FAISS index for similarity search"""
86
+ if len(self.embeddings) > 0:
87
+ dimension = self.embeddings.shape[1]
88
+ self.index = faiss.IndexFlatIP(dimension) # Inner product for similarity
89
+
90
+ # Normalize embeddings for cosine similarity
91
+ normalized_embeddings = self.embeddings / np.linalg.norm(
92
+ self.embeddings, axis=1, keepdims=True
93
+ )
94
+ self.index.add(normalized_embeddings.astype('float32'))
95
+
96
+ def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
97
+ """Search for similar documents"""
98
+ if self.index is None or len(self.documents) == 0:
99
+ return []
100
+
101
+ # Encode query
102
+ query_embedding = self.embedding_model.encode([query])
103
+ query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
104
+
105
+ # Search
106
+ scores, indices = self.index.search(query_embedding.astype('float32'), top_k)
107
+
108
+ results = []
109
+ for score, idx in zip(scores[0], indices[0]):
110
+ if idx < len(self.documents):
111
+ results.append((self.documents[idx], float(score)))
112
+
113
+ return results
114
+
115
+ def save(self, filepath: str):
116
+ """Save vector store to file"""
117
+ data = {
118
+ 'documents': self.documents,
119
+ 'embeddings': self.embeddings.tolist() if len(self.embeddings) > 0 else []
120
+ }
121
+ with open(filepath, 'wb') as f:
122
+ pickle.dump(data, f)
123
+
124
+ def load(self, filepath: str):
125
+ """Load vector store from file"""
126
+ with open(filepath, 'rb') as f:
127
+ data = pickle.load(f)
128
+
129
+ self.documents = data['documents']
130
+ if data['embeddings']:
131
+ self.embeddings = np.array(data['embeddings'])
132
+ self._build_index()
133
+
134
+ # RAG class
135
+ class RAGSystem:
136
+ def __init__(self, groq_client, embedding_model):
137
+ self.groq_client = groq_client
138
+ self.vector_store = VectorStore(embedding_model)
139
+
140
+ def add_documents(self, documents: List[str]):
141
+ """Add documents to the knowledge base"""
142
+ self.vector_store.add_documents(documents)
143
+
144
+ def query(self, question: str, model: str = "llama-3.3-70b-versatile", top_k: int = 3) -> Dict:
145
+ """Answer a question using RAG"""
146
+ # Retrieve relevant documents
147
+ retrieved_docs = self.vector_store.search(question, top_k=top_k)
148
+
149
+ if not retrieved_docs:
150
+ return {
151
+ "answer": "I don't have any relevant information to answer your question.",
152
+ "sources": [],
153
+ "confidence": 0.0
154
+ }
155
+
156
+ # Prepare context
157
+ context = "\n\n".join([doc for doc, score in retrieved_docs])
158
+
159
+ # Create prompt
160
+ prompt = f"""Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question."
161
+
162
+ Context:
163
+ {context}
164
+
165
+ Question: {question}
166
+
167
+ Answer:"""
168
+
169
+ try:
170
+ # Get response from Groq
171
+ chat_completion = self.groq_client.chat.completions.create(
172
+ messages=[
173
+ {
174
+ "role": "user",
175
+ "content": prompt,
176
+ }
177
+ ],
178
+ model=model,
179
+ temperature=0.1,
180
+ max_tokens=1000,
181
+ )
182
+
183
+ answer = chat_completion.choices[0].message.content
184
+
185
+ return {
186
+ "answer": answer,
187
+ "sources": [{"text": doc[:200] + "...", "score": score}
188
+ for doc, score in retrieved_docs],
189
+ "confidence": max([score for _, score in retrieved_docs]) if retrieved_docs else 0.0
190
+ }
191
+
192
+ except Exception as e:
193
+ return {
194
+ "answer": f"Error generating response: {str(e)}",
195
+ "sources": [],
196
+ "confidence": 0.0
197
+ }
198
+
199
+ # Streamlit App
200
+ def main():
201
+ st.set_page_config(
202
+ page_title="RAG App with Groq",
203
+ page_icon="πŸ€–",
204
+ layout="wide",
205
+ initial_sidebar_state="expanded"
206
+ )
207
+
208
+ st.title("πŸ€– RAG App with Groq & Sentence Transformers")
209
+ st.markdown("Ask questions about your documents using open-source models!")
210
+
211
+ # Sidebar
212
+ st.sidebar.header("βš™οΈ Configuration")
213
+
214
+ # API Key input
215
+ api_key = st.sidebar.text_input(
216
+ "Groq API Key",
217
+ value=os.getenv("GROQ_API_KEY", ""),
218
+ type="password",
219
+ help="Enter your Groq API key"
220
+ )
221
+
222
+ # Model selection
223
+ model_options = [
224
+ "llama-3.3-70b-versatile",
225
+ "llama-3.1-70b-versatile",
226
+ "llama-3.1-8b-instant",
227
+ "mixtral-8x7b-32768"
228
+ ]
229
+ selected_model = st.sidebar.selectbox("Select Model", model_options)
230
+
231
+ # Number of retrieved documents
232
+ top_k = st.sidebar.slider("Number of retrieved documents", 1, 10, 3)
233
+
234
+ # Initialize components
235
+ if api_key:
236
+ try:
237
+ groq_client = init_groq_client(api_key)
238
+ embedding_model = load_embedding_model()
239
+
240
+ # Initialize session state
241
+ if 'rag_system' not in st.session_state:
242
+ st.session_state.rag_system = RAGSystem(groq_client, embedding_model)
243
+
244
+ # Main content area
245
+ col1, col2 = st.columns([1, 1])
246
+
247
+ with col1:
248
+ st.header("πŸ“ Document Upload")
249
+
250
+ uploaded_files = st.file_uploader(
251
+ "Upload your documents",
252
+ type=['pdf', 'docx', 'txt'],
253
+ accept_multiple_files=True,
254
+ help="Supported formats: PDF, DOCX, TXT"
255
+ )
256
+
257
+ if uploaded_files:
258
+ if st.button("Process Documents", type="primary"):
259
+ with st.spinner("Processing documents..."):
260
+ all_chunks = []
261
+
262
+ for file in uploaded_files:
263
+ # Extract text based on file type
264
+ if file.type == "application/pdf":
265
+ text = extract_text_from_pdf(file)
266
+ elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
267
+ text = extract_text_from_docx(file)
268
+ elif file.type == "text/plain":
269
+ text = extract_text_from_txt(file)
270
+ else:
271
+ st.error(f"Unsupported file type: {file.type}")
272
+ continue
273
+
274
+ # Chunk the text
275
+ chunks = chunk_text(text, chunk_size=500, overlap=50)
276
+ all_chunks.extend(chunks)
277
+
278
+ st.success(f"βœ… Processed {file.name}: {len(chunks)} chunks")
279
+
280
+ # Add to RAG system
281
+ if all_chunks:
282
+ st.session_state.rag_system.add_documents(all_chunks)
283
+ st.success(f"πŸŽ‰ Added {len(all_chunks)} chunks to knowledge base!")
284
+
285
+ # Display document stats
286
+ if hasattr(st.session_state.rag_system, 'vector_store') and len(st.session_state.rag_system.vector_store.documents) > 0:
287
+ st.info(f"πŸ“Š Knowledge Base: {len(st.session_state.rag_system.vector_store.documents)} chunks")
288
+
289
+ with col2:
290
+ st.header("πŸ’¬ Ask Questions")
291
+
292
+ # Chat interface
293
+ if "messages" not in st.session_state:
294
+ st.session_state.messages = []
295
+
296
+ # Display chat history
297
+ for message in st.session_state.messages:
298
+ with st.chat_message(message["role"]):
299
+ st.write(message["content"])
300
+ if message["role"] == "assistant" and "sources" in message:
301
+ with st.expander("πŸ“š Sources"):
302
+ for i, source in enumerate(message["sources"]):
303
+ st.write(f"**Source {i+1}** (Score: {source['score']:.3f})")
304
+ st.write(source["text"])
305
+
306
+ # Chat input
307
+ if prompt := st.chat_input("Ask a question about your documents..."):
308
+ # Add user message
309
+ st.session_state.messages.append({"role": "user", "content": prompt})
310
+
311
+ with st.chat_message("user"):
312
+ st.write(prompt)
313
+
314
+ # Generate response
315
+ with st.chat_message("assistant"):
316
+ with st.spinner("Thinking..."):
317
+ response = st.session_state.rag_system.query(
318
+ prompt,
319
+ model=selected_model,
320
+ top_k=top_k
321
+ )
322
+
323
+ st.write(response["answer"])
324
+
325
+ # Show sources
326
+ if response["sources"]:
327
+ with st.expander("πŸ“š Sources"):
328
+ for i, source in enumerate(response["sources"]):
329
+ st.write(f"**Source {i+1}** (Score: {source['score']:.3f})")
330
+ st.write(source["text"])
331
+
332
+ # Add to chat history
333
+ st.session_state.messages.append({
334
+ "role": "assistant",
335
+ "content": response["answer"],
336
+ "sources": response["sources"]
337
+ })
338
+
339
+ # Clear chat button
340
+ if st.button("πŸ—‘οΈ Clear Chat"):
341
+ st.session_state.messages = []
342
+ st.rerun()
343
+
344
+ except Exception as e:
345
+ st.error(f"Error initializing components: {str(e)}")
346
+
347
+ else:
348
+ st.warning("Please enter your Groq API key in the sidebar to get started.")
349
+
350
+ # Footer
351
+ st.sidebar.markdown("---")
352
+ st.sidebar.markdown(
353
+ """
354
+ **About this app:**
355
+ - Uses Groq for fast inference
356
+ - Sentence Transformers for embeddings
357
+ - FAISS for vector search
358
+ - Supports PDF, DOCX, TXT files
359
+ """
360
+ )
361
+
362
+ if __name__ == "__main__":
363
+ main()