PraneshJs commited on
Commit
a9a9358
·
verified ·
1 Parent(s): a8404b2

fixed system prompt and response in txt

Browse files
Files changed (1) hide show
  1. app.py +119 -234
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import asyncio
4
  import json
5
  import hashlib
 
6
  from io import BytesIO, StringIO
7
  from typing import List, Tuple
8
 
@@ -12,13 +13,9 @@ import faiss
12
  import requests
13
  import pandas as pd
14
  from sentence_transformers import SentenceTransformer
15
-
16
- # file parsing libs
17
  import fitz # PyMuPDF
18
  import docx
19
  from pptx import Presentation
20
-
21
- # crawl4ai
22
  from crawl4ai import AsyncWebCrawler
23
 
24
  # ---------------- Config ----------------
@@ -26,24 +23,39 @@ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
26
  OPENROUTER_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free"
27
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
28
  CACHE_DIR = "./cache"
 
29
  os.makedirs(CACHE_DIR, exist_ok=True)
30
 
31
- # sentence-transformers embedder (loads once)
32
  embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
33
 
34
- # Global in-memory stores (cleared/updated by UI actions)
35
  DOCS: List[str] = []
36
  FILENAMES: List[str] = []
37
  EMBEDDINGS: np.ndarray = None
38
  FAISS_INDEX = None
39
  CURRENT_CACHE_KEY: str = ""
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # ---------------- File extraction helpers ----------------
42
  def extract_text_from_pdf(file_bytes: bytes) -> str:
43
  try:
44
  doc = fitz.open(stream=file_bytes, filetype="pdf")
45
- pages = [page.get_text() for page in doc]
46
- return "\n".join(pages)
47
  except Exception as e:
48
  return f"[PDF extraction error] {e}"
49
 
@@ -51,7 +63,7 @@ def extract_text_from_docx(file_bytes: bytes) -> str:
51
  try:
52
  f = BytesIO(file_bytes)
53
  doc = docx.Document(f)
54
- return "\n".join([p.text for p in doc.paragraphs])
55
  except Exception as e:
56
  return f"[DOCX extraction error] {e}"
57
 
@@ -65,10 +77,7 @@ def extract_text_from_excel(file_bytes: bytes) -> str:
65
  try:
66
  f = BytesIO(file_bytes)
67
  df = pd.read_excel(f, dtype=str)
68
- parts = []
69
- for col in df.columns:
70
- parts.append("\n".join(df[col].fillna("").astype(str).tolist()))
71
- return "\n".join(parts)
72
  except Exception as e:
73
  return f"[EXCEL extraction error] {e}"
74
 
@@ -94,90 +103,57 @@ def extract_text_from_csv(file_bytes: bytes) -> str:
94
  return f"[CSV extraction error] {e}"
95
 
96
  def extract_text_from_file_tuple(file_tuple) -> Tuple[str, bytes]:
97
- """
98
- Accepts a Gradio file object/tuple and returns (filename, bytes).
99
- Robust to multiple gradio versions.
100
- """
101
- # gradio v3.x passes TemporaryFile-like object with .name & .read()
102
  try:
103
  if hasattr(file_tuple, "name") and hasattr(file_tuple, "read"):
104
- filename = os.path.basename(file_tuple.name)
105
- file_bytes = file_tuple.read()
106
- return filename, file_bytes
107
- except Exception:
108
- pass
109
- # other shapes: tuple or dict-like
110
- try:
111
- # file_tuple may be (name, bytes)
112
- if isinstance(file_tuple, tuple) and len(file_tuple) == 2 and isinstance(file_tuple[1], (bytes, bytearray)):
113
- return file_tuple[0], bytes(file_tuple[1])
114
- except Exception:
115
- pass
116
- # fallback if path string provided
117
- try:
118
- if isinstance(file_tuple, str) and os.path.exists(file_tuple):
119
- with open(file_tuple, "rb") as fh:
120
- return os.path.basename(file_tuple), fh.read()
121
  except Exception:
122
  pass
 
 
 
 
 
123
  raise ValueError("Unsupported file object passed by Gradio.")
124
 
