Zahid0123 commited on
Commit
8f2c55b
·
verified ·
1 Parent(s): bd04dfb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +319 -0
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - FULLY WORKING AI Research Agent for Hugging Face Spaces
2
+ import os
3
+ import re
4
+ import ast
5
+ import operator
6
+ import logging
7
+ import requests
8
+ import tempfile
9
+ import time
10
+ from pathlib import Path
11
+ from typing import List, Dict, Any
12
+ from datetime import datetime
13
+
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+ import PyPDF2
17
+ from sentence_transformers import SentenceTransformer
18
+ import faiss
19
+ from groq import Groq
20
+ import gradio as gr
21
+ from gtts import gTTS
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # ========================================
27
+ # TOOLS & UTILITIES
28
+ # ========================================
29
+ class WebSearchTool:
30
+ def __init__(self, max_results: int = 5):
31
+ self.max_results = max_results
32
+ self.base_url = "https://api.duckduckgo.com/"
33
+
34
+ def search(self, query: str) -> Dict[str, Any]:
35
+ try:
36
+ params = {
37
+ 'q': query, 'format': 'json', 'no_redirect': '1',
38
+ 'no_html': '1', 'skip_disambig': '1'
39
+ }
40
+ response = requests.get(self.base_url, params=params, timeout=10)
41
+ response.raise_for_status()
42
+ data = response.json()
43
+
44
+ results = {
45
+ 'abstract': data.get('Abstract', '') or data.get('Answer', ''),
46
+ 'related': [
47
+ {'text': t.get('Text', ''), 'url': t.get('FirstURL', '')}
48
+ for t in data.get('RelatedTopics', [])[:self.max_results]
49
+ if 'Text' in t
50
+ ]
51
+ }
52
+ return results
53
+ except Exception as e:
54
+ logger.error(f"Web search failed: {e}")
55
+ return {'abstract': '', 'related': []}
56
+
57
+ # ========================================
58
+ # DOCUMENT PROCESSING
59
+ # ========================================
60
+ class DocumentProcessor:
61
+ def load_documents(self, data_directory: str) -> List[Dict[str, Any]]:
62
+ documents = []
63
+ path = Path(data_directory)
64
+ for file_path in path.rglob("*.pdf"):
65
+ try:
66
+ text = ""
67
+ with open(file_path, 'rb') as f:
68
+ reader = PyPDF2.PdfReader(f)
69
+ for page in reader.pages:
70
+ page_text = page.extract_text()
71
+ if page_text:
72
+ text += page_text + "\n"
73
+ if text.strip():
74
+ documents.append({
75
+ 'doc_id': str(file_path.relative_to(path)),
76
+ 'content': text,
77
+ 'file_path': str(file_path)
78
+ })
79
+ except Exception as e:
80
+ logger.error(f"Error reading {file_path}: {e}")
81
+ return documents
82
+
83
+ class DocumentChunker:
84
+ def __init__(self, chunk_size=512, chunk_overlap=50):
85
+ self.chunk_size = chunk_size
86
+ self.chunk_overlap = chunk_overlap
87
+
88
+ def chunk_documents(self, documents: List[Dict]) -> List[Dict]:
89
+ chunks = []
90
+ for doc in documents:
91
+ text = re.sub(r'\s+', ' ', doc['content']).strip()
92
+ start = 0
93
+ while start < len(text):
94
+ end = start + self.chunk_size
95
+ chunk = text[start:end]
96
+ if end < len(text):
97
+ last_space = chunk.rfind(' ')
98
+ if last_space > self.chunk_size // 2:
99
+ end = start + last_space
100
+ chunks.append({
101
+ 'chunk_id': f"{doc['doc_id']}_{start}",
102
+ 'content': text[start:end].strip(),
103
+ 'doc_id': doc['doc_id'],
104
+ 'source_file': doc['file_path']
105
+ })
106
+ start = end - self.chunk_overlap
107
+ if start >= len(text):
108
+ break
109
+ return [c for c in chunks if len(c['content']) > 50]
110
+
111
+ class EmbeddingGenerator:
112
+ def __init__(self, model_name='all-MiniLM-L6-v2'):
113
+ self.model = SentenceTransformer(model_name)
114
+
115
+ def generate(self, chunks: List[Dict]) -> np.ndarray:
116
+ texts = [c['content'] for c in chunks]
117
+ return self.model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True)
118
+
119
+ def query_embedding(self, query: str) -> np.ndarray:
120
+ return self.model.encode([query], convert_to_numpy=True)[0]
121
+
122
+ # ========================================
123
+ # RETRIEVER
124
+ # ========================================
125
+ class DocumentRetriever:
126
+ def __init__(self):
127
+ self.chunks = []
128
+ self.index = None
129
+ self.embedder = EmbeddingGenerator()
130
+
131
+ def build_index(self, chunks: List[Dict], embeddings: np.ndarray):
132
+ self.chunks = chunks
133
+ dim = embeddings.shape[1]
134
+ self.index = faiss.IndexFlatIP(dim)
135
+ normalized = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
136
+ self.index.add(normalized.astype('float32'))
137
+
138
+ def search(self, query: str, k: int = 8) -> List[Dict]:
139
+ if not self.index:
140
+ return []
141
+ q_emb = self.embedder.query_embedding(query)
142
+ q_norm = q_emb / np.linalg.norm(q_emb)
143
+ scores, indices = self.index.search(q_norm.reshape(1, -1).astype('float32'), k)
144
+ results = []
145
+ for score, idx in zip(scores[0], indices[0]):
146
+ if idx < len(self.chunks):
147
+ chunk = self.chunks[idx].copy()
148
+ chunk['similarity'] = float(score)
149
+ results.append(chunk)
150
+ return results
151
+
152
+ # ========================================
153
+ # AGENT TOOLS
154
+ # ========================================
155
+ class AgenticTools:
156
+ def __init__(self):
157
+ self.web_search = WebSearchTool()
158
+
159
+ def calculator(self, expr: str):
160
+ try:
161
+ safe_expr = re.sub(r'[^0-9+\-*/(). ]', '', expr)
162
+ result = eval(ast.parse(safe_expr, mode='eval').body, {"__builtins__": {}})
163
+ return {"success": True, "result": result}
164
+ except:
165
+ return {"success": False, "result": "Invalid math"}
166
+
167
+ def web_search(self, query: str):
168
+ return {"success": True, "result": self.web_search.search(query)}
169
+
170
+ # ========================================
171
+ # MAIN AGENT
172
+ # ========================================
173
+ class AgenticRAGAgent:
174
+ def __init__(self):
175
+ self.retriever = None
176
+ self.groq = Groq(api_key=os.getenv("GROQ_API_KEY")) if os.getenv("GROQ_API_KEY") else None
177
+ self.tools = AgenticTools()
178
+
179
+ def clean_for_speech(self, text: str) -> str:
180
+ text = re.sub(r'\*\*|\*|_|`|\[.*?\]|\(.*?\)', '', text)
181
+ text = re.sub(r'\s+', ' ', text).strip()
182
+ return text
183
+
184
+ def generate_voice(self, text: str):
185
+ if not text or not text.strip():
186
+ return None
187
+ clean = self.clean_for_speech(text)
188
+ try:
189
+ tts = gTTS(text=clean, lang='en')
190
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
191
+ tts.save(tmp.name)
192
+ return tmp.name
193
+ except Exception as e:
194
+ logger.error(f"TTS failed: {e}")
195
+ return None
196
+
197
+ def upload_pdfs(self, files):
198
+ if not files:
199
+ return "No files uploaded."
200
+
201
+ os.makedirs("sample_data", exist_ok=True)
202
+ saved = []
203
+ for file in files:
204
+ if file.name.lower().endswith(".pdf"):
205
+ dest = os.path.join("sample_data", os.path.basename(file.name))
206
+ with open(dest, "wb") as f:
207
+ f.write(file.read() if callable(getattr(file, 'read', None)) else file)
208
+ saved.append(dest)
209
+
210
+ if not saved:
211
+ return "No valid PDF files."
212
+
213
+ # Process documents
214
+ processor = DocumentProcessor()
215
+ chunker = DocumentChunker()
216
+ docs = processor.load_documents("sample_data")
217
+ chunks = chunker.chunk_documents(docs)
218
+ embedder = EmbeddingGenerator()
219
+ embeddings = embedder.generate(chunks)
220
+
221
+ self.retriever = DocumentRetriever()
222
+ self.retriever.build_index(chunks, embeddings)
223
+
224
+ return f"Loaded {len(saved)} PDFs → {len(chunks)} chunks ready! Ask anything."
225
+
226
+ def answer_query(self, query: str, history: list):
227
+ if not query.strip():
228
+ return history, None
229
+
230
+ if not history:
231
+ history = []
232
+
233
+ # Greeting
234
+ if query.strip().lower() in ["hi", "hello", "hey", "hola"]:
235
+ resp = "Hello! I'm your AI Research Agent with voice answers. Upload PDFs and ask complex questions!"
236
+ history.append([query, resp])
237
+ audio = self.generate_voice(resp)
238
+ return history, audio
239
+
240
+ if not self.retriever:
241
+ resp = "Please upload at least one PDF document first!"
242
+ history.append([query, resp])
243
+ return history, None
244
+
245
+ # Retrieve + Answer
246
+ docs = self.retriever.search(query, k=8)
247
+ context = "\n\n".join([d['content'][:1000] for d in docs[:5]])
248
+
249
+ prompt = f"""You are an expert research assistant.
250
+ Context from documents:
251
+ {context}
252
+
253
+ Question: {query}
254
+ Provide a clear, accurate, and concise answer. Use bullet points if helpful."""
255
+
256
+ try:
257
+ if not self.groq:
258
+ answer = "GROQ_API_KEY not set. Set it in Secrets."
259
+ else:
260
+ resp = self.groq.chat.completions.create(
261
+ model="llama-3.1-70b-versatile",
262
+ messages=[{"role": "user", "content": prompt}],
263
+ temperature=0.3,
264
+ max_tokens=800
265
+ )
266
+ answer = resp.choices[0].message.content.strip()
267
+ except Exception as e:
268
+ answer = f"LLM Error: {str(e)}"
269
+
270
+ history.append([query, answer])
271
+ audio = self.generate_voice(answer)
272
+ return history, audio
273
+
274
+ # ========================================
275
+ # GRADIO APP
276
+ # ========================================
277
+ def create_app():
278
+ agent = AgenticRAGAgent()
279
+
280
+ with gr.Blocks(title="AI Research Agent - RAG + Voice", theme=gr.themes.Soft()) as demo:
281
+ gr.Markdown("""
282
+ # 🤖 AI Research Agent (Agentic RAG + Voice)
283
+ Upload PDFs → Ask complex questions → Get answers with **voice**
284
+ """)
285
+
286
+ with gr.Row():
287
+ with gr.Column(scale=3):
288
+ chatbot = gr.Chatbot(height=600)
289
+ msg = gr.Textbox(placeholder="Ask about your documents...", label="Your Question")
290
+ with gr.Row():
291
+ send = gr.Button("Send", variant="primary")
292
+ clear = gr.Button("Clear Chat")
293
+ audio_out = gr.Audio(label="Voice Response", autoplay=True, interactive=False)
294
+
295
+ with gr.Column(scale=1):
296
+ gr.Markdown("### Upload Research PDFs")
297
+ file_input = gr.Files(file_types=[".pdf"], file_count="multiple")
298
+ status = gr.Textbox(label="Status", interactive=False, lines=4)
299
+
300
+ # Events
301
+ def respond(message, chat_history):
302
+ new_hist, audio_file = agent.answer_query(message, chat_history)
303
+ return "", new_hist, audio_file
304
+
305
+ msg.submit(respond, [msg, chatbot], [msg, chatbot, audio_out])
306
+ send.click(respond, [msg, chatbot], [msg, chatbot, audio_out])
307
+ clear.click(lambda: ([], None), outputs=[chatbot, audio_out])
308
+ file_input.change(agent.upload_pdfs, file_input, status)
309
+
310
+ gr.Markdown("**Secret Required:** Add `GROQ_API_KEY` in Space Secrets (free at console.groq.com)")
311
+
312
+ return demo
313
+
314
+ # ========================================
315
+ # LAUNCH
316
+ # ========================================
317
+ if __name__ == "__main__":
318
+ app = create_app()
319
+ app.launch(server_name="0.0.0.0", server_port=7860)