RAG / app.py
PraneshJs's picture
Create app.py
a6913bb verified
raw
history blame
16.8 kB
# app.py
import os
import asyncio
import json
import hashlib
from io import BytesIO, StringIO
from typing import List, Tuple
import gradio as gr
import numpy as np
import faiss
import requests
import pandas as pd
from sentence_transformers import SentenceTransformer
# file parsing libs
import fitz # PyMuPDF
import docx
from pptx import Presentation
# crawl4ai
from crawl4ai import AsyncWebCrawler
# ---------------- Config ----------------
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
OPENROUTER_MODEL = "microsoft/mai-ds-r1:free"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
CACHE_DIR = "./cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# sentence-transformers embedder (loads once)
embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
# Global in-memory stores (cleared/updated by UI actions)
DOCS: List[str] = []
FILENAMES: List[str] = []
EMBEDDINGS: np.ndarray = None
FAISS_INDEX = None
CURRENT_CACHE_KEY: str = ""
# ---------------- File extraction helpers ----------------
def extract_text_from_pdf(file_bytes: bytes) -> str:
try:
doc = fitz.open(stream=file_bytes, filetype="pdf")
pages = [page.get_text() for page in doc]
return "\n".join(pages)
except Exception as e:
return f"[PDF extraction error] {e}"
def extract_text_from_docx(file_bytes: bytes) -> str:
try:
f = BytesIO(file_bytes)
doc = docx.Document(f)
return "\n".join([p.text for p in doc.paragraphs])
except Exception as e:
return f"[DOCX extraction error] {e}"
def extract_text_from_txt(file_bytes: bytes) -> str:
try:
return file_bytes.decode("utf-8", errors="ignore")
except Exception as e:
return f"[TXT extraction error] {e}"
def extract_text_from_excel(file_bytes: bytes) -> str:
try:
f = BytesIO(file_bytes)
df = pd.read_excel(f, dtype=str)
parts = []
for col in df.columns:
parts.append("\n".join(df[col].fillna("").astype(str).tolist()))
return "\n".join(parts)
except Exception as e:
return f"[EXCEL extraction error] {e}"
def extract_text_from_pptx(file_bytes: bytes) -> str:
try:
f = BytesIO(file_bytes)
prs = Presentation(f)
texts = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
texts.append(shape.text)
return "\n".join(texts)
except Exception as e:
return f"[PPTX extraction error] {e}"
def extract_text_from_csv(file_bytes: bytes) -> str:
try:
f = StringIO(file_bytes.decode("utf-8", errors="ignore"))
df = pd.read_csv(f, dtype=str)
return df.to_string(index=False)
except Exception as e:
return f"[CSV extraction error] {e}"
def extract_text_from_file_tuple(file_tuple) -> Tuple[str, bytes]:
"""
Accepts a Gradio file object/tuple and returns (filename, bytes).
Robust to multiple gradio versions.
"""
# gradio v3.x passes TemporaryFile-like object with .name & .read()
try:
if hasattr(file_tuple, "name") and hasattr(file_tuple, "read"):
filename = os.path.basename(file_tuple.name)
file_bytes = file_tuple.read()
return filename, file_bytes
except Exception:
pass
# other shapes: tuple or dict-like
try:
# file_tuple may be (name, bytes)
if isinstance(file_tuple, tuple) and len(file_tuple) == 2 and isinstance(file_tuple[1], (bytes, bytearray)):
return file_tuple[0], bytes(file_tuple[1])
except Exception:
pass
# fallback if path string provided
try:
if isinstance(file_tuple, str) and os.path.exists(file_tuple):
with open(file_tuple, "rb") as fh:
return os.path.basename(file_tuple), fh.read()
except Exception:
pass
raise ValueError("Unsupported file object passed by Gradio.")
def extract_text_by_ext(filename: str, file_bytes: bytes) -> str:
name = filename.lower()
if name.endswith(".pdf"):
return extract_text_from_pdf(file_bytes)
if name.endswith(".docx"):
return extract_text_from_docx(file_bytes)
if name.endswith(".txt"):
return extract_text_from_txt(file_bytes)
if name.endswith(".xlsx") or name.endswith(".xls"):
return extract_text_from_excel(file_bytes)
if name.endswith(".pptx"):
return extract_text_from_pptx(file_bytes)
if name.endswith(".csv"):
return extract_text_from_csv(file_bytes)
# fallback: try plain text
return extract_text_from_txt(file_bytes)
# ---------------- Embedding caching helpers ----------------
def make_cache_key_for_files(files: List[Tuple[str, bytes]]) -> str:
"""
Create a deterministic cache key based on filenames + sizes + sha256 of each file content.
"""
h = hashlib.sha256()
for name, b in sorted(files, key=lambda x: x[0]):
h.update(name.encode("utf-8"))
h.update(str(len(b)).encode("utf-8"))
# update with small digest to keep speed; still robust
h.update(hashlib.sha256(b).digest())
return h.hexdigest()
def cache_save_embeddings(cache_key: str, embeddings: np.ndarray, filenames: List[str]):
path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
np.savez_compressed(path, embeddings=embeddings, filenames=np.array(filenames))
return path
def cache_load_embeddings(cache_key: str):
path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
if not os.path.exists(path):
return None
try:
arr = np.load(path, allow_pickle=True)
embeddings = arr["embeddings"]
filenames = arr["filenames"].tolist()
return embeddings, filenames
except Exception:
return None
# ---------------- FAISS helpers ----------------
def build_faiss_index(embeddings: np.ndarray):
global FAISS_INDEX
if embeddings is None or len(embeddings) == 0:
FAISS_INDEX = None
return None
emb = embeddings.astype("float32")
dim = emb.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(emb)
FAISS_INDEX = index
return index
def search_top_k(query: str, k: int = 3):
if FAISS_INDEX is None:
return []
q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
D, I = FAISS_INDEX.search(q_emb, k)
results = []
for dist, idx in zip(D[0], I[0]):
if idx < 0:
continue
results.append({
"index": int(idx),
"distance": float(dist),
"text": DOCS[idx],
"source": FILENAMES[idx]
})
return results
# ---------------- OpenRouter minimal client ----------------
def openrouter_chat_system_user(system_prompt: str, user_prompt: str):
"""
Sends only 'model' and 'messages' payload (system + user) to OpenRouter,
per your requirement (no max_tokens, temperature, etc).
"""
if not OPENROUTER_API_KEY:
return "[OpenRouter error] OPENROUTER_API_KEY not set."
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})
payload = {"model": OPENROUTER_MODEL, "messages": messages}
try:
r = requests.post(url, headers=headers, json=payload, timeout=60)
r.raise_for_status()
obj = r.json()
# Expecting OpenAI-like structure: choices[0].message.content
if "choices" in obj and len(obj["choices"]) > 0:
choice = obj["choices"][0]
if "message" in choice and "content" in choice["message"]:
return choice["message"]["content"]
if "text" in choice:
return choice["text"]
# fallback: return entire partial json for debugging
return json.dumps(obj, indent=2)[:12000]
except Exception as e:
return f"[OpenRouter request error] {e}"
# ---------------- Crawl4AI robust logic ----------------
async def _crawl_async_get_markdown(url: str):
# uses default crawler settings; adjust with run config if needed
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url=url)
# prefer a success flag if present
if hasattr(result, "success") and result.success is False:
# attempt to surface error
err = getattr(result, "error_message", None) or getattr(result, "error", None) or "[Crawl4AI unknown error]"
return f"[Crawl4AI error] {err}"
# try structured markdown first
md_obj = getattr(result, "markdown", None)
if md_obj:
# try common subfields observed in different versions
text = getattr(md_obj, "fit_markdown", None) or getattr(md_obj, "raw_markdown", None)
if text:
return text
# fallback to str(md_obj)
try:
return str(md_obj)
except Exception:
pass
# fallback to text or html
text = getattr(result, "text", None) or getattr(result, "html", None)
if text:
return text
# last resort: jsonify entire result (short)
try:
return json.dumps(result.__dict__, default=str)[:20000]
except Exception:
return "[Crawl4AI returned no usable fields]"
def crawl_url_sync(url: str) -> str:
try:
return asyncio.run(_crawl_async_get_markdown(url))
except Exception as e:
return f"[Crawl4AI runtime error] {e}"
# ---------------- Gradio handlers ----------------
def upload_and_index(files):
"""
files: list of file objects from Gradio. We'll extract bytes, compute cache key,
try to load embeddings from cache; if not found, compute embeddings and save.
"""
global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
if not files:
return "No files uploaded.", ""
# read files into list of (name, bytes)
prepared = []
previews = []
for f in files:
name, b = extract_text_from_file_tuple(f)
prepared.append((name, b))
# short preview
previews.append({"name": name, "size": len(b)})
cache_key = make_cache_key_for_files(prepared)
CURRENT_CACHE_KEY = cache_key
# Try load existing embeddings
cached = cache_load_embeddings(cache_key)
if cached:
emb, filenames = cached
EMBEDDINGS = np.array(emb)
FILENAMES = filenames
# Rebuild DOCS array: we still need textual content (not just embeddings)
DOCS = []
for name, b in prepared:
DOCS.append(extract_text_by_ext(name, b))
# Build faiss index
build_faiss_index(EMBEDDINGS)
return f"Loaded embeddings from cache ({len(FILENAMES)} docs).", json.dumps(previews)
# Not cached -> extract texts and embed
DOCS = []
FILENAMES = []
for name, b in prepared:
txt = extract_text_by_ext(name, b)
DOCS.append(txt)
FILENAMES.append(name)
# Compute embeddings
emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
EMBEDDINGS = emb
# Save to cache
cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES)
# Build faiss
build_faiss_index(EMBEDDINGS)
return f"Uploaded and indexed {len(DOCS)} documents.", json.dumps(previews)
def crawl_and_index(url: str):
global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY
if not url:
return "No URL provided.", ""
crawled = crawl_url_sync(url)
if crawled.startswith("[Crawl4AI"):
return crawled, ""
# create a cache key based on url and content
key_hash = hashlib.sha256()
key_hash.update(url.encode("utf-8"))
key_hash.update(crawled.encode("utf-8"))
cache_key = key_hash.hexdigest()
CURRENT_CACHE_KEY = cache_key
cached = cache_load_embeddings(cache_key)
if cached:
emb, filenames = cached
EMBEDDINGS = np.array(emb)
FILENAMES = filenames
DOCS = [crawled]
build_faiss_index(EMBEDDINGS)
return f"Crawled and loaded embeddings from cache for {url}", crawled[:2000]
# Not cached -> index
DOCS = [crawled]
FILENAMES = [url]
emb = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
EMBEDDINGS = emb
cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES)
build_faiss_index(EMBEDDINGS)
return f"Crawled and indexed {url}", crawled[:2000]
def ask_question(question: str, system_prompt: str = ""):
if not question:
return "Please enter a question."
if not DOCS or FAISS_INDEX is None:
return "No indexed documents. Upload files or crawl a site first."
topk = 3
results = search_top_k(question, k=topk)
if not results:
return "No relevant documents found."
# prepare context from top results (trim each)
context_blocks = []
meta = []
for r in results:
snippet = r["text"]
if len(snippet) > 1800:
snippet = snippet[:1800] + "\n...[truncated]"
context_blocks.append(f"Source: {r['source']}\n\n{snippet}\n\n---\n")
meta.append({"source": r["source"], "distance": r["distance"]})
context = "\n".join(context_blocks)
user_prompt = f"Use the following context to answer the question, and cite sources from the 'Source:' lines.\n\nContext:\n{context}\nQuestion: {question}\nAnswer:"
# Call OpenRouter with only model + messages (system & user)
try:
answer = openrouter_chat_system_user(system_prompt=system_prompt, user_prompt=user_prompt)
except Exception as e:
answer = f"[OpenRouter call failed] {e}"
out = {"answer": answer, "sources": meta}
return json.dumps(out, indent=2)
# ---------------- Gradio UI ----------------
with gr.Blocks(title="AI Ally (Gradio) — Crawl4AI + OpenRouter + FAISS") as demo:
gr.Markdown("# AI Ally — Document & Website QA\nCrawl4AI for websites, local file uploads for docs. FAISS retrieval + sentence-transformers embeddings. OpenRouter used for generation (only model + messages).")
with gr.Tab("Documents"):
with gr.Row():
file_input = gr.File(label="Upload files", file_count="multiple", file_types=[".pdf", ".docx", ".txt", ".xlsx", ".pptx", ".csv"])
upload_btn = gr.Button("Upload & Index")
with gr.Row():
upload_status = gr.Textbox(label="Status", interactive=False)
preview_box = gr.Textbox(label="Uploads (preview JSON)", interactive=False)
upload_btn.click(upload_and_index, inputs=[file_input], outputs=[upload_status, preview_box])
gr.Markdown("### Ask about the indexed documents")
q = gr.Textbox(label="Question", lines=3)
sys_prompt = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=2, value="You are a helpful assistant.")
ask_btn = gr.Button("Ask")
answer_out = gr.Textbox(label="Answer JSON", interactive=False)
ask_btn.click(ask_question, inputs=[q, sys_prompt], outputs=[answer_out])
with gr.Tab("Website Crawl"):
with gr.Row():
url = gr.Textbox(label="URL to crawl (starting URL)")
crawl_btn = gr.Button("Crawl & Index")
with gr.Row():
crawl_status = gr.Textbox(label="Status", interactive=False)
crawl_preview = gr.Textbox(label="Crawl preview (first 2k chars)", interactive=False)
crawl_btn.click(crawl_and_index, inputs=[url], outputs=[crawl_status, crawl_preview])
gr.Markdown("### Ask about the crawled site")
q2 = gr.Textbox(label="Question", lines=3)
sys_prompt2 = gr.Textbox(label="Optional System Prompt (sent to LLM)", lines=2, value="You are a helpful assistant.")
ask_btn2 = gr.Button("Ask site")
answer_out2 = gr.Textbox(label="Answer JSON", interactive=False)
ask_btn2.click(ask_question, inputs=[q2, sys_prompt2], outputs=[answer_out2])
with gr.Tab("Settings / Info"):
gr.Markdown(f"- OpenRouter model: `{OPENROUTER_MODEL}`")
gr.Markdown(f"- Embedding model: `{EMBEDDING_MODEL_NAME}`")
gr.Markdown("Set `OPENROUTER_API_KEY` in your environment or HF Secrets before deploying.")
gr.Markdown("Cache directory: `" + CACHE_DIR + "`")
gr.Markdown("----\nNotes: This app saves embeddings to `./cache/` using a deterministic cache key. OpenRouter calls include only `model` + `messages` (system + user) as requested.")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)