125
  def extract_text_by_ext(filename: str, file_bytes: bytes) -> str:
126
  name = filename.lower()
127
- if name.endswith(".pdf"):
128
- return extract_text_from_pdf(file_bytes)
129
- if name.endswith(".docx"):
130
- return extract_text_from_docx(file_bytes)
131
- if name.endswith(".txt"):
132
- return extract_text_from_txt(file_bytes)
133
- if name.endswith(".xlsx") or name.endswith(".xls"):
134
- return extract_text_from_excel(file_bytes)
135
- if name.endswith(".pptx"):
136
- return extract_text_from_pptx(file_bytes)
137
- if name.endswith(".csv"):
138
- return extract_text_from_csv(file_bytes)
139
- # fallback: try plain text
140
  return extract_text_from_txt(file_bytes)
141
 
142
- # ---------------- Embedding caching helpers ----------------
 
143
  def make_cache_key_for_files(files: List[Tuple[str, bytes]]) -> str:
144
- """
145
- Create a deterministic cache key based on filenames + sizes + sha256 of each file content.
146
- """
147
  h = hashlib.sha256()
148
  for name, b in sorted(files, key=lambda x: x[0]):
149
- h.update(name.encode("utf-8"))
150
- h.update(str(len(b)).encode("utf-8"))
151
- # update with small digest to keep speed; still robust
152
  h.update(hashlib.sha256(b).digest())
153
  return h.hexdigest()
154
 
155
  def cache_save_embeddings(cache_key: str, embeddings: np.ndarray, filenames: List[str]):
156
- path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
157
- np.savez_compressed(path, embeddings=embeddings, filenames=np.array(filenames))
158
- return path
159
 
160
  def cache_load_embeddings(cache_key: str):
161
  path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
162
- if not os.path.exists(path):
163
- return None
164
  try:
165
  arr = np.load(path, allow_pickle=True)
166
- embeddings = arr["embeddings"]
167
- filenames = arr["filenames"].tolist()
168
- return embeddings, filenames
169
  except Exception:
170
  return None
171
 
172
- # ---------------- FAISS helpers ----------------
173
  def build_faiss_index(embeddings: np.ndarray):
174
  global FAISS_INDEX
175
  if embeddings is None or len(embeddings) == 0:
176
  FAISS_INDEX = None
177
  return None
178
  emb = embeddings.astype("float32")
179
- dim = emb.shape[1]
180
- index = faiss.IndexFlatL2(dim)
181
  index.add(emb)
182
  FAISS_INDEX = index
183
  return index
@@ -187,84 +163,63 @@ def search_top_k(query: str, k: int = 3):
187
  return []
188
  q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
189
  D, I = FAISS_INDEX.search(q_emb, k)
190
- results = []
191
- for dist, idx in zip(D[0], I[0]):
192
- if idx < 0:
193
- continue
194
- results.append({
195
- "index": int(idx),
196
- "distance": float(dist),
197
- "text": DOCS[idx],
198
- "source": FILENAMES[idx]
199
- })
200
- return results
201
-
202
- # ---------------- OpenRouter minimal client ----------------
203
- def openrouter_chat_system_user(system_prompt: str, user_prompt: str):
204
  """
205
- Sends only 'model' and 'messages' payload (system + user) to OpenRouter,
206
- per your requirement (no max_tokens, temperature, etc).
207
  """
208
  if not OPENROUTER_API_KEY:
209
- return "[OpenRouter error] OPENROUTER_API_KEY not set."
210
 
211
  url = "https://openrouter.ai/api/v1/chat/completions"
212
- headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
213
- messages = []
214
- if system_prompt:
215
- messages.append({"role": "system", "content": system_prompt})
216
- messages.append({"role": "user", "content": user_prompt})
 
 
 
 
 
 
 
 
217
 
218
- payload = {"model": OPENROUTER_MODEL, "messages": messages}
219
  try:
220
  r = requests.post(url, headers=headers, json=payload, timeout=60)
221
  r.raise_for_status()
222
  obj = r.json()
223
- # Expecting OpenAI-like structure: choices[0].message.content
224
- if "choices" in obj and len(obj["choices"]) > 0:
 
