SyedZainAliShah commited on
Commit
e9eb5ef
Β·
verified Β·
1 Parent(s): acf338b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +315 -0
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from groq import Groq
4
+ import PyPDF2
5
+ from sentence_transformers import SentenceTransformer
6
+ import numpy as np
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import json
9
+ from datetime import datetime
10
+ import docx
11
+
12
+ # Initialize Groq client
13
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
14
+
15
+ # Initialize sentence transformer model for embeddings
16
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
17
+
18
+ # Global storage for documents and conversation history
19
+ document_store = {
20
+ 'chunks': [],
21
+ 'embeddings': [],
22
+ 'metadata': [],
23
+ 'conversation_history': []
24
+ }
25
+
26
+ def extract_text_from_pdf(pdf_file):
27
+ """Extract text from PDF file"""
28
+ try:
29
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
30
+ text_data = []
31
+ for page_num, page in enumerate(pdf_reader.pages):
32
+ text = page.extract_text()
33
+ text_data.append({
34
+ 'text': text,
35
+ 'page': page_num + 1,
36
+ 'filename': os.path.basename(pdf_file.name)
37
+ })
38
+ return text_data
39
+ except Exception as e:
40
+ return [{'text': f"Error reading PDF: {str(e)}", 'page': 0, 'filename': pdf_file.name}]
41
+
42
+ def extract_text_from_docx(docx_file):
43
+ """Extract text from DOCX file (Enhancement 5)"""
44
+ try:
45
+ doc = docx.Document(docx_file)
46
+ text = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
47
+ return [{
48
+ 'text': text,
49
+ 'page': 1,
50
+ 'filename': os.path.basename(docx_file.name)
51
+ }]
52
+ except Exception as e:
53
+ return [{'text': f"Error reading DOCX: {str(e)}", 'page': 0, 'filename': docx_file.name}]
54
+
55
+ def chunk_text(text_data, chunk_size=500, overlap=50):
56
+ """Split text into semantic chunks with overlap (Enhancement 6)"""
57
+ chunks = []
58
+ metadata = []
59
+
60
+ for data in text_data:
61
+ text = data['text']
62
+ words = text.split()
63
+
64
+ for i in range(0, len(words), chunk_size - overlap):
65
+ chunk = ' '.join(words[i:i + chunk_size])
66
+ if len(chunk.strip()) > 50: # Only keep meaningful chunks
67
+ chunks.append(chunk)
68
+ metadata.append({
69
+ 'page': data['page'],
70
+ 'filename': data['filename'],
71
+ 'chunk_id': len(chunks)
72
+ })
73
+
74
+ return chunks, metadata
75
+
76
+ def create_embeddings(chunks):
77
+ """Create embeddings using sentence-transformers (Enhancement 1)"""
78
+ embeddings = embedder.encode(chunks)
79
+ return embeddings
80
+
81
+ def process_files(files):
82
+ """Process uploaded files and create vector store"""
83
+ global document_store
84
+
85
+ if not files:
86
+ return "❌ Please upload at least one file."
87
+
88
+ document_store = {
89
+ 'chunks': [],
90
+ 'embeddings': [],
91
+ 'metadata': [],
92
+ 'conversation_history': []
93
+ }
94
+
95
+ all_text_data = []
96
+ file_summaries = []
97
+
98
+ for file in files:
99
+ file_ext = os.path.splitext(file.name)[1].lower()
100
+
101
+ if file_ext == '.pdf':
102
+ text_data = extract_text_from_pdf(file)
103
+ elif file_ext == '.docx':
104
+ text_data = extract_text_from_docx(file)
105
+ else:
106
+ continue
107
+
108
+ all_text_data.extend(text_data)
109
+
110
+ # Generate file summary (Enhancement 2)
111
+ total_text = ' '.join([d['text'] for d in text_data])
112
+ file_summaries.append(f"πŸ“„ **{os.path.basename(file.name)}** - {len(text_data)} pages, {len(total_text)} characters")
113
+
114
+ # Chunk and embed
115
+ chunks, metadata = chunk_text(all_text_data)
116
+ embeddings = create_embeddings(chunks)
117
+
118
+ document_store['chunks'] = chunks
119
+ document_store['embeddings'] = embeddings
120
+ document_store['metadata'] = metadata
121
+
122
+ summary = f"βœ… **Processed {len(files)} file(s)**\n\n" + "\n".join(file_summaries)
123
+ summary += f"\n\nπŸ“Š Created {len(chunks)} text chunks for retrieval."
124
+
125
+ return summary
126
+
127
+ def retrieve_relevant_chunks(query, top_k=3):
128
+ """Retrieve most relevant chunks using cosine similarity"""
129
+ if not document_store['chunks']:
130
+ return [], []
131
+
132
+ query_embedding = embedder.encode([query])
133
+ similarities = cosine_similarity(query_embedding, document_store['embeddings'])[0]
134
+
135
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
136
+
137
+ relevant_chunks = [document_store['chunks'][i] for i in top_indices]
138
+ relevant_metadata = [document_store['metadata'][i] for i in top_indices]
139
+
140
+ return relevant_chunks, relevant_metadata
141
+
142
+ def generate_answer(query, history):
143
+ """Generate answer using Groq LLM with RAG (Enhancement 3 - Conversational Memory)"""
144
+ if not document_store['chunks']:
145
+ return "⚠️ Please upload and process documents first."
146
+
147
+ # Retrieve relevant context
148
+ relevant_chunks, metadata = retrieve_relevant_chunks(query, top_k=3)
149
+
150
+ if not relevant_chunks:
151
+ return "❌ No relevant information found in the documents."
152
+
153
+ # Build context with source references (Enhancement 4)
154
+ context = "\n\n".join([
155
+ f"[Source: {meta['filename']}, Page {meta['page']}]\n{chunk}"
156
+ for chunk, meta in zip(relevant_chunks, metadata)
157
+ ])
158
+
159
+ # Build conversation history for context
160
+ history_context = ""
161
+ if history:
162
+ history_context = "\n".join([
163
+ f"User: {h[0]}\nAssistant: {h[1]}"
164
+ for h in history[-3:] # Last 3 exchanges
165
+ ])
166
+
167
+ # Create prompt
168
+ prompt = f"""You are a helpful assistant that answers questions based on the provided document context.
169
+
170
+ Previous conversation:
171
+ {history_context}
172
+
173
+ Context from documents:
174
+ {context}
175
+
176
+ Current question: {query}
177
+
178
+ Instructions:
179
+ - Answer based strictly on the provided context
180
+ - If the answer isn't in the context, say so
181
+ - Be concise and accurate
182
+ - Reference specific sources when relevant
183
+
184
+ Answer:"""
185
+
186
+ try:
187
+ # Call Groq API
188
+ chat_completion = client.chat.completions.create(
189
+ messages=[
190
+ {
191
+ "role": "user",
192
+ "content": prompt,
193
+ }
194
+ ],
195
+ model="llama3-8b-8192",
196
+ temperature=0.3,
197
+ max_tokens=1024,
198
+ )
199
+
200
+ answer = chat_completion.choices[0].message.content
201
+
202
+ # Add source references to answer (Enhancement 4)
203
+ sources = "\n\nπŸ“š **Sources:**\n" + "\n".join([
204
+ f"- {meta['filename']} (Page {meta['page']})"
205
+ for meta in metadata
206
+ ])
207
+
208
+ full_answer = answer + sources
209
+
210
+ # Log query (Enhancement 8)
211
+ document_store['conversation_history'].append({
212
+ 'timestamp': datetime.now().isoformat(),
213
+ 'query': query,
214
+ 'answer': answer,
215
+ 'sources': [f"{m['filename']}_p{m['page']}" for m in metadata]
216
+ })
217
+
218
+ return full_answer
219
+
220
+ except Exception as e:
221
+ return f"❌ Error generating answer: {str(e)}"
222
+
223
+ def download_chat_history():
224
+ """Download conversation history as JSON (Enhancement 7)"""
225
+ if not document_store['conversation_history']:
226
+ return None
227
+
228
+ history_file = "chat_history.json"
229
+ with open(history_file, 'w') as f:
230
+ json.dump(document_store['conversation_history'], f, indent=2)
231
+
232
+ return history_file
233
+
234
+ def clear_history():
235
+ """Clear conversation history"""
236
+ document_store['conversation_history'] = []
237
+ return None, "πŸ—‘οΈ History cleared!"
238
+
239
+ # Build Gradio Interface
240
+ with gr.Blocks(title="Enhanced RAG Chatbot", theme=gr.themes.Soft()) as demo:
241
+ gr.Markdown("""
242
+ # πŸ€– Enhanced RAG-Based Chatbot
243
+ Upload PDF/DOCX files and ask questions about their content!
244
+
245
+ **Features:**
246
+ - βœ… Multiple file support (PDF & DOCX)
247
+ - βœ… Semantic embeddings with sentence-transformers
248
+ - βœ… Document preview & summaries
249
+ - βœ… Conversational memory
250
+ - βœ… Source references with page numbers
251
+ - βœ… Download chat history
252
+ """)
253
+
254
+ with gr.Row():
255
+ with gr.Column(scale=1):
256
+ file_upload = gr.File(
257
+ label="Upload Documents (PDF/DOCX)",
258
+ file_count="multiple",
259
+ file_types=[".pdf", ".docx"]
260
+ )
261
+ process_btn = gr.Button("πŸ“‚ Process Documents", variant="primary")
262
+ process_output = gr.Markdown(label="Processing Status")
263
+
264
+ gr.Markdown("### πŸ’Ύ Chat History")
265
+ download_btn = gr.Button("⬇️ Download History")
266
+ download_file = gr.File(label="Download")
267
+ clear_btn = gr.Button("πŸ—‘οΈ Clear History")
268
+ clear_msg = gr.Textbox(label="Status", interactive=False)
269
+
270
+ with gr.Column(scale=2):
271
+ chatbot = gr.Chatbot(label="Conversation", height=500)
272
+ query_input = gr.Textbox(
273
+ label="Ask a question",
274
+ placeholder="Type your question here...",
275
+ lines=2
276
+ )
277
+ submit_btn = gr.Button("πŸš€ Ask", variant="primary")
278
+
279
+ # Event handlers
280
+ process_btn.click(
281
+ fn=process_files,
282
+ inputs=[file_upload],
283
+ outputs=[process_output]
284
+ )
285
+
286
+ submit_btn.click(
287
+ fn=generate_answer,
288
+ inputs=[query_input, chatbot],
289
+ outputs=[chatbot]
290
+ ).then(
291
+ lambda q, h: (h + [[q, generate_answer(q, h)]], ""),
292
+ inputs=[query_input, chatbot],
293
+ outputs=[chatbot, query_input]
294
+ )
295
+
296
+ download_btn.click(
297
+ fn=download_chat_history,
298
+ outputs=[download_file]
299
+ )
300
+
301
+ clear_btn.click(
302
+ fn=clear_history,
303
+ outputs=[chatbot, clear_msg]
304
+ )
305
+
306
+ gr.Markdown("""
307
+ ---
308
+ ### πŸ“– How RAG Works:
309
+ 1. **Retrieval**: Finds relevant text chunks from uploaded documents using semantic similarity
310
+ 2. **Augmentation**: Combines retrieved context with your question
311
+ 3. **Generation**: Uses Groq LLM to generate accurate answers based on the context
312
+ """)
313
+
314
+ if __name__ == "__main__":
315
+ demo.launch()