| |
| import os |
| import asyncio |
| import json |
| import hashlib |
| from io import BytesIO, StringIO |
| from typing import List, Tuple |
|
|
| import gradio as gr |
| import numpy as np |
| import faiss |
| import requests |
| import pandas as pd |
| from sentence_transformers import SentenceTransformer |
|
|
| |
| import fitz |
| import docx |
| from pptx import Presentation |
|
|
| |
| from crawl4ai import AsyncWebCrawler |
|
|
| |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") |
| OPENROUTER_MODEL = "microsoft/mai-ds-r1:free" |
| EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" |
| CACHE_DIR = "./cache" |
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| |
| embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
|
| |
| DOCS: List[str] = [] |
| FILENAMES: List[str] = [] |
| EMBEDDINGS: np.ndarray = None |
| FAISS_INDEX = None |
| CURRENT_CACHE_KEY: str = "" |
|
|
| |
| def extract_text_from_pdf(file_bytes: bytes) -> str: |
| try: |
| doc = fitz.open(stream=file_bytes, filetype="pdf") |
| pages = [page.get_text() for page in doc] |
| return "\n".join(pages) |
| except Exception as e: |
| return f"[PDF extraction error] {e}" |
|
|
| def extract_text_from_docx(file_bytes: bytes) -> str: |
| try: |
| f = BytesIO(file_bytes) |
| doc = docx.Document(f) |
| return "\n".join([p.text for p in doc.paragraphs]) |
| except Exception as e: |
| return f"[DOCX extraction error] {e}" |
|
|
| def extract_text_from_txt(file_bytes: bytes) -> str: |
| try: |
| return file_bytes.decode("utf-8", errors="ignore") |
| except Exception as e: |
| return f"[TXT extraction error] {e}" |
|
|
| def extract_text_from_excel(file_bytes: bytes) -> str: |
| try: |
| f = BytesIO(file_bytes) |
| df = pd.read_excel(f, dtype=str) |
| parts = [] |
| for col in df.columns: |
| parts.append("\n".join(df[col].fillna("").astype(str).tolist())) |
| return "\n".join(parts) |
| except Exception as e: |
| return f"[EXCEL extraction error] {e}" |
|
|
| def extract_text_from_pptx(file_bytes: bytes) -> str: |
| try: |
| f = BytesIO(file_bytes) |
| prs = Presentation(f) |
| texts = [] |
| for slide in prs.slides: |
| for shape in slide.shapes: |
| if hasattr(shape, "text"): |
| texts.append(shape.text) |
| return "\n".join(texts) |
| except Exception as e: |
| return f"[PPTX extraction error] {e}" |
|
|
| def extract_text_from_csv(file_bytes: bytes) -> str: |
| try: |
| f = StringIO(file_bytes.decode("utf-8", errors="ignore")) |
| df = pd.read_csv(f, dtype=str) |
| return df.to_string(index=False) |
| except Exception as e: |
| return f"[CSV extraction error] {e}" |
|
|
| def extract_text_from_file_tuple(file_tuple) -> Tuple[str, bytes]: |
| """ |
| Accepts a Gradio file object/tuple and returns (filename, bytes). |
| Robust to multiple gradio versions. |
| """ |
| |
| try: |
| if hasattr(file_tuple, "name") and hasattr(file_tuple, "read"): |
| filename = os.path.basename(file_tuple.name) |
| file_bytes = file_tuple.read() |
| return filename, file_bytes |
| except Exception: |
| pass |
| |
| try: |
| |
| if isinstance(file_tuple, tuple) and len(file_tuple) == 2 and isinstance(file_tuple[1], (bytes, bytearray)): |
| return file_tuple[0], bytes(file_tuple[1]) |
| except Exception: |
| pass |
| |
| try: |
| if isinstance(file_tuple, str) and os.path.exists(file_tuple): |
| with open(file_tuple, "rb") as fh: |
| return os.path.basename(file_tuple), fh.read() |
| except Exception: |
| pass |
| raise ValueError("Unsupported file object passed by Gradio.") |
|
|
| def extract_text_by_ext(filename: str, file_bytes: bytes) -> str: |
| name = filename.lower() |
| if name.endswith(".pdf"): |
| return extract_text_from_pdf(file_bytes) |
| if name.endswith(".docx"): |
| return extract_text_from_docx(file_bytes) |
| if name.endswith(".txt"): |
| return extract_text_from_txt(file_bytes) |
| if name.endswith(".xlsx") or name.endswith(".xls"): |
| return extract_text_from_excel(file_bytes) |
| if name.endswith(".pptx"): |
| return extract_text_from_pptx(file_bytes) |
| if name.endswith(".csv"): |
| return extract_text_from_csv(file_bytes) |
| |
| return extract_text_from_txt(file_bytes) |
|
|
| |
| def make_cache_key_for_files(files: List[Tuple[str, bytes]]) -> str: |
| """ |
| Create a deterministic cache key based on filenames + sizes + sha256 of each file content. |
| """ |
| h = hashlib.sha256() |
| for name, b in sorted(files, key=lambda x: x[0]): |
| h.update(name.encode("utf-8")) |
| h.update(str(len(b)).encode("utf-8")) |
| |
| h.update(hashlib.sha256(b).digest()) |
| return h.hexdigest() |
|
|
| def cache_save_embeddings(cache_key: str, embeddings: np.ndarray, filenames: List[str]): |
| path = os.path.join(CACHE_DIR, f"{cache_key}.npz") |
| np.savez_compressed(path, embeddings=embeddings, filenames=np.array(filenames)) |
| return path |
|
|
| def cache_load_embeddings(cache_key: str): |
| path = os.path.join(CACHE_DIR, f"{cache_key}.npz") |
| if not os.path.exists(path): |
| return None |
| try: |
| arr = np.load(path, allow_pickle=True) |
| embeddings = arr["embeddings"] |
| filenames = arr["filenames"].tolist() |
| return embeddings, filenames |
| except Exception: |
| return None |
|
|
| |
| def build_faiss_index(embeddings: np.ndarray): |
| global FAISS_INDEX |
| if embeddings is None or len(embeddings) == 0: |
| FAISS_INDEX = None |
| return None |
| emb = embeddings.astype("float32") |
| dim = emb.shape[1] |
| index = faiss.IndexFlatL2(dim) |
| index.add(emb) |
| FAISS_INDEX = index |
| return index |
|
|
| def search_top_k(query: str, k: int = 3): |
| if FAISS_INDEX is None: |
| return [] |
| q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32") |
| D, I = FAISS_INDEX.search(q_emb, k) |
| results = [] |
| for dist, idx in zip(D[0], I[0]): |
| if idx < 0: |
| continue |
| results.append({ |
| "index": int(idx), |
| "distance": float(dist), |
| "text": DOCS[idx], |
| "source": FILENAMES[idx] |
| }) |
| return results |
|
|
| |
| def openrouter_chat_system_user(system_prompt: str, user_prompt: str): |
| """ |
| Sends only 'model' and 'messages' payload (system + user) to OpenRouter, |
| per your requirement (no max_tokens, temperature, etc). |
| """ |
| if not OPENROUTER_API_KEY: |
| return "[OpenRouter error] OPENROUTER_API_KEY not set." |
|
|
| url = "https://openrouter.ai/api/v1/chat/completions" |
| headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"} |
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": user_prompt}) |
|
|
| payload = {"model": OPENROUTER_MODEL, "messages": messages} |
| try: |
| r = requests.post(url, headers=headers, json=payload, timeout=60) |
| r.raise_for_status() |
| obj = r.json() |
| |
| if "choices" in obj and len(obj["choices"]) > 0: |
| choice = obj["choices"][0] |
| if "message" in choice and "content" in choice["message"]: |
| return choice["message"]["content"] |
| if "text" in choice: |
| return choice["text"] |
| |
| return json.dumps(obj, indent=2)[:12000] |
| except Exception as e: |
| return f"[OpenRouter request error] {e}" |
|
|
| |
| async def _crawl_async_get_markdown(url: str): |
| |
| async with AsyncWebCrawler() as crawler: |
| result = await crawler.arun(url=url) |
| |
| if hasattr(result, "success") and result.success is False: |
| |
| err = getattr(result, "error_message", None) or getattr(result, "error", None) or "[Crawl4AI unknown error]" |
| return f"[Crawl4AI error] {err}" |
|
|
| |
| md_obj = getattr(result, "markdown", None) |
| if md_obj: |
| |
| text = getattr(md_obj, "fit_markdown", None) or getattr(md_obj, "raw_markdown", None) |
| if text: |
| return text |
| |
| try: |
| return str(md_obj) |
| except Exception: |
| pass |
|
|
| |
| text = getattr(result, "text", None) or getattr(result, "html", None) |
| if text: |
| return text |
| |
| try: |
| return json.dumps(result.__dict__, default=str)[:20000] |
| except Exception: |
| return "[Crawl4AI returned no usable fields]" |
|
|
| def crawl_url_sync(url: str) -> str: |
| try: |
| return asyncio.run(_crawl_async_get_markdown(url)) |
| except Exception as e: |
| return f"[Crawl4AI runtime error] {e}" |
|
|
| |
| def upload_and_index(files): |
| """ |
| files: list of file objects from Gradio. We'll extract bytes, compute cache key, |
| try to load embeddings from cache; if not found, compute embeddings and save. |
| """ |
| global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY |
|
|
| if not files: |
| return "No files uploaded.", "" |
|
|
| |
| prepared = [] |
| previews = [] |
| for f in files: |
| name, b = extract_text_from_file_tuple(f) |
| prepared.append((name, b)) |
| |
| previews.append({"name": name, "size": len(b)}) |
|
|
| cache_key = make_cache_key_for_files(prepared) |
| CURRENT_CACHE_KEY = cache_key |
|
|
| |
| cached = cache_load_embeddings(cache_key) |
| if cached: |
| emb, filenames = cached |
| EMBEDDINGS = np.array(emb) |
| FILENAMES = filenames |
| |
| DOCS = [] |
| for name, b in prepared: |
| DOCS.append(extract_text_by_ext(name, b)) |
| |
| build_faiss_index(EMBEDDINGS) |
| return f"Loaded embeddings from cache ({len(FILENAMES)} docs).", json.dumps(previews) |
|
|
| |
| DOCS = [] |
| FILENAMES = [] |
| for name, b in prepared: |
| txt = extract_text_by_ext(name, b) |
| DOCS.append(txt) |
| FILENAMES.append(name) |
|
|
| |
| emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32") |
| EMBEDDINGS = emb |
| |
| cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES) |
| |
| build_faiss_index(EMBEDDINGS) |
|
|
| return f"Uploaded and indexed {len(DOCS)} documents.", json.dumps(previews) |
|
|
| def crawl_and_index(url: str): |
| global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY |
| if not url: |
| return "No URL provided.", "" |
|
|
| crawled = crawl_url_sync(url) |
| if crawled.startswith("[Crawl4AI"): |
| return crawled, "" |
|
|
| |
| key_hash = hashlib.sha256() |
| key_hash.update(url.encode("utf-8")) |
| key_hash.update(crawled.encode("utf-8")) |
| cache_key = key_hash.hexdigest() |
| CURRENT_CACHE_KEY = cache_key |
|
|
| cached = cache_load_embeddings(cache_key) |
| if cached: |
| emb, filenames = cached |
| EMBEDDINGS = np.array(emb) |
| FILENAMES = filenames |
| DOCS = [crawled] |
| build_faiss_index(EMBEDDINGS) |
| return f"Crawled and loaded embeddings from cache for {url}", crawled[:2000] |
|
|
| |
| DOCS = [crawled] |
| FILENAMES = [url] |
| emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32") |
| EMBEDDINGS = emb |
| cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES) |
| build_faiss_index(EMBEDDINGS) |
| return f"Crawled and indexed {url}", crawled[:2000] |
|
|
| def ask_question(question: str, system_prompt: str = ""): |
| if not question: |
| return "Please enter a question." |
| if not DOCS or FAISS_INDEX is None: |
| return "No indexed documents. Upload files or crawl a site first." |
|
|
| topk = 3 |
| results = search_top_k(question, k=topk) |
| if not results: |
| return "No relevant documents found." |
|
|
| |
| context_blocks = [] |
| meta = [] |
| for r in results: |
| snippet = r["text"] |
| if len(snippet) > 1800: |
| snippet = snippet[:1800] + "\n...[truncated]" |
| context_blocks.append(f"Source: {r['source']}\n\n{snippet}\n\n---\n") |
| meta.append({"source": r["source"], "distance": r["distance"]}) |
|
|
| context = "\n".join(context_blocks) |
| user_prompt = f"Use the following context to answer the question, and cite sources from the 'Source:' lines.\n\nContext:\n{context}\nQuestion: {question}\nAnswer:" |
|
|
| |
| try: |
| answer = openrouter_chat_system_user(system_prompt=system_prompt, user_prompt=user_prompt) |
| except Exception as e: |
| answer = f"[OpenRouter call failed] {e}" |
|
|
| out = {"answer": answer, "sources": meta} |
| return json.dumps(out, indent=2) |
|
|
| |
| with gr.Blocks(title="AI Ally (Gradio) — Crawl4AI + OpenRouter + FAISS") as demo: |
| gr.Markdown("# AI Ally — Document & Website QA\nCrawl4AI for websites, local file uploads for docs. FAISS retrieval + sentence-transformers embeddings. OpenRouter used for generation (only model + messages).") |
|
|
| with gr.Tab("Documents"): |
| with gr.Row(): |
| file_input = gr.File(label="Upload files", file_count="multiple", file_types=[".pdf", ".docx", ".txt", ".xlsx", ".pptx", ".csv"]) |
| upload_btn = gr.Button("Upload & Index") |
| with gr.Row(): |
| upload_status = gr.Textbox(label="Status", interactive=False) |
| preview_box = gr.Textbox(label="Uploads (preview JSON)", interactive=False) |
| upload_btn.click(upload_and_index, inputs=[file_input], outputs=[upload_status, preview_box]) |
|
|
| gr.Markdown("### Ask about the indexed documents") |
| q = gr.Textbox(label="Question", lines=3) |
| sys_prompt = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=2, value="You are a helpful assistant.") |
| ask_btn = gr.Button("Ask") |
| answer_out = gr.Textbox(label="Answer JSON", interactive=False) |
| ask_btn.click(ask_question, inputs=[q, sys_prompt], outputs=[answer_out]) |
|
|
| with gr.Tab("Website Crawl"): |
| with gr.Row(): |
| url = gr.Textbox(label="URL to crawl (starting URL)") |
| crawl_btn = gr.Button("Crawl & Index") |
| with gr.Row(): |
| crawl_status = gr.Textbox(label="Status", interactive=False) |
| crawl_preview = gr.Textbox(label="Crawl preview (first 2k chars)", interactive=False) |
| crawl_btn.click(crawl_and_index, inputs=[url], outputs=[crawl_status, crawl_preview]) |
|
|
| gr.Markdown("### Ask about the crawled site") |
| q2 = gr.Textbox(label="Question", lines=3) |
| sys_prompt2 = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=2, value="You are a helpful assistant.") |
| ask_btn2 = gr.Button("Ask site") |
| answer_out2 = gr.Textbox(label="Answer JSON", interactive=False) |
| ask_btn2.click(ask_question, inputs=[q2, sys_prompt2], outputs=[answer_out2]) |
|
|
| with gr.Tab("Settings / Info"): |
| gr.Markdown(f"- OpenRouter model: `{OPENROUTER_MODEL}`") |
| gr.Markdown(f"- Embedding model: `{EMBEDDING_MODEL_NAME}`") |
| gr.Markdown("Set `OPENROUTER_API_KEY` in your environment or HF Secrets before deploying.") |
| gr.Markdown("Cache directory: `" + CACHE_DIR + "`") |
|
|
| gr.Markdown("----\nNotes: This app saves embeddings to `./cache/` using a deterministic cache key. OpenRouter calls include only `model` + `messages` (system + user) as requested.") |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|