225
  choice = obj["choices"][0]
226
  if "message" in choice and "content" in choice["message"]:
227
- return choice["message"]["content"]
228
- if "text" in choice:
229
- return choice["text"]
230
- # fallback: return entire partial json for debugging
231
- return json.dumps(obj, indent=2)[:12000]
 
 
 
232
  except Exception as e:
233
  return f"[OpenRouter request error] {e}"
234
 
235
- # ---------------- Crawl4AI robust logic ----------------
 
236
  async def _crawl_async_get_markdown(url: str):
237
- # uses default crawler settings; adjust with run config if needed
238
  async with AsyncWebCrawler() as crawler:
239
  result = await crawler.arun(url=url)
240
- # prefer a success flag if present
241
  if hasattr(result, "success") and result.success is False:
242
- # attempt to surface error
243
- err = getattr(result, "error_message", None) or getattr(result, "error", None) or "[Crawl4AI unknown error]"
244
- return f"[Crawl4AI error] {err}"
245
-
246
- # try structured markdown first
247
  md_obj = getattr(result, "markdown", None)
248
  if md_obj:
249
- # try common subfields observed in different versions
250
- text = getattr(md_obj, "fit_markdown", None) or getattr(md_obj, "raw_markdown", None)
251
- if text:
252
- return text
253
- # fallback to str(md_obj)
254
- try:
255
- return str(md_obj)
256
- except Exception:
257
- pass
258
-
259
- # fallback to text or html
260
- text = getattr(result, "text", None) or getattr(result, "html", None)
261
- if text:
262
- return text
263
- # last resort: jsonify entire result (short)
264
- try:
265
- return json.dumps(result.__dict__, default=str)[:20000]
266
- except Exception:
267
- return "[Crawl4AI returned no usable fields]"
268
 
269
  def crawl_url_sync(url: str) -> str:
270
  try:
@@ -272,78 +227,40 @@ def crawl_url_sync(url: str) -> str:
272
  except Exception as e:
273
  return f"[Crawl4AI runtime error] {e}"
274
 
275
- # ---------------- Gradio handlers ----------------
 
276
  def upload_and_index(files):
277
- """
278
- files: list of file objects from Gradio. We'll extract bytes, compute cache key,
279
- try to load embeddings from cache; if not found, compute embeddings and save.
280
- """
281
  global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
282
-
283
  if not files:
284
  return "No files uploaded.", ""
285
-
286
- # read files into list of (name, bytes)
287
- prepared = []
288
- previews = []
289
- for f in files:
290
- name, b = extract_text_from_file_tuple(f)
291
- prepared.append((name, b))
292
- # short preview
293
- previews.append({"name": name, "size": len(b)})
294
-
295
  cache_key = make_cache_key_for_files(prepared)
296
  CURRENT_CACHE_KEY = cache_key
297
-
298
- # Try load existing embeddings
299
  cached = cache_load_embeddings(cache_key)
300
  if cached:
301
  emb, filenames = cached
302
  EMBEDDINGS = np.array(emb)
303
  FILENAMES = filenames
304
- # Rebuild DOCS array: we still need textual content (not just embeddings)
305
- DOCS = []
306
- for name, b in prepared:
307
- DOCS.append(extract_text_by_ext(name, b))
308
- # Build faiss index
309
  build_faiss_index(EMBEDDINGS)
310
  return f"Loaded embeddings from cache ({len(FILENAMES)} docs).", json.dumps(previews)
311
-
312
- # Not cached -> extract texts and embed
313
- DOCS = []
314
- FILENAMES = []
315
- for name, b in prepared:
316
- txt = extract_text_by_ext(name, b)
317
- DOCS.append(txt)
318
- FILENAMES.append(name)
319
-
320
- # Compute embeddings
321
- emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
322
- EMBEDDINGS = emb
323
- # Save to cache
324
  cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES)
325
- # Build faiss
326
  build_faiss_index(EMBEDDINGS)
327
-
328
  return f"Uploaded and indexed {len(DOCS)} documents.", json.dumps(previews)
329
 
330
  def crawl_and_index(url: str):
331
  global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
