Zahid0123 commited on
Commit
ef08035
ยท
verified ยท
1 Parent(s): 80ad68a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +464 -0
app.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import List, Tuple, Any
7
+ import numpy as np
8
+ import PyPDF2
9
+ from sentence_transformers import SentenceTransformer
10
+ import faiss
11
+ import gradio as gr
12
+ from gtts import gTTS
13
+ import requests
14
+ import math
15
+ import ast
16
+ import json
17
+
18
+ try:
19
+ import sympy as sp
20
+ SYMPY_OK = True
21
+ except Exception:
22
+ SYMPY_OK = False
23
+
24
+ try:
25
+ from groq import Groq
26
+ GROQ_OK = True
27
+ except ImportError:
28
+ GROQ_OK = False
29
+ print("โŒ Groq library not installed!")
30
+
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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")
35
+ groq_client = None
36
+
37
+ if GROQ_OK:
38
+ try:
39
+ groq_client = Groq(api_key=GROQ_API_KEY)
40
+ print("โœ… Groq client initialized successfully!")
41
+ except Exception as e:
42
+ groq_client = None
43
+ print(f"โŒ Groq initialization error: {e}")
44
+
45
+ # GLOBAL CHAT MEMORY (NO LONGER USED)
46
+ chat_memory = []
47
+
48
+
49
+ # Safe evaluation for calculations
50
+ class SafeEval(ast.NodeVisitor):
51
+ ALLOWED_NAMES = {n: getattr(math, n) for n in dir(math) if not n.startswith("__")}
52
+ ALLOWED_NAMES.update({"abs": abs, "round": round})
53
+
54
+ def visit(self, node):
55
+ if isinstance(node, ast.Expression):
56
+ return self.visit(node.body)
57
+ if isinstance(node, ast.BinOp):
58
+ left = self.visit(node.left)
59
+ right = self.visit(node.right)
60
+ return self._binop(node.op, left, right)
61
+ if isinstance(node, ast.UnaryOp):
62
+ operand = self.visit(node.operand)
63
+ return self._unaryop(node.op, operand)
64
+ if isinstance(node, ast.Num):
65
+ return node.n
66
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
67
+ return node.value
68
+ if isinstance(node, ast.Call):
69
+ func = node.func
70
+ if isinstance(func, ast.Name) and func.id in self.ALLOWED_NAMES:
71
+ args = [self.visit(a) for a in node.args]
72
+ return self.ALLOWED_NAMES[func.id](*args)
73
+ if isinstance(node, ast.Name):
74
+ if node.id in self.ALLOWED_NAMES:
75
+ return self.ALLOWED_NAMES[node.id]
76
+ raise ValueError(f"Use of name '{node.id}' is not allowed")
77
+ raise ValueError(f"Unsupported expression: {ast.dump(node)}")
78
+
79
+ def _binop(self, op, a, b):
80
+ if isinstance(op, ast.Add): return a + b
81
+ if isinstance(op, ast.Sub): return a - b
82
+ if isinstance(op, ast.Mult): return a * b
83
+ if isinstance(op, ast.Div): return a / b
84
+ if isinstance(op, ast.Mod): return a % b
85
+ if isinstance(op, ast.Pow): return a ** b
86
+ raise ValueError("Unsupported binary operator")
87
+
88
+ def _unaryop(self, op, a):
89
+ if isinstance(op, ast.UAdd): return +a
90
+ if isinstance(op, ast.USub): return -a
91
+ raise ValueError("Unsupported unary operator")
92
+
93
+
94
+ def safe_calc_eval(expr: str):
95
+ expr = expr.strip()
96
+ if SYMPY_OK:
97
+ try:
98
+ result = sp.sympify(expr)
99
+ numeric = None
100
+ try:
101
+ numeric = float(result.evalf())
102
+ except:
103
+ numeric = None
104
+ if numeric is not None:
105
+ return True, str(numeric)
106
+ return True, str(result)
107
+ except:
108
+ pass
109
+ try:
110
+ node = ast.parse(expr, mode='eval')
111
+ se = SafeEval()
112
+ val = se.visit(node)
113
+ return True, str(val)
114
+ except Exception as e:
115
+ return False, f"Calc error: {e}"
116
+
117
+
118
+ # Simple web search
119
+ def web_search(query: str, max_results: int = 3) -> List[dict]:
120
+ try:
121
+ resp = requests.get(
122
+ "https://html.duckduckgo.com/html/",
123
+ params={"q": query},
124
+ timeout=10,
125
+ headers={"User-Agent": "Mozilla/5.0"}
126
+ )
127
+ resp.raise_for_status()
128
+ text = resp.text
129
+ results = []
130
+ parts = text.split('result__a')
131
+ for part in parts[1:max_results+1]:
132
+ try:
133
+ title = part.split('>')[1].split('<')[0]
134
+ except:
135
+ title = ""
136
+ snippet = ""
137
+ try:
138
+ snippet = part.split('result__snippet')[1].split('>')[1].split('<')[0]
139
+ except:
140
+ snippet = ""
141
+ results.append({"title": title, "snippet": snippet})
142
+ return results
143
+ except:
144
+ return []
145
+
146
+
147
+ class AgenticRAGAgent:
148
+ def __init__(self):
149
+ self.chunks = []
150
+ self.index = None
151
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
152
+ self.temperature = 0.3
153
+ self.max_tokens = 500
154
+ self.chunk_size = 512
155
+ self.chunk_overlap = 50
156
+ self.retrieval_k = 8
157
+ self.enable_web_search = True
158
+ self.enable_calculations = True
159
+ self.enable_fact_checking = True
160
+ self.enable_analysis = True
161
+ print("โœ… AgenticRAGAgent initialized")
162
+
163
+ def remove_emojis(self, text: str) -> str:
164
+ emoji_pattern = re.compile("["
165
+ u"\U0001F600-\U0001F64F"
166
+ u"\U0001F300-\U0001F5FF"
167
+ u"\U0001F680-\U0001F6FF"
168
+ u"\U0001F1E0-\U0001F1FF"
169
+ u"\U00002702-\U000027B0"
170
+ u"\U000024C2-\U0001F251"
171
+ "]+", flags=re.UNICODE)
172
+ return emoji_pattern.sub(r'', text)
173
+
174
+ def clean_for_voice(self, text: str) -> str:
175
+ text = self.remove_emojis(text)
176
+ text = re.sub(r'[\*_`#\[\]]', '', text)
177
+ text = re.sub(r'\s+', ' ', text).strip()
178
+ return text
179
+
180
+ def generate_voice(self, text: str):
181
+ if not text or not text.strip():
182
+ return None
183
+ clean = self.clean_for_voice(text)
184
+ if len(clean) < 5:
185
+ return None
186
+ try:
187
+ tts = gTTS(text=clean, lang='en', slow=False)
188
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
189
+ tts.save(tmp.name)
190
+ return tmp.name
191
+ except Exception as e:
192
+ logger.error(f"Voice generation failed: {e}")
193
+ return None
194
+
195
+ def upload_pdfs(self, files):
196
+ if not files:
197
+ return "No files selected."
198
+ folder = Path("sample_data")
199
+ folder.mkdir(exist_ok=True)
200
+ all_chunks = []
201
+ count = 0
202
+ for file in files:
203
+ filename = str(file.name) if hasattr(file, 'name') else str(file)
204
+ if not filename.lower().endswith('.pdf'):
205
+ continue
206
+ dest = folder / Path(filename).name
207
+ try:
208
+ content = file.read() if hasattr(file, 'read') else open(filename, 'rb').read()
209
+ with open(dest, "wb") as f:
210
+ f.write(content)
211
+ except Exception as e:
212
+ logger.warning(f"Failed to save file {filename}: {e}")
213
+ continue
214
+ text = ""
215
+ try:
216
+ with open(dest, 'rb') as f:
217
+ reader = PyPDF2.PdfReader(f)
218
+ for page in reader.pages:
219
+ t = page.extract_text()
220
+ if t:
221
+ text += t + " "
222
+ except Exception as e:
223
+ logger.warning(f"Failed to extract text from {filename}: {e}")
224
+ continue
225
+ if text.strip():
226
+ chunks = [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size - self.chunk_overlap)]
227
+ all_chunks.extend([{"content": str(c.strip())} for c in chunks if c.strip()])
228
+ count += 1
229
+ if not all_chunks:
230
+ return "No readable text found in the PDFs."
231
+ print(f"Creating embeddings for {len(all_chunks)} chunks...")
232
+ vecs = self.embedder.encode([c["content"] for c in all_chunks], show_progress_bar=True)
233
+ vecs = vecs / np.linalg.norm(vecs, axis=1, keepdims=True)
234
+ dim = vecs.shape[1]
235
+ self.index = faiss.IndexFlatIP(dim)
236
+ self.index.add(vecs.astype('float32'))
237
+ self.chunks = all_chunks
238
+ status_msg = f"โœ… Loaded {count} PDF(s) โ†’ {len(all_chunks)} chunks ready!"
239
+ print(status_msg)
240
+ return status_msg
241
+
242
+ def detect_math(self, text: str):
243
+ if re.search(r'[0-9]', text) and re.search(r'[\+\-\*\/\^%=]', text):
244
+ expr = text.strip()
245
+ expr = re.sub(r'[a-zA-Z,?]+', '', expr)
246
+ expr = expr.strip()
247
+ return expr if len(expr) > 0 else None
248
+ return None
249
+
250
+ def perform_fact_check(self, text: str, context: str) -> str:
251
+ if not context or not text: return ""
252
+ try:
253
+ claims = [s.strip() for s in text.split('.') if s.strip() and len(s.strip()) > 10]
254
+ verified = []
255
+ for claim in claims[:2]:
256
+ key_terms = [w for w in claim.split() if len(w) > 4]
257
+ matches = sum(1 for term in key_terms if term.lower() in context.lower())
258
+ if matches >= len(key_terms) * 0.5:
259
+ verified.append(f"โœ“ {claim[:60]}...")
260
+ if verified:
261
+ return "\n[โœ… Fact Check]\n" + "\n".join(verified)
262
+ return ""
263
+ except:
264
+ return ""
265
+
266
+ def perform_analysis(self, text: str, context: str, question: str) -> str:
267
+ if not text or len(text) < 20: return ""
268
+ analysis = []
269
+ sentence_count = len([s for s in text.split('.') if s.strip()])
270
+ if sentence_count >= 3: analysis.append("๐Ÿ“Š Comprehensive answer with multiple points")
271
+ context_refs = sum(1 for word in context.split() if len(word) > 5 and word.lower() in text.lower())
272
+ if context_refs > 0: analysis.append(f"๐Ÿ“„ References from {context_refs} key context terms")
273
+ word_count = len(text.split())
274
+ if word_count > 100: analysis.append(f"๐Ÿ“ Detailed response ({word_count} words)")
275
+ elif word_count > 50: analysis.append(f"๐Ÿ“ Moderate response ({word_count} words)")
276
+ q_words = [w.lower() for w in question.split() if len(w) > 3]
277
+ answer_relevance = sum(1 for w in q_words if w in text.lower())
278
+ if answer_relevance >= len(q_words) * 0.5: analysis.append("โœ“ Answer directly addresses the question")
279
+ if analysis:
280
+ return "\n[๐Ÿ“Š Analysis]\n" + "\n".join(analysis)
281
+ return ""
282
+
283
+ def ask(self, question: str, history: List) -> Tuple[List, Any]:
284
+ global groq_client
285
+ if not isinstance(question, str): question = str(question) if question else ""
286
+ if not isinstance(history, list): history = []
287
+ question = question.strip()
288
+ if not question: return history, None
289
+
290
+ if question.lower() in ["hi", "hello", "hey"]:
291
+ reply = "Hi! I am your AI Research Agent. Upload PDFs and ask questions."
292
+ history.append([question, reply])
293
+ return history, self.generate_voice(reply)
294
+
295
+ if not self.index:
296
+ reply = "Please upload a PDF first!"
297
+ history.append([question, reply])
298
+ return history, self.generate_voice(reply)
299
+
300
+ try:
301
+ q_vec = self.embedder.encode([question])
302
+ q_vec = q_vec / np.linalg.norm(q_vec)
303
+ D, I = self.index.search(q_vec.astype('float32'), k=self.retrieval_k)
304
+ context_list = [self.chunks[i]["content"] for i in I[0] if i < len(self.chunks)]
305
+ context = "\n\n".join(context_list).strip()
306
+ except:
307
+ context = ""
308
+
309
+ prompt = f"Context from documents:\n{context}\n\nQuestion: {question}\nAnswer clearly and accurately:" if context else f"Question: {question}\nAnswer clearly and accurately:"
310
+
311
+ calc_note = ""
312
+ web_note = ""
313
+ fact_note = ""
314
+ analysis_note = ""
315
+
316
+ if self.enable_calculations:
317
+ expr = self.detect_math(question)
318
+ if expr:
319
+ ok, res = safe_calc_eval(expr)
320
+ if ok:
321
+ calc_note = f"\n[๐Ÿงฎ Calculator] {expr} = {res}"
322
+
323
+ if self.enable_web_search:
324
+ keywords = ["latest", "today", "current", "recent", "news"]
325
+ if any(k in question.lower() for k in keywords):
326
+ results = web_search(question)
327
+ if results:
328
+ web_note = "\n[๐ŸŒ Web Sources]:\n" + "\n".join([f"- {r.get('title','')}" for r in results[:2]])
329
+
330
+ reply = "Error processing request."
331
+
332
+ if groq_client:
333
+ try:
334
+ messages = [{"role": "user", "content": prompt}]
335
+ resp = groq_client.chat.completions.create(
336
+ model="llama-3.3-70b-versatile",
337
+ messages=messages,
338
+ temperature=float(self.temperature),
339
+ max_tokens=int(self.max_tokens)
340
+ )
341
+ if resp and resp.choices and len(resp.choices) > 0:
342
+ reply = str(resp.choices[0].message.content).strip()
343
+ else:
344
+ reply = "No response from API"
345
+ except Exception as e:
346
+ reply = f"Error: {e}"
347
+
348
+ if calc_note: reply += calc_note
349
+ if web_note: reply += web_note
350
+ if self.enable_fact_checking: fact_note = self.perform_fact_check(reply, context)
351
+ if fact_note: reply += fact_note
352
+ if self.enable_analysis: analysis_note = self.perform_analysis(reply, context, question)
353
+ if analysis_note: reply += analysis_note
354
+
355
+ history.append([question, reply])
356
+ return history, self.generate_voice(reply)
357
+
358
+ def update_settings(self, temp, tokens, chunk_size, overlap, k, web, calc, fact, analysis):
359
+ self.temperature = float(temp)
360
+ self.max_tokens = int(tokens)
361
+ self.chunk_size = int(chunk_size)
362
+ self.chunk_overlap = int(overlap)
363
+ self.retrieval_k = int(k)
364
+ self.enable_web_search = bool(web)
365
+ self.enable_calculations = bool(calc)
366
+ self.enable_fact_checking = bool(fact)
367
+ self.enable_analysis = bool(analysis)
368
+ return f"""โš™๏ธ Settings Updated:
369
+ โ€ข Temperature: {temp}
370
+ โ€ข Max Tokens: {tokens}
371
+ โ€ข Chunk Size: {chunk_size}
372
+ โ€ข Chunk Overlap: {overlap}
373
+ โ€ข Retrieved Chunks: {k}
374
+ โ€ข Web Search: {'โœ…' if web else 'โŒ'}
375
+ โ€ข Calculator: {'โœ…' if calc else 'โŒ'}
376
+ โ€ข Fact Check: {'โœ…' if fact else 'โŒ'}
377
+ โ€ข Analysis: {'โœ…' if analysis else 'โŒ'}"""
378
+
379
+
380
+ # ===== FIXED: Gradio Interface (session-based memory) =====
381
+ def create_interface():
382
+ agent = AgenticRAGAgent()
383
+
384
+ with gr.Blocks(title="AI Research Agent") as interface:
385
+
386
+ chat_memory = gr.State([]) # <-- FIX: user-specific memory
387
+
388
+ gr.HTML("""
389
+ <div style="text-align:center;padding:20px;background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);border-radius:15px;">
390
+ <h1 style="color:white;margin:0;">๐Ÿค– AI Research Agent - Agentic RAG</h1>
391
+ <p style="color:white;margin:10px 0;">Advanced Multi-Tool Research Assistant with Voice Support ๐ŸŽค๐Ÿ”Š</p>
392
+ </div>
393
+ """)
394
+
395
+ with gr.Row():
396
+ with gr.Column(scale=2):
397
+ chatbot = gr.Chatbot(label="๐Ÿ’ฌ Chat", height=500)
398
+
399
+ with gr.Row():
400
+ msg = gr.Textbox(placeholder="Ask a complex research question...", scale=4, lines=1)
401
+ submit_btn = gr.Button("๐Ÿš€ Send", variant="primary", scale=1)
402
+
403
+ with gr.Row():
404
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Chat", variant="secondary")
405
+
406
+ audio_output = gr.Audio(label="๐Ÿ”Š Voice Response", autoplay=True, interactive=False)
407
+
408
+ with gr.Column(scale=1):
409
+ gr.HTML("<h3 style='text-align:center;'>๐Ÿ“„ Upload Documents</h3>")
410
+ pdf_upload = gr.Files(file_types=[".pdf"], label="Upload PDFs")
411
+ upload_status = gr.Textbox(label="๐Ÿ“Š Status", interactive=False, max_lines=10)
412
+
413
+ with gr.Accordion("โš™๏ธ AI Parameters", open=False):
414
+ temperature_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.1, label="๐ŸŒก๏ธ Temperature")
415
+ max_tokens_slider = gr.Slider(100, 2000, value=500, step=50, label="๐Ÿ“ Max Tokens")
416
+
417
+ with gr.Accordion("๐Ÿ“„ Document Processing", open=False):
418
+ chunk_size_slider = gr.Slider(256, 1024, value=512, step=64, label="๐Ÿ“„ Chunk Size")
419
+ chunk_overlap_slider = gr.Slider(0, 200, value=50, step=10, label="๐Ÿ”— Chunk Overlap")
420
+ retrieval_k_slider = gr.Slider(3, 15, value=8, step=1, label="๐Ÿ” Retrieved Chunks")
421
+
422
+ with gr.Accordion("๐Ÿ› ๏ธ Agentic Tools", open=False):
423
+ enable_web = gr.Checkbox(value=True, label="๐ŸŒ Web Search")
424
+ enable_calc = gr.Checkbox(value=True, label="๐Ÿงฎ Calculator")
425
+ enable_fact = gr.Checkbox(value=True, label="โœ… Fact Check")
426
+ enable_analysis = gr.Checkbox(value=True, label="๐Ÿ“Š Analysis")
427
+
428
+ apply_btn = gr.Button("โšก Apply Settings", variant="primary", size="lg")
429
+ settings_status = gr.Textbox(label="โš™๏ธ Settings Status", interactive=False, max_lines=10, value="Settings ready.")
430
+
431
+ def respond(message, history):
432
+ updated_history, audio_file = agent.ask(message, history)
433
+ display_history = []
434
+ for item in updated_history:
435
+ if isinstance(item, list) and len(item) == 2:
436
+ display_history.append({"role": "user", "content": str(item[0])})
437
+ display_history.append({"role": "assistant", "content": str(item[1])})
438
+ return "", updated_history, display_history, audio_file
439
+
440
+ def clear_chat():
441
+ return [], []
442
+
443
+ submit_btn.click(respond, inputs=[msg, chat_memory], outputs=[msg, chat_memory, chatbot, audio_output])
444
+ msg.submit(respond, inputs=[msg, chat_memory], outputs=[msg, chat_memory, chatbot, audio_output])
445
+ clear_btn.click(clear_chat, outputs=[chat_memory, chatbot])
446
+ pdf_upload.change(agent.upload_pdfs, inputs=[pdf_upload], outputs=[upload_status])
447
+
448
+ apply_btn.click(
449
+ agent.update_settings,
450
+ inputs=[
451
+ temperature_slider, max_tokens_slider, chunk_size_slider,
452
+ chunk_overlap_slider, retrieval_k_slider, enable_web,
453
+ enable_calc, enable_fact, enable_analysis
454
+ ],
455
+ outputs=[settings_status]
456
+ )
457
+
458
+ return interface
459
+
460
+
461
+ if __name__ == "__main__":
462
+ print("๐Ÿš€ Starting AI Research Agent with Full UI...")
463
+ app = create_interface()
464
+ app.launch(server_name="0.0.0.0", server_port=7860, show_error=True)