PraneshJs commited on
Commit
a6913bb
·
verified ·
1 Parent(s): e5231be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +442 -0
app.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import asyncio
4
+ import json
5
+ import hashlib
6
+ from io import BytesIO, StringIO
7
+ from typing import List, Tuple
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import faiss
12
+ import requests
13
+ import pandas as pd
14
+ from sentence_transformers import SentenceTransformer
15
+
16
+ # file parsing libs
17
+ import fitz # PyMuPDF
18
+ import docx
19
+ from pptx import Presentation
20
+
21
+ # crawl4ai
22
+ from crawl4ai import AsyncWebCrawler
23
+
24
+ # ---------------- Config ----------------
25
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
26
+ OPENROUTER_MODEL = "microsoft/mai-ds-r1:free"
27
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
28
+ CACHE_DIR = "./cache"
29
+ os.makedirs(CACHE_DIR, exist_ok=True)
30
+
31
+ # sentence-transformers embedder (loads once)
32
+ embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
33
+
34
+ # Global in-memory stores (cleared/updated by UI actions)
35
+ DOCS: List[str] = []
36
+ FILENAMES: List[str] = []
37
+ EMBEDDINGS: np.ndarray = None
38
+ FAISS_INDEX = None
39
+ CURRENT_CACHE_KEY: str = ""
40
+
41
+ # ---------------- File extraction helpers ----------------
42
+ def extract_text_from_pdf(file_bytes: bytes) -> str:
43
+ try:
44
+ doc = fitz.open(stream=file_bytes, filetype="pdf")
45
+ pages = [page.get_text() for page in doc]
46
+ return "\n".join(pages)
47
+ except Exception as e:
48
+ return f"[PDF extraction error] {e}"
49
+
50
+ def extract_text_from_docx(file_bytes: bytes) -> str:
51
+ try:
52
+ f = BytesIO(file_bytes)
53
+ doc = docx.Document(f)
54
+ return "\n".join([p.text for p in doc.paragraphs])
55
+ except Exception as e:
56
+ return f"[DOCX extraction error] {e}"
57
+
58
+ def extract_text_from_txt(file_bytes: bytes) -> str:
59
+ try:
60
+ return file_bytes.decode("utf-8", errors="ignore")
61
+ except Exception as e:
62
+ return f"[TXT extraction error] {e}"
63
+
64
+ def extract_text_from_excel(file_bytes: bytes) -> str:
65
+ try:
66
+ f = BytesIO(file_bytes)
67
+ df = pd.read_excel(f, dtype=str)
68
+ parts = []
69
+ for col in df.columns:
70
+ parts.append("\n".join(df[col].fillna("").astype(str).tolist()))
71
+ return "\n".join(parts)
72
+ except Exception as e:
73
+ return f"[EXCEL extraction error] {e}"
74
+
75
+ def extract_text_from_pptx(file_bytes: bytes) -> str:
76
+ try:
77
+ f = BytesIO(file_bytes)
78
+ prs = Presentation(f)
79
+ texts = []
80
+ for slide in prs.slides:
81
+ for shape in slide.shapes:
82
+ if hasattr(shape, "text"):
83
+ texts.append(shape.text)
84
+ return "\n".join(texts)
85
+ except Exception as e:
86
+ return f"[PPTX extraction error] {e}"
87
+
88
+ def extract_text_from_csv(file_bytes: bytes) -> str:
89
+ try:
90
+ f = StringIO(file_bytes.decode("utf-8", errors="ignore"))
91
+ df = pd.read_csv(f, dtype=str)
92
+ return df.to_string(index=False)
93
+ except Exception as e:
94
+ return f"[CSV extraction error] {e}"
95
+
96
+ def extract_text_from_file_tuple(file_tuple) -> Tuple[str, bytes]:
97
+ """
98
+ Accepts a Gradio file object/tuple and returns (filename, bytes).
99
+ Robust to multiple gradio versions.
100
+ """
101
+ # gradio v3.x passes TemporaryFile-like object with .name & .read()
102
+ try:
103
+ if hasattr(file_tuple, "name") and hasattr(file_tuple, "read"):
104
+ filename = os.path.basename(file_tuple.name)
105
+ file_bytes = file_tuple.read()
106
+ return filename, file_bytes
107
+ except Exception:
108
+ pass
109
+ # other shapes: tuple or dict-like
110
+ try:
111
+ # file_tuple may be (name, bytes)
112
+ if isinstance(file_tuple, tuple) and len(file_tuple) == 2 and isinstance(file_tuple[1], (bytes, bytearray)):
113
+ return file_tuple[0], bytes(file_tuple[1])
114
+ except Exception:
115
+ pass
116
+ # fallback if path string provided
117
+ try:
118
+ if isinstance(file_tuple, str) and os.path.exists(file_tuple):
119
+ with open(file_tuple, "rb") as fh:
120
+ return os.path.basename(file_tuple), fh.read()
121
+ except Exception:
122
+ pass
123
+ raise ValueError("Unsupported file object passed by Gradio.")
124
+
125
+ def extract_text_by_ext(filename: str, file_bytes: bytes) -> str:
126
+ name = filename.lower()
127
+ if name.endswith(".pdf"):
128
+ return extract_text_from_pdf(file_bytes)
129
+ if name.endswith(".docx"):
130
+ return extract_text_from_docx(file_bytes)
131
+ if name.endswith(".txt"):
132
+ return extract_text_from_txt(file_bytes)
133
+ if name.endswith(".xlsx") or name.endswith(".xls"):
134
+ return extract_text_from_excel(file_bytes)
135
+ if name.endswith(".pptx"):
136
+ return extract_text_from_pptx(file_bytes)
137
+ if name.endswith(".csv"):
138
+ return extract_text_from_csv(file_bytes)
139
+ # fallback: try plain text
140
+ return extract_text_from_txt(file_bytes)
141
+
142
+ # ---------------- Embedding caching helpers ----------------
143
+ def make_cache_key_for_files(files: List[Tuple[str, bytes]]) -> str:
144
+ """
145
+ Create a deterministic cache key based on filenames + sizes + sha256 of each file content.
146
+ """
147
+ h = hashlib.sha256()
148
+ for name, b in sorted(files, key=lambda x: x[0]):
149
+ h.update(name.encode("utf-8"))
150
+ h.update(str(len(b)).encode("utf-8"))
151
+ # update with small digest to keep speed; still robust
152
+ h.update(hashlib.sha256(b).digest())
153
+ return h.hexdigest()
154
+
155
+ def cache_save_embeddings(cache_key: str, embeddings: np.ndarray, filenames: List[str]):
156
+ path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
157
+ np.savez_compressed(path, embeddings=embeddings, filenames=np.array(filenames))
158
+ return path
159
+
160
+ def cache_load_embeddings(cache_key: str):
161
+ path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
162
+ if not os.path.exists(path):
163
+ return None
164
+ try:
165
+ arr = np.load(path, allow_pickle=True)
166
+ embeddings = arr["embeddings"]
167
+ filenames = arr["filenames"].tolist()
168
+ return embeddings, filenames
169
+ except Exception:
170
+ return None
171
+
172
+ # ---------------- FAISS helpers ----------------
173
+ def build_faiss_index(embeddings: np.ndarray):
174
+ global FAISS_INDEX
175
+ if embeddings is None or len(embeddings) == 0:
176
+ FAISS_INDEX = None
177
+ return None
178
+ emb = embeddings.astype("float32")
179
+ dim = emb.shape[1]
180
+ index = faiss.IndexFlatL2(dim)
181
+ index.add(emb)
182
+ FAISS_INDEX = index
183
+ return index
184
+
185
+ def search_top_k(query: str, k: int = 3):
186
+ if FAISS_INDEX is None:
187
+ return []
188
+ q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
189
+ D, I = FAISS_INDEX.search(q_emb, k)
190
+ results = []
191
+ for dist, idx in zip(D[0], I[0]):
192
+ if idx < 0:
193
+ continue
194
+ results.append({
195
+ "index": int(idx),
196
+ "distance": float(dist),
197
+ "text": DOCS[idx],
198
+ "source": FILENAMES[idx]
199
+ })
200
+ return results
201
+
202
+ # ---------------- OpenRouter minimal client ----------------
203
+ def openrouter_chat_system_user(system_prompt: str, user_prompt: str):
204
+ """
205
+ Sends only 'model' and 'messages' payload (system + user) to OpenRouter,
206
+ per your requirement (no max_tokens, temperature, etc).
207
+ """
208
+ if not OPENROUTER_API_KEY:
209
+ return "[OpenRouter error] OPENROUTER_API_KEY not set."
210
+
211
+ url = "https://openrouter.ai/api/v1/chat/completions"
212
+ headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
213
+ messages = []
214
+ if system_prompt:
215
+ messages.append({"role": "system", "content": system_prompt})
216
+ messages.append({"role": "user", "content": user_prompt})
217
+
218
+ payload = {"model": OPENROUTER_MODEL, "messages": messages}
219
+ try:
220
+ r = requests.post(url, headers=headers, json=payload, timeout=60)
221
+ r.raise_for_status()
222
+ obj = r.json()
223
+ # Expecting OpenAI-like structure: choices[0].message.content
224
+ if "choices" in obj and len(obj["choices"]) > 0:
225
+ choice = obj["choices"][0]
226
+ if "message" in choice and "content" in choice["message"]:
227
+ return choice["message"]["content"]
228
+ if "text" in choice:
229
+ return choice["text"]
230
+ # fallback: return entire partial json for debugging
231
+ return json.dumps(obj, indent=2)[:12000]
232
+ except Exception as e:
233
+ return f"[OpenRouter request error] {e}"
234
+
235
+ # ---------------- Crawl4AI robust logic ----------------
236
+ async def _crawl_async_get_markdown(url: str):
237
+ # uses default crawler settings; adjust with run config if needed
238
+ async with AsyncWebCrawler() as crawler:
239
+ result = await crawler.arun(url=url)
240
+ # prefer a success flag if present
241
+ if hasattr(result, "success") and result.success is False:
242
+ # attempt to surface error
243
+ err = getattr(result, "error_message", None) or getattr(result, "error", None) or "[Crawl4AI unknown error]"
244
+ return f"[Crawl4AI error] {err}"
245
+
246
+ # try structured markdown first
247
+ md_obj = getattr(result, "markdown", None)
248
+ if md_obj:
249
+ # try common subfields observed in different versions
250
+ text = getattr(md_obj, "fit_markdown", None) or getattr(md_obj, "raw_markdown", None)
251
+ if text:
252
+ return text
253
+ # fallback to str(md_obj)
254
+ try:
255
+ return str(md_obj)
256
+ except Exception:
257
+ pass
258
+
259
+ # fallback to text or html
260
+ text = getattr(result, "text", None) or getattr(result, "html", None)
261
+ if text:
262
+ return text
263
+ # last resort: jsonify entire result (short)
264
+ try:
265
+ return json.dumps(result.__dict__, default=str)[:20000]
266
+ except Exception:
267
+ return "[Crawl4AI returned no usable fields]"
268
+
269
+ def crawl_url_sync(url: str) -> str:
270
+ try:
271
+ return asyncio.run(_crawl_async_get_markdown(url))
272
+ except Exception as e:
273
+ return f"[Crawl4AI runtime error] {e}"
274
+
275
+ # ---------------- Gradio handlers ----------------
276
+ def upload_and_index(files):
277
+ """
278
+ files: list of file objects from Gradio. We'll extract bytes, compute cache key,
279
+ try to load embeddings from cache; if not found, compute embeddings and save.
280
+ """
281
+ global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
282
+
283
+ if not files:
284
+ return "No files uploaded.", ""
285
+
286
+ # read files into list of (name, bytes)
287
+ prepared = []
288
+ previews = []
289
+ for f in files:
290
+ name, b = extract_text_from_file_tuple(f)
291
+ prepared.append((name, b))
292
+ # short preview
293
+ previews.append({"name": name, "size": len(b)})
294
+
295
+ cache_key = make_cache_key_for_files(prepared)
296
+ CURRENT_CACHE_KEY = cache_key
297
+
298
+ # Try load existing embeddings
299
+ cached = cache_load_embeddings(cache_key)
300
+ if cached:
301
+ emb, filenames = cached
302
+ EMBEDDINGS = np.array(emb)
303
+ FILENAMES = filenames
304
+ # Rebuild DOCS array: we still need textual content (not just embeddings)
305
+ DOCS = []
306
+ for name, b in prepared:
307
+ DOCS.append(extract_text_by_ext(name, b))
308
+ # Build faiss index
309
+ build_faiss_index(EMBEDDINGS)
310
+ return f"Loaded embeddings from cache ({len(FILENAMES)} docs).", json.dumps(previews)
311
+
312
+ # Not cached -> extract texts and embed
313
+ DOCS = []
314
+ FILENAMES = []
315
+ for name, b in prepared:
316
+ txt = extract_text_by_ext(name, b)
317
+ DOCS.append(txt)
318
+ FILENAMES.append(name)
319
+
320
+ # Compute embeddings
321
+ emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
322
+ EMBEDDINGS = emb
323
+ # Save to cache
324
+ cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES)
325
+ # Build faiss
326
+ build_faiss_index(EMBEDDINGS)
327
+
328
+ return f"Uploaded and indexed {len(DOCS)} documents.", json.dumps(previews)
329
+
330
+ def crawl_and_index(url: str):
331
+ global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
332
+ if not url:
333
+ return "No URL provided.", ""
334
+
335
+ crawled = crawl_url_sync(url)
336
+ if crawled.startswith("[Crawl4AI"):
337
+ return crawled, ""
338
+
339
+ # create a cache key based on url and content
340
+ key_hash = hashlib.sha256()
341
+ key_hash.update(url.encode("utf-8"))
342
+ key_hash.update(crawled.encode("utf-8"))
343
+ cache_key = key_hash.hexdigest()
344
+ CURRENT_CACHE_KEY = cache_key
345
+
346
+ cached = cache_load_embeddings(cache_key)
347
+ if cached:
348
+ emb, filenames = cached
349
+ EMBEDDINGS = np.array(emb)
350
+ FILENAMES = filenames
351
+ DOCS = [crawled]
352
+ build_faiss_index(EMBEDDINGS)
353
+ return f"Crawled and loaded embeddings from cache for {url}", crawled[:2000]
354
+
355
+ # Not cached -> index
356
+ DOCS = [crawled]
357
+ FILENAMES = [url]
358
+ emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
359
+ EMBEDDINGS = emb
360
+ cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES)
361
+ build_faiss_index(EMBEDDINGS)
362
+ return f"Crawled and indexed {url}", crawled[:2000]
363
+
364
+ def ask_question(question: str, system_prompt: str = ""):
365
+ if not question:
366
+ return "Please enter a question."
367
+ if not DOCS or FAISS_INDEX is None:
368
+ return "No indexed documents. Upload files or crawl a site first."
369
+
370
+ topk = 3
371
+ results = search_top_k(question, k=topk)
372
+ if not results:
373
+ return "No relevant documents found."
374
+
375
+ # prepare context from top results (trim each)
376
+ context_blocks = []
377
+ meta = []
378
+ for r in results:
379
+ snippet = r["text"]
380
+ if len(snippet) > 1800:
381
+ snippet = snippet[:1800] + "\n...[truncated]"
382
+ context_blocks.append(f"Source: {r['source']}\n\n{snippet}\n\n---\n")
383
+ meta.append({"source": r["source"], "distance": r["distance"]})
384
+
385
+ context = "\n".join(context_blocks)
386
+ user_prompt = f"Use the following context to answer the question, and cite sources from the 'Source:' lines.\n\nContext:\n{context}\nQuestion: {question}\nAnswer:"
387
+
388
+ # Call OpenRouter with only model + messages (system & user)
389
+ try:
390
+ answer = openrouter_chat_system_user(system_prompt=system_prompt, user_prompt=user_prompt)
391
+ except Exception as e:
392
+ answer = f"[OpenRouter call failed] {e}"
393
+
394
+ out = {"answer": answer, "sources": meta}
395
+ return json.dumps(out, indent=2)
396
+
397
+ # ---------------- Gradio UI ----------------
398
+ with gr.Blocks(title="AI Ally (Gradio) — Crawl4AI + OpenRouter + FAISS") as demo:
399
+ gr.Markdown("# AI Ally — Document & Website QA\nCrawl4AI for websites, local file uploads for docs. FAISS retrieval + sentence-transformers embeddings. OpenRouter used for generation (only model + messages).")
400
+
401
+ with gr.Tab("Documents"):
402
+ with gr.Row():
403
+ file_input = gr.File(label="Upload files", file_count="multiple", file_types=[".pdf", ".docx", ".txt", ".xlsx", ".pptx", ".csv"])
404
+ upload_btn = gr.Button("Upload & Index")
405
+ with gr.Row():
406
+ upload_status = gr.Textbox(label="Status", interactive=False)
407
+ preview_box = gr.Textbox(label="Uploads (preview JSON)", interactive=False)
408
+ upload_btn.click(upload_and_index, inputs=[file_input], outputs=[upload_status, preview_box])
409
+
410
+ gr.Markdown("### Ask about the indexed documents")
411
+ q = gr.Textbox(label="Question", lines=3)
412
+ sys_prompt = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=2, value="You are a helpful assistant.")
413
+ ask_btn = gr.Button("Ask")
414
+ answer_out = gr.Textbox(label="Answer JSON", interactive=False)
415
+ ask_btn.click(ask_question, inputs=[q, sys_prompt], outputs=[answer_out])
416
+
417
+ with gr.Tab("Website Crawl"):
418
+ with gr.Row():
419
+ url = gr.Textbox(label="URL to crawl (starting URL)")
420
+ crawl_btn = gr.Button("Crawl & Index")
421
+ with gr.Row():
422
+ crawl_status = gr.Textbox(label="Status", interactive=False)
423
+ crawl_preview = gr.Textbox(label="Crawl preview (first 2k chars)", interactive=False)
424
+ crawl_btn.click(crawl_and_index, inputs=[url], outputs=[crawl_status, crawl_preview])
425
+
426
+ gr.Markdown("### Ask about the crawled site")
427
+ q2 = gr.Textbox(label="Question", lines=3)
428
+ sys_prompt2 = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=2, value="You are a helpful assistant.")
429
+ ask_btn2 = gr.Button("Ask site")
430
+ answer_out2 = gr.Textbox(label="Answer JSON", interactive=False)
431
+ ask_btn2.click(ask_question, inputs=[q2, sys_prompt2], outputs=[answer_out2])
432
+
433
+ with gr.Tab("Settings / Info"):
434
+ gr.Markdown(f"- OpenRouter model: `{OPENROUTER_MODEL}`")
435
+ gr.Markdown(f"- Embedding model: `{EMBEDDING_MODEL_NAME}`")
436
+ gr.Markdown("Set `OPENROUTER_API_KEY` in your environment or HF Secrets before deploying.")
437
+ gr.Markdown("Cache directory: `" + CACHE_DIR + "`")
438
+
439
+ gr.Markdown("----\nNotes: This app saves embeddings to `./cache/` using a deterministic cache key. OpenRouter calls include only `model` + `messages` (system + user) as requested.")
440
+
441
+ if __name__ == "__main__":
442
+ demo.launch(server_name="0.0.0.0", server_port=7860)