332
  if not url:
333
  return "No URL provided.", ""
334
-
335
  crawled = crawl_url_sync(url)
336
  if crawled.startswith("[Crawl4AI"):
337
  return crawled, ""
338
-
339
- # create a cache key based on url and content
340
- key_hash = hashlib.sha256()
341
- key_hash.update(url.encode("utf-8"))
342
- key_hash.update(crawled.encode("utf-8"))
343
- cache_key = key_hash.hexdigest()
344
- CURRENT_CACHE_KEY = cache_key
345
-
346
- cached = cache_load_embeddings(cache_key)
347
  if cached:
348
  emb, filenames = cached
349
  EMBEDDINGS = np.array(emb)
@@ -351,92 +268,60 @@ def crawl_and_index(url: str):
351
  DOCS = [crawled]
352
  build_faiss_index(EMBEDDINGS)
353
  return f"Crawled and loaded embeddings from cache for {url}", crawled[:2000]
354
-
355
- # Not cached -> index
356
- DOCS = [crawled]
357
- FILENAMES = [url]
358
- emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
359
- EMBEDDINGS = emb
360
- cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES)
361
  build_faiss_index(EMBEDDINGS)
362
  return f"Crawled and indexed {url}", crawled[:2000]
363
 
364
- def ask_question(question: str, system_prompt: str = ""):
365
  if not question:
366
  return "Please enter a question."
367
  if not DOCS or FAISS_INDEX is None:
368
- return "No indexed documents. Upload files or crawl a site first."
369
-
370
- topk = 3
371
- results = search_top_k(question, k=topk)
372
  if not results:
373
  return "No relevant documents found."
 
 
 
374
 
375
- # prepare context from top results (trim each)
376
- context_blocks = []
377
- meta = []
378
- for r in results:
379
- snippet = r["text"]
380
- if len(snippet) > 1800:
381
- snippet = snippet[:1800] + "\n...[truncated]"
382
- context_blocks.append(f"Source: {r['source']}\n\n{snippet}\n\n---\n")
383
- meta.append({"source": r["source"], "distance": r["distance"]})
384
-
385
- context = "\n".join(context_blocks)
386
- user_prompt = f"Use the following context to answer the question, and cite sources from the 'Source:' lines.\n\nContext:\n{context}\nQuestion: {question}\nAnswer:"
387
-
388
- # Call OpenRouter with only model + messages (system & user)
389
- try:
390
- answer = openrouter_chat_system_user(system_prompt=system_prompt, user_prompt=user_prompt)
391
- except Exception as e:
392
- answer = f"[OpenRouter call failed] {e}"
393
-
394
- out = {"answer": answer, "sources": meta}
395
- return json.dumps(out, indent=2)
396
 
397
  # ---------------- Gradio UI ----------------
398
- with gr.Blocks(title="AI Ally (Gradio) — Crawl4AI + OpenRouter + FAISS") as demo:
399
- gr.Markdown("# AI Ally — Document & Website QA\nCrawl4AI for websites, local file uploads for docs. FAISS retrieval + sentence-transformers embeddings. OpenRouter used for generation (only model + messages).")
400
 
401
  with gr.Tab("Documents"):
402
- with gr.Row():
403
- file_input = gr.File(label="Upload files", file_count="multiple", file_types=[".pdf", ".docx", ".txt", ".xlsx", ".pptx", ".csv"])
404
- upload_btn = gr.Button("Upload & Index")
405
- with gr.Row():
406
- upload_status = gr.Textbox(label="Status", interactive=False)
407
- preview_box = gr.Textbox(label="Uploads (preview JSON)", interactive=False)
408
  upload_btn.click(upload_and_index, inputs=[file_input], outputs=[upload_status, preview_box])
409
 
410
- gr.Markdown("### Ask about the indexed documents")
411
- q = gr.Textbox(label="Question", lines=5)
412
- sys_prompt = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=5, value="You are a helpful assistant.")
413
  ask_btn = gr.Button("Ask")
414
- answer_out = gr.Textbox(label="Answer JSON", interactive=False, lines=15)
415
- ask_btn.click(ask_question, inputs=[q, sys_prompt], outputs=[answer_out])
416
 
