testRAG / app.py
Hesamnasiri's picture
Update app.py
52ea860 verified
import os
import uuid
import tempfile
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional, List, Tuple, Any
import json
import re
import gradio as gr
from pypdf import PdfReader
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer, CrossEncoder
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# -----------------------------
# Config
# -----------------------------
EMBED_MODEL_NAME = "BAAI/bge-large-en-v1.5" # 1024-dim cosine
RERANK_MODEL_NAME = "BAAI/bge-reranker-large" # cross-encoder reranker
COLLECTION_NAME = "pdf_rag_en"
DEFAULT_MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct", # CPU-friendly
"microsoft/Phi-3.1-mini-4k-instruct", # CPU-friendly
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
]
# Strict RAG system prompt
SYSTEM_GUARDRAILS = (
"You are a strict RAG assistant. Answer ONLY using the provided context. "
"If the answer is not present, say: 'I don't know based on the provided PDFs.' "
"Cite sources as [filename p.PAGE]. Keep answers concise and factual."
)
# -----------------------------
# Utilities
# -----------------------------
def read_pdf_to_pages(file_path: str):
reader = PdfReader(file_path)
pages = []
for i, page in enumerate(reader.pages, start=1):
try:
text = page.extract_text() or ""
except Exception:
text = ""
text = "\n".join(line.strip() for line in text.splitlines() if line.strip())
pages.append((i, text))
return pages
def chunk_text(text: str, words_per_chunk: int = 350, overlap: int = 40):
"""Simple word-based chunking suitable for English PDFs and BGE limits."""
words = text.split()
if not words:
return []
chunks = []
i = 0
while i < len(words):
j = min(i + words_per_chunk, len(words))
chunk = " ".join(words[i:j])
chunks.append(chunk)
i = j - overlap if j - overlap > i else j
return chunks
@dataclass
class RetrievedChunk:
text: str
score: float
file: str
page: int
# -----------------------------
# JSON helpers
# -----------------------------
def read_json_file(path: str):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
return {"__error__": str(e)}
# -----------------------------
# Embeddings & Reranker
# -----------------------------
@lru_cache(maxsize=1)
def load_embedder():
model = SentenceTransformer(EMBED_MODEL_NAME)
return model
@lru_cache(maxsize=1)
def load_reranker():
return CrossEncoder(RERANK_MODEL_NAME)
# -----------------------------
# Qdrant (Vector DB)
# -----------------------------
def get_qdrant_client():
url = os.getenv("QDRANT_URL", "").strip()
key = os.getenv("QDRANT_API_KEY", "").strip()
if not url or not key:
raise RuntimeError("Missing QDRANT_URL or QDRANT_API_KEY in Secrets.")
client = QdrantClient(url=url, api_key=key, timeout=60)
return client
def ensure_collection(client: QdrantClient, vector_size: int):
collections = client.get_collections().collections
names = {c.name for c in collections}
if COLLECTION_NAME not in names:
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
# -----------------------------
# LLM Loader (with optional 4-bit)
# -----------------------------
_current_model_name = None
_pipe = None
def load_llm(model_name: str, use_4bit: bool = True):
global _current_model_name, _pipe
if _pipe is not None and _current_model_name == model_name:
return _pipe
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
quant_kwargs = {}
if use_4bit:
try:
quant_kwargs = dict(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
)
except Exception:
quant_kwargs = {}
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
**quant_kwargs,
)
_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=False,
)
_current_model_name = model_name
return _pipe
except Exception as e:
raise RuntimeError(f"Failed to load model {model_name}: {e}")
# -----------------------------
# Indexing
# -----------------------------
def _normalize_file_inputs(files):
if not files:
return []
if not isinstance(files, (list, tuple)):
files = [files]
paths = []
for f in files:
if f is None:
continue
try:
p = str(f) # Gradio NamedString -> temp filepath
except Exception:
p = getattr(f, "name", None) or getattr(f, "path", None)
if p:
paths.append(p)
return paths
def ingest_pdfs(files, collection_name: str = COLLECTION_NAME):
file_paths = _normalize_file_inputs(files)
if not file_paths:
return "No files received. Please upload PDF(s) first."
embedder = load_embedder()
vector_dim = embedder.get_sentence_embedding_dimension()
client = get_qdrant_client()
ensure_collection(client, vector_dim)
total_chunks = 0
for path in file_paths:
pages = read_pdf_to_pages(path)
fname = os.path.basename(path)
payloads, vectors, ids = [], [], []
for page_num, text in pages:
if not text:
continue
for chunk_idx, chunk in enumerate(chunk_text(text)):
vec = embedder.encode(
chunk,
normalize_embeddings=True,
convert_to_numpy=True,
)
vectors.append(vec)
ids.append(str(uuid.uuid4()))
payloads.append({
"text": chunk,
"file": fname,
"page": page_num,
"chunk": chunk_idx,
})
total_chunks += 1
if vectors:
client.upsert(
collection_name=collection_name,
points=[PointStruct(id=pid, vector=v.tolist(), payload=pl)
for pid, v, pl in zip(ids, vectors, payloads)]
)
return f"Indexed ~{total_chunks} chunks from {len(file_paths)} PDF(s)."
# -----------------------------
# Retrieval + Rerank
# -----------------------------
def retrieve(query: str, top_k: int = 16, score_threshold: float = 0.25):
embedder = load_embedder()
client = get_qdrant_client()
qvec = embedder.encode(query, normalize_embeddings=True, convert_to_numpy=True)
candidates = client.search(
collection_name=COLLECTION_NAME,
query_vector=qvec.tolist(),
limit=top_k,
with_payload=True,
score_threshold=score_threshold,
)
chunks: List[RetrievedChunk] = []
for p in candidates:
payload = p.payload or {}
chunks.append(
RetrievedChunk(
text=payload.get("text", ""),
score=float(p.score or 0.0),
file=payload.get("file", "unknown.pdf"),
page=int(payload.get("page", 0)),
)
)
return chunks
def rerank(query: str, chunks: List[RetrievedChunk], top_n: int = 6):
if not chunks:
return []
reranker = load_reranker()
pairs = [(query, c.text) for c in chunks]
scores = reranker.predict(pairs)
ranked = sorted(zip(chunks, scores), key=lambda x: float(x[1]), reverse=True)
top = [c for c, _ in ranked[:top_n]]
return top
# -----------------------------
# QA Generation (with optional JSON compare)
# -----------------------------
def build_prompt(query: str, contexts: List[RetrievedChunk], json_text: Optional[str] = None):
context_text = "\n\n".join([c.text for c in contexts])
json_block = f"\n\nJSON_SPEC:\n{json_text}\n" if json_text else ""
prompt = (
f"<s>[SYSTEM]\n{SYSTEM_GUARDRAILS}\nIf a JSON spec is provided, compare it to the PDF context: identify agreements, conflicts, and missing fields explicitly.\n[/SYSTEM]\n"
f"[USER]\nQuestion: {query}\n\nContext from PDFs:\n{context_text}{json_block}\n"
f"Answer in English. If conflicts exist between JSON and PDFs, report them clearly. Include PDF citations like [filename p.PAGE].\n[/USER]\n[ASSISTANT]"
)
return prompt
def answer_query(query: str, model_name: str, use_4bit: bool, top_k: int, rerank_k: int,
max_new_tokens: int, temperature: float, json_path: Optional[str] = None,
include_json: bool = False):
if not query or not query.strip():
return "Please enter a question.", ""
retrieved = retrieve(query.strip(), top_k=top_k)
if not retrieved:
return "I don't know based on the provided PDFs.", ""
selected = rerank(query.strip(), retrieved, top_n=rerank_k)
if not selected:
return "I don't know based on the provided PDFs.", ""
json_text = None
if include_json and json_path:
obj = read_json_file(json_path)
if isinstance(obj, dict) and "__error__" in obj:
json_text = f"__JSON_ERROR__: {obj['__error__']}"
else:
try:
json_text = json.dumps(obj, ensure_ascii=False)
if len(json_text) > 8000:
json_text = json_text[:8000] + "\n... [truncated]"
except Exception as e:
json_text = f"__JSON_ERROR__: {e}"
prompt = build_prompt(query, selected, json_text=json_text)
pipe = load_llm(model_name, use_4bit=use_4bit)
out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature)[0]["generated_text"]
cits = []
for c in selected:
cits.append(f"[{c.file} p.{c.page}]")
seen = set()
uniq = []
for ci in cits:
if ci not in seen:
seen.add(ci)
uniq.append(ci)
return out, " ".join(uniq)
# -----------------------------
# Admin helpers
# -----------------------------
def wipe_collection():
client = get_qdrant_client()
client.delete_collection(COLLECTION_NAME)
dim = load_embedder().get_sentence_embedding_dimension()
ensure_collection(client, dim)
return "Collection wiped and recreated."
def get_index_stats(sample_limit: int = 64):
"""Return basic collection stats from Qdrant to verify that indexing worked."""
client = get_qdrant_client()
try:
cnt = client.count(collection_name=COLLECTION_NAME, exact=True).count
except Exception as e:
return f"Count failed: {e}"
files = []
try:
points, next_offset = client.scroll(
collection_name=COLLECTION_NAME,
limit=sample_limit,
with_payload=True,
)
for p in points or []:
payload = p.payload or {}
fn = payload.get("file") or "unknown.pdf"
files.append(fn)
except Exception as e:
return f"Points: {cnt}. Scroll failed: {e}"
uniq_files = sorted(set(files))
return f"Points: {cnt} | Collection: {COLLECTION_NAME} | Sample files ({len(uniq_files)}): {', '.join(uniq_files[:10])}{' …' if len(uniq_files)>10 else ''}"
# -----------------------------
# Counterfactual (CF) evaluator helpers
# -----------------------------
def _flatten_first(x):
while isinstance(x, (list, tuple)) and len(x) == 1 and isinstance(x[0], (list, tuple)):
x = x[0]
return x
def parse_cf_input_json(path: str):
data = read_json_file(path)
if isinstance(data, dict) and data.get("__error__"):
return None, f"Load error: {data['__error__']}"
req = [
"test_data", "cfs_list", "feature_names", "feature_names_including_target",
"data_interface", "desired_class"
]
for k in req:
if k not in data:
return None, f"Missing key: {k}"
test_data = _flatten_first(data["test_data"]) # one row
cfs_list = _flatten_first(data["cfs_list"]) # list of rows
if not isinstance(cfs_list, (list, tuple)) or not cfs_list:
return None, "Empty cfs_list"
feat_inc = data["feature_names_including_target"]
desired = data["desired_class"]
outcome_name = data.get("data_interface", {}).get("outcome_name") or data.get("outcome_name", "income")
return {
"test_data": test_data,
"cfs_list": cfs_list,
"feature_names_including_target": feat_inc,
"feature_names": data["feature_names"],
"desired_class": desired,
"outcome_name": outcome_name,
}, None
def build_cf_retrieval_query(test_row, feature_names):
try:
fmap = {k: v for k, v in zip(feature_names, test_row)}
except Exception:
return "adult income factors by occupation, education, hours per week, marital status"
keys = ["occupation", "education", "workclass", "marital_status", "age", "hours_per_week", "gender", "race"]
parts = [f"{k}:{fmap[k]}" for k in keys if k in fmap]
parts.append("income threshold and probability drivers")
return ", ".join(map(str, parts))
def get_rag_context_text(query: str, top_k: int, rerank_k: int, max_chars: int = 8000):
chunks = retrieve(query, top_k=top_k)
if not chunks:
return ""
selected = rerank(query, chunks, top_n=rerank_k)
lines = [f"{c.text}\n[CIT: {c.file} p.{c.page}]" for c in selected]
return "\n\n".join(lines)[:max_chars]
def build_cf_prompt(parsed, rag_text: str = "", extra_json_text: str = ""):
td = parsed["test_data"]
cfs = parsed["cfs_list"]
feat_inc = parsed["feature_names_including_target"]
desired = parsed["desired_class"]
outcome = parsed["outcome_name"]
instr = (
"You are a helpful assistant with deep knowledge of counterfactual explanations, fairness, and causal reasoning.\n\n"
"You will be given a test data point, candidate counterfactuals, feature names (including target), the desired class,"
" and real-world context retrieved from documents.\n\n"
"Goals: (1) choose or propose a counterfactual that flips the class to the desired one, (2) minimize actionable changes,"
" (3) ensure plausibility given Adult Income and provided context.\n\n"
"Return only this JSON (no prose): {\"best_cf\": [...], \"explanation\": \"...\"}"
)
ctx = ""
if rag_text:
ctx += f"\n\nRETRIEVED_CONTEXT:\n{rag_text}"
if extra_json_text:
ctx += f"\n\nUPLOADED_JSON_CONTEXT:\n{extra_json_text}"
user = (
f"feature_names_including_target: {json.dumps(feat_inc)}\n"
f"desired_class: {desired}\n"
f"outcome_name: {outcome}\n"
f"test_data: {json.dumps(td)}\n"
f"cfs_list: {json.dumps(cfs)}\n"
f"{ctx}\n\n"
"Only output the JSON with keys 'best_cf' and 'explanation'. Ensure 'best_cf' matches the length and order of feature_names_including_target."
)
return (
f"<s>[SYSTEM]\n{SYSTEM_GUARDRAILS}\n{instr}\n[/SYSTEM]\n"
f"[USER]\n{user}\n[/USER]\n[ASSISTANT]"
)
def extract_json_object(text: str):
try:
obj = json.loads(text)
if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj:
return json.dumps(obj, ensure_ascii=False)
except Exception:
pass
m = re.search(r"\{[\s\S]*\}", text)
if m:
try:
obj = json.loads(m.group(0))
if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj:
return json.dumps(obj, ensure_ascii=False)
except Exception:
return "{\n \"error\": \"Model returned invalid JSON.\"\n}"
return "{\n \"error\": \"No JSON object found in model output.\"\n}"
def evaluate_cfs(cf_json_path: Optional[str], use_rag: bool, top_k: int, rerank_k: int,
max_new_tokens: int, temperature: float, extra_json_files, include_extra_json: bool,
model_name: str, use_4bit: bool):
if not cf_json_path:
return "{\n \"error\": \"No CF input JSON uploaded.\"\n}"
parsed, err = parse_cf_input_json(cf_json_path)
if err:
return json.dumps({"error": err})
rag_text = ""
if use_rag:
q = build_cf_retrieval_query(parsed["test_data"], parsed["feature_names_including_target"][:-1])
rag_text = get_rag_context_text(q, top_k=int(top_k), rerank_k=int(rerank_k))
extra_json_text = ""
if include_extra_json and extra_json_files:
paths = _normalize_file_inputs(extra_json_files)
blobs = []
for p in paths:
obj = read_json_file(p)
try:
blobs.append(json.dumps(obj, ensure_ascii=False))
except Exception:
continue
extra_json_text = ("\n\n".join(blobs))[:6000]
prompt = build_cf_prompt(parsed, rag_text=rag_text, extra_json_text=extra_json_text)
pipe = load_llm(model_name, use_4bit=use_4bit)
out = pipe(prompt, max_new_tokens=int(max_new_tokens), do_sample=(temperature > 0), temperature=float(temperature))[0]["generated_text"]
return extract_json_object(out)
# -----------------------------
# UI
# -----------------------------
with gr.Blocks(title="PDF RAG (ZeroGPU + Qdrant Cloud)") as demo:
gr.Markdown("""
# PDF RAG (ZeroGPU + Qdrant)
- **Upload** PDFs → **Index** → ask a **question**.\
- Answers are strictly grounded in your PDFs. If unknown, the app will say so.
""")
with gr.Accordion("Setup & Indexing", open=True):
files = gr.File(file_count="multiple", file_types=[".pdf"], type="filepath")
idx_btn = gr.Button("Build / Update Index")
idx_status = gr.Textbox(label="Status", interactive=False)
wipe_btn = gr.Button("Wipe Index (danger)")
stats_btn = gr.Button("Index stats")
stats_box = gr.Textbox(label="Index stats", interactive=False)
with gr.Accordion("JSON compare (optional)", open=False):
json_file = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload JSON spec")
include_json = gr.Checkbox(value=False, label="Include JSON in prompt for comparison")
json_info_box = gr.Textbox(label="JSON status", interactive=False)
with gr.Row():
model_name = gr.Dropdown(choices=DEFAULT_MODELS, value=DEFAULT_MODELS[0], label="LLM")
use_4bit = gr.Checkbox(value=True, label="Use 4-bit quantization (if available)")
with gr.Row():
top_k = gr.Slider(4, 32, value=16, step=1, label="Vector top_k")
rerank_k = gr.Slider(2, 12, value=6, step=1, label="Rerank top_n")
with gr.Row():
max_new_tokens = gr.Slider(64, 1024, value=320, step=16, label="Max new tokens")
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Temperature")
question = gr.Textbox(label="Ask a question (English)")
answer = gr.Textbox(label="Answer", lines=10)
citations = gr.Textbox(label="Citations", lines=2)
# CF evaluator UI
with gr.Accordion("Counterfactual Evaluator", open=False):
cf_input_json = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload CF input JSON (Adult format)")
cf_extra_jsons = gr.File(file_count="multiple", file_types=[".json"], type="filepath", label="Optional: Additional JSON context")
include_rag_cf = gr.Checkbox(value=True, label="Use RAG context from indexed PDFs")
include_extra_json_cf = gr.Checkbox(value=False, label="Include uploaded JSON context in prompt")
eval_btn = gr.Button("Evaluate Counterfactuals → JSON output")
result_cf = gr.Textbox(label="Result JSON", lines=10)
# Wiring
idx_btn.click(fn=ingest_pdfs, inputs=[files], outputs=[idx_status])
wipe_btn.click(fn=wipe_collection, inputs=None, outputs=[idx_status])
stats_btn.click(fn=get_index_stats, inputs=None, outputs=[stats_box])
def _json_info(path):
if not path:
return "No JSON uploaded."
obj = read_json_file(path)
if isinstance(obj, dict) and "__error__" in obj:
return f"JSON error: {obj['__error__']}"
try:
if isinstance(obj, dict):
keys = len(obj.keys())
return f"Loaded JSON object with {keys} top-level keys."
elif isinstance(obj, list):
return f"Loaded JSON array with {len(obj)} items."
else:
return f"Loaded JSON of type {type(obj).__name__}."
except Exception:
return "Loaded JSON."
json_file.change(fn=_json_info, inputs=[json_file], outputs=[json_info_box])
question.submit(
fn=answer_query,
inputs=[question, model_name, use_4bit, top_k, rerank_k, max_new_tokens, temperature, json_file, include_json],
outputs=[answer, citations]
)
eval_btn.click(
fn=evaluate_cfs,
inputs=[cf_input_json, include_rag_cf, top_k, rerank_k, max_new_tokens, temperature, cf_extra_jsons, include_extra_json_cf, model_name, use_4bit],
outputs=[result_cf]
)
if __name__ == "__main__":
demo.launch()