import os import uuid import tempfile from dataclasses import dataclass from functools import lru_cache from typing import Optional, List, Tuple, Any import json import re import gradio as gr from pypdf import PdfReader from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams, PointStruct from sentence_transformers import SentenceTransformer, CrossEncoder import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # ----------------------------- # Config # ----------------------------- EMBED_MODEL_NAME = "BAAI/bge-large-en-v1.5" # 1024-dim cosine RERANK_MODEL_NAME = "BAAI/bge-reranker-large" # cross-encoder reranker COLLECTION_NAME = "pdf_rag_en" DEFAULT_MODELS = [ "Qwen/Qwen2.5-1.5B-Instruct", # CPU-friendly "microsoft/Phi-3.1-mini-4k-instruct", # CPU-friendly "Qwen/Qwen2.5-7B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3", ] # Strict RAG system prompt SYSTEM_GUARDRAILS = ( "You are a strict RAG assistant. Answer ONLY using the provided context. " "If the answer is not present, say: 'I don't know based on the provided PDFs.' " "Cite sources as [filename p.PAGE]. Keep answers concise and factual." ) # ----------------------------- # Utilities # ----------------------------- def read_pdf_to_pages(file_path: str): reader = PdfReader(file_path) pages = [] for i, page in enumerate(reader.pages, start=1): try: text = page.extract_text() or "" except Exception: text = "" text = "\n".join(line.strip() for line in text.splitlines() if line.strip()) pages.append((i, text)) return pages def chunk_text(text: str, words_per_chunk: int = 350, overlap: int = 40): """Simple word-based chunking suitable for English PDFs and BGE limits.""" words = text.split() if not words: return [] chunks = [] i = 0 while i < len(words): j = min(i + words_per_chunk, len(words)) chunk = " ".join(words[i:j]) chunks.append(chunk) i = j - overlap if j - overlap > i else j return chunks @dataclass class RetrievedChunk: text: str score: float file: str page: int # ----------------------------- # JSON helpers # ----------------------------- def read_json_file(path: str): try: with open(path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: return {"__error__": str(e)} # ----------------------------- # Embeddings & Reranker # ----------------------------- @lru_cache(maxsize=1) def load_embedder(): model = SentenceTransformer(EMBED_MODEL_NAME) return model @lru_cache(maxsize=1) def load_reranker(): return CrossEncoder(RERANK_MODEL_NAME) # ----------------------------- # Qdrant (Vector DB) # ----------------------------- def get_qdrant_client(): url = os.getenv("QDRANT_URL", "").strip() key = os.getenv("QDRANT_API_KEY", "").strip() if not url or not key: raise RuntimeError("Missing QDRANT_URL or QDRANT_API_KEY in Secrets.") client = QdrantClient(url=url, api_key=key, timeout=60) return client def ensure_collection(client: QdrantClient, vector_size: int): collections = client.get_collections().collections names = {c.name for c in collections} if COLLECTION_NAME not in names: client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) # ----------------------------- # LLM Loader (with optional 4-bit) # ----------------------------- _current_model_name = None _pipe = None def load_llm(model_name: str, use_4bit: bool = True): global _current_model_name, _pipe if _pipe is not None and _current_model_name == model_name: return _pipe try: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True) quant_kwargs = {} if use_4bit: try: quant_kwargs = dict( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, ) except Exception: quant_kwargs = {} model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, **quant_kwargs, ) _pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, do_sample=False, ) _current_model_name = model_name return _pipe except Exception as e: raise RuntimeError(f"Failed to load model {model_name}: {e}") # ----------------------------- # Indexing # ----------------------------- def _normalize_file_inputs(files): if not files: return [] if not isinstance(files, (list, tuple)): files = [files] paths = [] for f in files: if f is None: continue try: p = str(f) # Gradio NamedString -> temp filepath except Exception: p = getattr(f, "name", None) or getattr(f, "path", None) if p: paths.append(p) return paths def ingest_pdfs(files, collection_name: str = COLLECTION_NAME): file_paths = _normalize_file_inputs(files) if not file_paths: return "No files received. Please upload PDF(s) first." embedder = load_embedder() vector_dim = embedder.get_sentence_embedding_dimension() client = get_qdrant_client() ensure_collection(client, vector_dim) total_chunks = 0 for path in file_paths: pages = read_pdf_to_pages(path) fname = os.path.basename(path) payloads, vectors, ids = [], [], [] for page_num, text in pages: if not text: continue for chunk_idx, chunk in enumerate(chunk_text(text)): vec = embedder.encode( chunk, normalize_embeddings=True, convert_to_numpy=True, ) vectors.append(vec) ids.append(str(uuid.uuid4())) payloads.append({ "text": chunk, "file": fname, "page": page_num, "chunk": chunk_idx, }) total_chunks += 1 if vectors: client.upsert( collection_name=collection_name, points=[PointStruct(id=pid, vector=v.tolist(), payload=pl) for pid, v, pl in zip(ids, vectors, payloads)] ) return f"Indexed ~{total_chunks} chunks from {len(file_paths)} PDF(s)." # ----------------------------- # Retrieval + Rerank # ----------------------------- def retrieve(query: str, top_k: int = 16, score_threshold: float = 0.25): embedder = load_embedder() client = get_qdrant_client() qvec = embedder.encode(query, normalize_embeddings=True, convert_to_numpy=True) candidates = client.search( collection_name=COLLECTION_NAME, query_vector=qvec.tolist(), limit=top_k, with_payload=True, score_threshold=score_threshold, ) chunks: List[RetrievedChunk] = [] for p in candidates: payload = p.payload or {} chunks.append( RetrievedChunk( text=payload.get("text", ""), score=float(p.score or 0.0), file=payload.get("file", "unknown.pdf"), page=int(payload.get("page", 0)), ) ) return chunks def rerank(query: str, chunks: List[RetrievedChunk], top_n: int = 6): if not chunks: return [] reranker = load_reranker() pairs = [(query, c.text) for c in chunks] scores = reranker.predict(pairs) ranked = sorted(zip(chunks, scores), key=lambda x: float(x[1]), reverse=True) top = [c for c, _ in ranked[:top_n]] return top # ----------------------------- # QA Generation (with optional JSON compare) # ----------------------------- def build_prompt(query: str, contexts: List[RetrievedChunk], json_text: Optional[str] = None): context_text = "\n\n".join([c.text for c in contexts]) json_block = f"\n\nJSON_SPEC:\n{json_text}\n" if json_text else "" prompt = ( f"[SYSTEM]\n{SYSTEM_GUARDRAILS}\nIf a JSON spec is provided, compare it to the PDF context: identify agreements, conflicts, and missing fields explicitly.\n[/SYSTEM]\n" f"[USER]\nQuestion: {query}\n\nContext from PDFs:\n{context_text}{json_block}\n" f"Answer in English. If conflicts exist between JSON and PDFs, report them clearly. Include PDF citations like [filename p.PAGE].\n[/USER]\n[ASSISTANT]" ) return prompt def answer_query(query: str, model_name: str, use_4bit: bool, top_k: int, rerank_k: int, max_new_tokens: int, temperature: float, json_path: Optional[str] = None, include_json: bool = False): if not query or not query.strip(): return "Please enter a question.", "" retrieved = retrieve(query.strip(), top_k=top_k) if not retrieved: return "I don't know based on the provided PDFs.", "" selected = rerank(query.strip(), retrieved, top_n=rerank_k) if not selected: return "I don't know based on the provided PDFs.", "" json_text = None if include_json and json_path: obj = read_json_file(json_path) if isinstance(obj, dict) and "__error__" in obj: json_text = f"__JSON_ERROR__: {obj['__error__']}" else: try: json_text = json.dumps(obj, ensure_ascii=False) if len(json_text) > 8000: json_text = json_text[:8000] + "\n... [truncated]" except Exception as e: json_text = f"__JSON_ERROR__: {e}" prompt = build_prompt(query, selected, json_text=json_text) pipe = load_llm(model_name, use_4bit=use_4bit) out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature)[0]["generated_text"] cits = [] for c in selected: cits.append(f"[{c.file} p.{c.page}]") seen = set() uniq = [] for ci in cits: if ci not in seen: seen.add(ci) uniq.append(ci) return out, " ".join(uniq) # ----------------------------- # Admin helpers # ----------------------------- def wipe_collection(): client = get_qdrant_client() client.delete_collection(COLLECTION_NAME) dim = load_embedder().get_sentence_embedding_dimension() ensure_collection(client, dim) return "Collection wiped and recreated." def get_index_stats(sample_limit: int = 64): """Return basic collection stats from Qdrant to verify that indexing worked.""" client = get_qdrant_client() try: cnt = client.count(collection_name=COLLECTION_NAME, exact=True).count except Exception as e: return f"Count failed: {e}" files = [] try: points, next_offset = client.scroll( collection_name=COLLECTION_NAME, limit=sample_limit, with_payload=True, ) for p in points or []: payload = p.payload or {} fn = payload.get("file") or "unknown.pdf" files.append(fn) except Exception as e: return f"Points: {cnt}. Scroll failed: {e}" uniq_files = sorted(set(files)) return f"Points: {cnt} | Collection: {COLLECTION_NAME} | Sample files ({len(uniq_files)}): {', '.join(uniq_files[:10])}{' …' if len(uniq_files)>10 else ''}" # ----------------------------- # Counterfactual (CF) evaluator helpers # ----------------------------- def _flatten_first(x): while isinstance(x, (list, tuple)) and len(x) == 1 and isinstance(x[0], (list, tuple)): x = x[0] return x def parse_cf_input_json(path: str): data = read_json_file(path) if isinstance(data, dict) and data.get("__error__"): return None, f"Load error: {data['__error__']}" req = [ "test_data", "cfs_list", "feature_names", "feature_names_including_target", "data_interface", "desired_class" ] for k in req: if k not in data: return None, f"Missing key: {k}" test_data = _flatten_first(data["test_data"]) # one row cfs_list = _flatten_first(data["cfs_list"]) # list of rows if not isinstance(cfs_list, (list, tuple)) or not cfs_list: return None, "Empty cfs_list" feat_inc = data["feature_names_including_target"] desired = data["desired_class"] outcome_name = data.get("data_interface", {}).get("outcome_name") or data.get("outcome_name", "income") return { "test_data": test_data, "cfs_list": cfs_list, "feature_names_including_target": feat_inc, "feature_names": data["feature_names"], "desired_class": desired, "outcome_name": outcome_name, }, None def build_cf_retrieval_query(test_row, feature_names): try: fmap = {k: v for k, v in zip(feature_names, test_row)} except Exception: return "adult income factors by occupation, education, hours per week, marital status" keys = ["occupation", "education", "workclass", "marital_status", "age", "hours_per_week", "gender", "race"] parts = [f"{k}:{fmap[k]}" for k in keys if k in fmap] parts.append("income threshold and probability drivers") return ", ".join(map(str, parts)) def get_rag_context_text(query: str, top_k: int, rerank_k: int, max_chars: int = 8000): chunks = retrieve(query, top_k=top_k) if not chunks: return "" selected = rerank(query, chunks, top_n=rerank_k) lines = [f"{c.text}\n[CIT: {c.file} p.{c.page}]" for c in selected] return "\n\n".join(lines)[:max_chars] def build_cf_prompt(parsed, rag_text: str = "", extra_json_text: str = ""): td = parsed["test_data"] cfs = parsed["cfs_list"] feat_inc = parsed["feature_names_including_target"] desired = parsed["desired_class"] outcome = parsed["outcome_name"] instr = ( "You are a helpful assistant with deep knowledge of counterfactual explanations, fairness, and causal reasoning.\n\n" "You will be given a test data point, candidate counterfactuals, feature names (including target), the desired class," " and real-world context retrieved from documents.\n\n" "Goals: (1) choose or propose a counterfactual that flips the class to the desired one, (2) minimize actionable changes," " (3) ensure plausibility given Adult Income and provided context.\n\n" "Return only this JSON (no prose): {\"best_cf\": [...], \"explanation\": \"...\"}" ) ctx = "" if rag_text: ctx += f"\n\nRETRIEVED_CONTEXT:\n{rag_text}" if extra_json_text: ctx += f"\n\nUPLOADED_JSON_CONTEXT:\n{extra_json_text}" user = ( f"feature_names_including_target: {json.dumps(feat_inc)}\n" f"desired_class: {desired}\n" f"outcome_name: {outcome}\n" f"test_data: {json.dumps(td)}\n" f"cfs_list: {json.dumps(cfs)}\n" f"{ctx}\n\n" "Only output the JSON with keys 'best_cf' and 'explanation'. Ensure 'best_cf' matches the length and order of feature_names_including_target." ) return ( f"[SYSTEM]\n{SYSTEM_GUARDRAILS}\n{instr}\n[/SYSTEM]\n" f"[USER]\n{user}\n[/USER]\n[ASSISTANT]" ) def extract_json_object(text: str): try: obj = json.loads(text) if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj: return json.dumps(obj, ensure_ascii=False) except Exception: pass m = re.search(r"\{[\s\S]*\}", text) if m: try: obj = json.loads(m.group(0)) if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj: return json.dumps(obj, ensure_ascii=False) except Exception: return "{\n \"error\": \"Model returned invalid JSON.\"\n}" return "{\n \"error\": \"No JSON object found in model output.\"\n}" def evaluate_cfs(cf_json_path: Optional[str], use_rag: bool, top_k: int, rerank_k: int, max_new_tokens: int, temperature: float, extra_json_files, include_extra_json: bool, model_name: str, use_4bit: bool): if not cf_json_path: return "{\n \"error\": \"No CF input JSON uploaded.\"\n}" parsed, err = parse_cf_input_json(cf_json_path) if err: return json.dumps({"error": err}) rag_text = "" if use_rag: q = build_cf_retrieval_query(parsed["test_data"], parsed["feature_names_including_target"][:-1]) rag_text = get_rag_context_text(q, top_k=int(top_k), rerank_k=int(rerank_k)) extra_json_text = "" if include_extra_json and extra_json_files: paths = _normalize_file_inputs(extra_json_files) blobs = [] for p in paths: obj = read_json_file(p) try: blobs.append(json.dumps(obj, ensure_ascii=False)) except Exception: continue extra_json_text = ("\n\n".join(blobs))[:6000] prompt = build_cf_prompt(parsed, rag_text=rag_text, extra_json_text=extra_json_text) pipe = load_llm(model_name, use_4bit=use_4bit) out = pipe(prompt, max_new_tokens=int(max_new_tokens), do_sample=(temperature > 0), temperature=float(temperature))[0]["generated_text"] return extract_json_object(out) # ----------------------------- # UI # ----------------------------- with gr.Blocks(title="PDF RAG (ZeroGPU + Qdrant Cloud)") as demo: gr.Markdown(""" # PDF RAG (ZeroGPU + Qdrant) - **Upload** PDFs → **Index** → ask a **question**.\ - Answers are strictly grounded in your PDFs. If unknown, the app will say so. """) with gr.Accordion("Setup & Indexing", open=True): files = gr.File(file_count="multiple", file_types=[".pdf"], type="filepath") idx_btn = gr.Button("Build / Update Index") idx_status = gr.Textbox(label="Status", interactive=False) wipe_btn = gr.Button("Wipe Index (danger)") stats_btn = gr.Button("Index stats") stats_box = gr.Textbox(label="Index stats", interactive=False) with gr.Accordion("JSON compare (optional)", open=False): json_file = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload JSON spec") include_json = gr.Checkbox(value=False, label="Include JSON in prompt for comparison") json_info_box = gr.Textbox(label="JSON status", interactive=False) with gr.Row(): model_name = gr.Dropdown(choices=DEFAULT_MODELS, value=DEFAULT_MODELS[0], label="LLM") use_4bit = gr.Checkbox(value=True, label="Use 4-bit quantization (if available)") with gr.Row(): top_k = gr.Slider(4, 32, value=16, step=1, label="Vector top_k") rerank_k = gr.Slider(2, 12, value=6, step=1, label="Rerank top_n") with gr.Row(): max_new_tokens = gr.Slider(64, 1024, value=320, step=16, label="Max new tokens") temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Temperature") question = gr.Textbox(label="Ask a question (English)") answer = gr.Textbox(label="Answer", lines=10) citations = gr.Textbox(label="Citations", lines=2) # CF evaluator UI with gr.Accordion("Counterfactual Evaluator", open=False): cf_input_json = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload CF input JSON (Adult format)") cf_extra_jsons = gr.File(file_count="multiple", file_types=[".json"], type="filepath", label="Optional: Additional JSON context") include_rag_cf = gr.Checkbox(value=True, label="Use RAG context from indexed PDFs") include_extra_json_cf = gr.Checkbox(value=False, label="Include uploaded JSON context in prompt") eval_btn = gr.Button("Evaluate Counterfactuals → JSON output") result_cf = gr.Textbox(label="Result JSON", lines=10) # Wiring idx_btn.click(fn=ingest_pdfs, inputs=[files], outputs=[idx_status]) wipe_btn.click(fn=wipe_collection, inputs=None, outputs=[idx_status]) stats_btn.click(fn=get_index_stats, inputs=None, outputs=[stats_box]) def _json_info(path): if not path: return "No JSON uploaded." obj = read_json_file(path) if isinstance(obj, dict) and "__error__" in obj: return f"JSON error: {obj['__error__']}" try: if isinstance(obj, dict): keys = len(obj.keys()) return f"Loaded JSON object with {keys} top-level keys." elif isinstance(obj, list): return f"Loaded JSON array with {len(obj)} items." else: return f"Loaded JSON of type {type(obj).__name__}." except Exception: return "Loaded JSON." json_file.change(fn=_json_info, inputs=[json_file], outputs=[json_info_box]) question.submit( fn=answer_query, inputs=[question, model_name, use_4bit, top_k, rerank_k, max_new_tokens, temperature, json_file, include_json], outputs=[answer, citations] ) eval_btn.click( fn=evaluate_cfs, inputs=[cf_input_json, include_rag_cf, top_k, rerank_k, max_new_tokens, temperature, cf_extra_jsons, include_extra_json_cf, model_name, use_4bit], outputs=[result_cf] ) if __name__ == "__main__": demo.launch()