417
  with gr.Tab("Website Crawl"):
418
- with gr.Row():
419
- url = gr.Textbox(label="URL to crawl (starting URL)")
420
- crawl_btn = gr.Button("Crawl & Index")
421
- with gr.Row():
422
- crawl_status = gr.Textbox(label="Status", interactive=False)
423
- crawl_preview = gr.Textbox(label="Crawl preview (first 2k chars)", interactive=False)
424
  crawl_btn.click(crawl_and_index, inputs=[url], outputs=[crawl_status, crawl_preview])
425
 
426
- gr.Markdown("### Ask about the crawled site")
427
- q2 = gr.Textbox(label="Question", lines=5)
428
- sys_prompt2 = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=10, value="You are a helpful assistant.")
429
  ask_btn2 = gr.Button("Ask site")
430
- answer_out2 = gr.Textbox(label="Answer JSON", interactive=False, lines=15)
431
- ask_btn2.click(ask_question, inputs=[q2, sys_prompt2], outputs=[answer_out2])
432
 
433
  with gr.Tab("Settings / Info"):
434
- gr.Markdown(f"- OpenRouter model: `{OPENROUTER_MODEL}`")
435
  gr.Markdown(f"- Embedding model: `{EMBEDDING_MODEL_NAME}`")
436
- gr.Markdown("Set `OPENROUTER_API_KEY` in your environment or HF Secrets before deploying.")
437
- gr.Markdown("Cache directory: `" + CACHE_DIR + "`")
438
-
439
- gr.Markdown("----\nNotes: This app saves embeddings to `./cache/` using a deterministic cache key. OpenRouter calls include only `model` + `messages` (system + user) as requested.")
440
 
441
  if __name__ == "__main__":
442
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import asyncio
4
  import json
5
  import hashlib
6
+ import shutil
7
  from io import BytesIO, StringIO
8
  from typing import List, Tuple
9
 
 
13
  import requests
14
  import pandas as pd
15
  from sentence_transformers import SentenceTransformer
 
 
16
  import fitz # PyMuPDF
17
  import docx
18
  from pptx import Presentation
 
 
19
  from crawl4ai import AsyncWebCrawler
20
 
21
  # ---------------- Config ----------------
 
23
  OPENROUTER_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free"
24
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
25
  CACHE_DIR = "./cache"
26
+ SYSTEM_PROMPT = "You are a helpful assistant."
27
  os.makedirs(CACHE_DIR, exist_ok=True)
28
 
 
29
  embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
30
 
 
31
  DOCS: List[str] = []
32
  FILENAMES: List[str] = []
33
  EMBEDDINGS: np.ndarray = None
34
  FAISS_INDEX = None
35
  CURRENT_CACHE_KEY: str = ""
36
 
37
+
38
+ # ---------------- Periodic cache cleanup ----------------
39
+ async def clear_cache_every_5min():
40
+ while True:
41
+ await asyncio.sleep(300) # 5 minutes
42
+ try:
43
+ if os.path.exists(CACHE_DIR):
44
+ shutil.rmtree(CACHE_DIR)
45
+ os.makedirs(CACHE_DIR, exist_ok=True)
46
+ print("🧹 Cache cleared successfully.")
47
+ except Exception as e:
48
+ print(f"[Cache cleanup error] {e}")
49
+
50
+ # Launch the cleaner in background
51
+ asyncio.get_event_loop().create_task(clear_cache_every_5min())
52
+
53
+
54
  # ---------------- File extraction helpers ----------------
55
  def extract_text_from_pdf(file_bytes: bytes) -> str:
56
  try:
57
  doc = fitz.open(stream=file_bytes, filetype="pdf")
58
+ return "\n".join(page.get_text() for page in doc)
 
59
  except Exception as e:
60
  return f"[PDF extraction error] {e}"
61
 
 
63
  try:
64
  f = BytesIO(file_bytes)
65
  doc = docx.Document(f)
66
+ return "\n".join(p.text for p in doc.paragraphs)
67
  except Exception as e:
68
  return f"[DOCX extraction error] {e}"
69
 
 
77
  try:
78
  f = BytesIO(file_bytes)
79
  df = pd.read_excel(f, dtype=str)
80
+ return "\n".join("\n".join(df[col].fillna("").astype(str).tolist()) for col in df.columns)
 
 
 
81
  except Exception as e:
82
  return f"[EXCEL extraction error] {e}"
83
 
 
103
  return f"[CSV extraction error] {e}"
104
 
105
  def extract_text_from_file_tuple(file_tuple) -> Tuple[str, bytes]:
 
 
 
 
 
106
  try:
107
  if hasattr(file_tuple, "name") and hasattr(file_tuple, "read"):
108
+ return os.path.basename(file_tuple.name), file_tuple.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  except Exception:
110
  pass
111
+ if isinstance(file_tuple, tuple) and len(file_tuple) == 2 and isinstance(file_tuple[1], (bytes, bytearray)):
112
+ return file_tuple[0], bytes(file_tuple[1])
113
+ if isinstance(file_tuple, str) and os.path.exists(file_tuple):
114
+ with open(file_tuple, "rb") as fh:
115
+ return os.path.basename(file_tuple), fh.read()
116
  raise ValueError("Unsupported file object passed by Gradio.")
117
 
118
  def extract_text_by_ext(filename: str, file_bytes: bytes) -> str:
119
  name = filename.lower()
120
+ if name.endswith(".pdf"): return extract_text_from_pdf(file_bytes)
121
+ if name.endswith(".docx"): return extract_text_from_docx(file_bytes)
122
+ if name.endswith(".txt"): return extract_text_from_txt(file_bytes)
123
+ if name.endswith((".xlsx", ".xls")): return extract_text_from_excel(file_bytes)
124
+ if name.endswith(".pptx"): return extract_text_from_pptx(file_bytes)
125
+ if name.endswith(".csv"): return extract_text_from_csv(file_bytes)
 
 
 
 
 
 
 
126
  return extract_text_from_txt(file_bytes)
127
 
128
+
129
+ # ---------------- Cache + FAISS helpers ----------------
130
  def make_cache_key_for_files(files: List[Tuple[str, bytes]]) -> str:
 
 
 
131
  h = hashlib.sha256()
132
  for name, b in sorted(files, key=lambda x: x[0]):
133
+ h.update(name.encode())
134
+ h.update(str(len(b)).encode())
 
135
  h.update(hashlib.sha256(b).digest())
136
  return h.hexdigest()
137
 
138
  def cache_save_embeddings(cache_key: str, embeddings: np.ndarray, filenames: List[str]):
139
+ np.savez_compressed(os.path.join(CACHE_DIR, f"{cache_key}.npz"), embeddings=embeddings, filenames=np.array(filenames))
 
 
140
 
141
  def cache_load_embeddings(cache_key: str):
142
  path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
143
+ if not os.path.exists(path): return None
 
144
  try:
145
  arr = np.load(path, allow_pickle=True)
146
+ return arr["embeddings"], arr["filenames"].tolist()
 
 
147
  except Exception:
148
  return None
149
 
 
150
  def build_faiss_index(embeddings: np.ndarray):
151
  global FAISS_INDEX
152
  if embeddings is None or len(embeddings) == 0:
153
  FAISS_INDEX = None
154
  return None
155
  emb = embeddings.astype("float32")
156
+ index = faiss.IndexFlatL2(emb.shape[1])
 
157
  index.add(emb)
158
  FAISS_INDEX = index
159
  return index
 
163
  return []
164
  q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
165
  D, I = FAISS_INDEX.search(q_emb, k)
166
+ return [{"index": int(i), "distance": float(d), "text": DOCS[i], "source": FILENAMES[i]} for d, i in zip(D[0], I[0]) if i >= 0]
167
+
168
+
169
+ # ---------------- OpenRouter Client ----------------
170
+ def openrouter_chat_system_user(user_prompt: str):
 
 
 
 
 
 
 
 
 
171
  """
172
+ Sends user prompt to OpenRouter and expects a plain text response.
 
173
  """
174
  if not OPENROUTER_API_KEY:
175
+ return "[OpenRouter error] Missing OPENROUTER_API_KEY."
176
 
177
  url = "https://openrouter.ai/api/v1/chat/completions"
178
+ headers = {
179
+ "Authorization": f"Bearer {OPENROUTER_API_KEY}",
180
+ "Content-Type": "application/json",
181
+ }
182
+
183
+ # Tell the model explicitly to reply as plain text only
184
+ payload = {
185
+ "model": OPENROUTER_MODEL,
186
+ "messages": [
187
+ {"role": "system", "content": SYSTEM_PROMPT + " Always respond in plain text. Avoid JSON or markdown formatting."},
188
+ {"role": "user", "content": user_prompt},
189
+ ],
190
+ }
191
 
 
192
  try:
193
  r = requests.post(url, headers=headers, json=payload, timeout=60)
194
  r.raise_for_status()
195
  obj = r.json()
196
+
197
+ # Safely extract plain text
198
+ if "choices" in obj and obj["choices"]:
199
  choice = obj["choices"][0]
200
  if "message" in choice and "content" in choice["message"]:
201
+ text = choice["message"]["content"]
202
+ # Ensure no markdown or code blocks
203
+ text = text.strip().replace("```", "").replace("json", "")
204
+ return text
205
+ elif "text" in choice:
206
+ return choice["text"].strip()
207
+ return "[OpenRouter] Unexpected response format."
208
+
209
  except Exception as e:
210
  return f"[OpenRouter request error] {e}"
211
 
212
+
213
+ # ---------------- Crawl4AI Logic ----------------
214
  async def _crawl_async_get_markdown(url: str):
 
215
  async with AsyncWebCrawler() as crawler:
216
  result = await crawler.arun(url=url)
 
217
  if hasattr(result, "success") and result.success is False:
218
+ return f"[Crawl4AI error] {getattr(result, 'error_message', '[Unknown error]')}"
 
 
 
 
219
  md_obj = getattr(result, "markdown", None)
220
  if md_obj:
221
+ return getattr(md_obj, "fit_markdown", None) or getattr(md_obj, "raw_markdown", None) or str(md_obj)
222
+ return getattr(result, "text", None) or getattr(result, "html", None) or "[Crawl4AI returned no usable fields]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  def crawl_url_sync(url: str) -> str:
225
  try:
 
227
  except Exception as e:
228
  return f"[Crawl4AI runtime error] {e}"
229
 
230
+
231
+ # ---------------- Gradio Handlers ----------------
232
  def upload_and_index(files):
 
 
 
 
233
  global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
 
234
  if not files:
235
  return "No files uploaded.", ""
236
+ prepared = [(name := extract_text_from_file_tuple(f)[0], extract_text_from_file_tuple(f)[1]) for f in files]
237
+ previews = [{"name": n, "size": len(b)} for n, b in prepared]
 
 
 
 
 
 
 
 
238
  cache_key = make_cache_key_for_files(prepared)
239
  CURRENT_CACHE_KEY = cache_key
 
 
240
  cached = cache_load_embeddings(cache_key)
241
  if cached:
242
  emb, filenames = cached
243
  EMBEDDINGS = np.array(emb)
244
  FILENAMES = filenames
245
+ DOCS = [extract_text_by_ext(n, b) for n, b in prepared]
 
 
 
 
246
  build_faiss_index(EMBEDDINGS)
247
  return f"Loaded embeddings from cache ({len(FILENAMES)} docs).", json.dumps(previews)
248
+ DOCS, FILENAMES = zip(*[(extract_text_by_ext(n, b), n) for n, b in prepared])
249
+ EMBEDDINGS = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
 
 
 
 
 
 
 
 
 
 
 
250
  cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES)
 
251
  build_faiss_index(EMBEDDINGS)
 
252
  return f"Uploaded and indexed {len(DOCS)} documents.", json.dumps(previews)
253
 
254
  def crawl_and_index(url: str):
255
  global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
256
  if not url:
257
  return "No URL provided.", ""
 
258
  crawled = crawl_url_sync(url)
259
  if crawled.startswith("[Crawl4AI"):
260
  return crawled, ""
261
+ key_hash = hashlib.sha256((url + crawled).encode()).hexdigest()
262
+ CURRENT_CACHE_KEY = key_hash
263
+ cached = cache_load_embeddings(key_hash)
 
 
 
 
 
 
264
  if cached:
265
  emb, filenames = cached
266
  EMBEDDINGS = np.array(emb)
 
268
  DOCS = [crawled]
269
  build_faiss_index(EMBEDDINGS)
270
  return f"Crawled and loaded embeddings from cache for {url}", crawled[:2000]
271
+ DOCS, FILENAMES = [crawled], [url]
272
+ EMBEDDINGS = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
273
+ cache_save_embeddings(key_hash, EMBEDDINGS, FILENAMES)
 
 
 
 
274
  build_faiss_index(EMBEDDINGS)
275
  return f"Crawled and indexed {url}", crawled[:2000]
276
 
277
+ def ask_question(question: str):
278
  if not question:
279
  return "Please enter a question."
280
  if not DOCS or FAISS_INDEX is None:
281
+ return "No indexed data found."
282
+ results = search_top_k(question, k=3)
 
 
283
  if not results:
284
  return "No relevant documents found."
285
+ context = "\n".join(f"Source: {r['source']}\n\n{r['text'][:1800]}\n---\n" for r in results)
286
+ user_prompt = f"Use the following context to answer the question.\n\nContext:\n{context}\nQuestion: {question}\nAnswer:"
287
+ return openrouter_chat_system_user(user_prompt)
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  # ---------------- Gradio UI ----------------
291
+ with gr.Blocks(title="AI Ally — Crawl4AI + OpenRouter + FAISS") as demo:
292
+ gr.Markdown("# 🤖 AI Ally — Document & Website QA\nCrawl4AI for websites, file uploads for docs. FAISS retrieval + sentence-transformers + OpenRouter LLM.")
293
 
294
  with gr.Tab("Documents"):
295
+ file_input = gr.File(label="Upload files", file_count="multiple",
296
+ file_types=[".pdf", ".docx", ".txt", ".xlsx", ".pptx", ".csv"])
297
+ upload_btn = gr.Button("Upload & Index")
298
+ upload_status = gr.Textbox(label="Status", interactive=False)
299
+ preview_box = gr.Textbox(label="Uploads (preview JSON)", interactive=False)
 
300
  upload_btn.click(upload_and_index, inputs=[file_input], outputs=[upload_status, preview_box])
301
 
302
+ gr.Markdown("### Ask about your documents")
303
+ q = gr.Textbox(label="Question", lines=3)
 
304
  ask_btn = gr.Button("Ask")
305
+ answer_out = gr.Textbox(label="Answer", interactive=False, lines=15)
306
+ ask_btn.click(ask_question, inputs=[q], outputs=[answer_out])
307
 
308
  with gr.Tab("Website Crawl"):
309
+ url = gr.Textbox(label="URL to crawl")
310
+ crawl_btn = gr.Button("Crawl & Index")
311
+ crawl_status = gr.Textbox(label="Status", interactive=False)
312
+ crawl_preview = gr.Textbox(label="Crawl preview", interactive=False)
 
 
313
  crawl_btn.click(crawl_and_index, inputs=[url], outputs=[crawl_status, crawl_preview])
314
 
315
+ q2 = gr.Textbox(label="Question", lines=3)
 
 
316
  ask_btn2 = gr.Button("Ask site")
317
+ answer_out2 = gr.Textbox(label="Answer", interactive=False, lines=15)
318
+ ask_btn2.click(ask_question, inputs=[q2], outputs=[answer_out2])
319
 
320
  with gr.Tab("Settings / Info"):
321
+ gr.Markdown(f"- Model: `{OPENROUTER_MODEL}`")
322
  gr.Markdown(f"- Embedding model: `{EMBEDDING_MODEL_NAME}`")
323
+ gr.Markdown(f"- Cache clears automatically every 5 minutes.")
324
+ gr.Markdown(f"- System prompt is fixed internally: `{SYSTEM_PROMPT}`")
 
 
325
 
326
  if __name__ == "__main__":
327
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)