Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import tempfile | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from typing import Optional, List, Tuple, Any | |
| import json | |
| import re | |
| import gradio as gr | |
| from pypdf import PdfReader | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams, PointStruct | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| EMBED_MODEL_NAME = "BAAI/bge-large-en-v1.5" # 1024-dim cosine | |
| RERANK_MODEL_NAME = "BAAI/bge-reranker-large" # cross-encoder reranker | |
| COLLECTION_NAME = "pdf_rag_en" | |
| DEFAULT_MODELS = [ | |
| "Qwen/Qwen2.5-1.5B-Instruct", # CPU-friendly | |
| "microsoft/Phi-3.1-mini-4k-instruct", # CPU-friendly | |
| "Qwen/Qwen2.5-7B-Instruct", | |
| "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| "mistralai/Mistral-7B-Instruct-v0.3", | |
| ] | |
| # Strict RAG system prompt | |
| SYSTEM_GUARDRAILS = ( | |
| "You are a strict RAG assistant. Answer ONLY using the provided context. " | |
| "If the answer is not present, say: 'I don't know based on the provided PDFs.' " | |
| "Cite sources as [filename p.PAGE]. Keep answers concise and factual." | |
| ) | |
| # ----------------------------- | |
| # Utilities | |
| # ----------------------------- | |
| def read_pdf_to_pages(file_path: str): | |
| reader = PdfReader(file_path) | |
| pages = [] | |
| for i, page in enumerate(reader.pages, start=1): | |
| try: | |
| text = page.extract_text() or "" | |
| except Exception: | |
| text = "" | |
| text = "\n".join(line.strip() for line in text.splitlines() if line.strip()) | |
| pages.append((i, text)) | |
| return pages | |
| def chunk_text(text: str, words_per_chunk: int = 350, overlap: int = 40): | |
| """Simple word-based chunking suitable for English PDFs and BGE limits.""" | |
| words = text.split() | |
| if not words: | |
| return [] | |
| chunks = [] | |
| i = 0 | |
| while i < len(words): | |
| j = min(i + words_per_chunk, len(words)) | |
| chunk = " ".join(words[i:j]) | |
| chunks.append(chunk) | |
| i = j - overlap if j - overlap > i else j | |
| return chunks | |
| class RetrievedChunk: | |
| text: str | |
| score: float | |
| file: str | |
| page: int | |
| # ----------------------------- | |
| # JSON helpers | |
| # ----------------------------- | |
| def read_json_file(path: str): | |
| try: | |
| with open(path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| return {"__error__": str(e)} | |
| # ----------------------------- | |
| # Embeddings & Reranker | |
| # ----------------------------- | |
| def load_embedder(): | |
| model = SentenceTransformer(EMBED_MODEL_NAME) | |
| return model | |
| def load_reranker(): | |
| return CrossEncoder(RERANK_MODEL_NAME) | |
| # ----------------------------- | |
| # Qdrant (Vector DB) | |
| # ----------------------------- | |
| def get_qdrant_client(): | |
| url = os.getenv("QDRANT_URL", "").strip() | |
| key = os.getenv("QDRANT_API_KEY", "").strip() | |
| if not url or not key: | |
| raise RuntimeError("Missing QDRANT_URL or QDRANT_API_KEY in Secrets.") | |
| client = QdrantClient(url=url, api_key=key, timeout=60) | |
| return client | |
| def ensure_collection(client: QdrantClient, vector_size: int): | |
| collections = client.get_collections().collections | |
| names = {c.name for c in collections} | |
| if COLLECTION_NAME not in names: | |
| client.create_collection( | |
| collection_name=COLLECTION_NAME, | |
| vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), | |
| ) | |
| # ----------------------------- | |
| # LLM Loader (with optional 4-bit) | |
| # ----------------------------- | |
| _current_model_name = None | |
| _pipe = None | |
| def load_llm(model_name: str, use_4bit: bool = True): | |
| global _current_model_name, _pipe | |
| if _pipe is not None and _current_model_name == model_name: | |
| return _pipe | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True) | |
| quant_kwargs = {} | |
| if use_4bit: | |
| try: | |
| quant_kwargs = dict( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| except Exception: | |
| quant_kwargs = {} | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| **quant_kwargs, | |
| ) | |
| _pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| ) | |
| _current_model_name = model_name | |
| return _pipe | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model {model_name}: {e}") | |
| # ----------------------------- | |
| # Indexing | |
| # ----------------------------- | |
| def _normalize_file_inputs(files): | |
| if not files: | |
| return [] | |
| if not isinstance(files, (list, tuple)): | |
| files = [files] | |
| paths = [] | |
| for f in files: | |
| if f is None: | |
| continue | |
| try: | |
| p = str(f) # Gradio NamedString -> temp filepath | |
| except Exception: | |
| p = getattr(f, "name", None) or getattr(f, "path", None) | |
| if p: | |
| paths.append(p) | |
| return paths | |
| def ingest_pdfs(files, collection_name: str = COLLECTION_NAME): | |
| file_paths = _normalize_file_inputs(files) | |
| if not file_paths: | |
| return "No files received. Please upload PDF(s) first." | |
| embedder = load_embedder() | |
| vector_dim = embedder.get_sentence_embedding_dimension() | |
| client = get_qdrant_client() | |
| ensure_collection(client, vector_dim) | |
| total_chunks = 0 | |
| for path in file_paths: | |
| pages = read_pdf_to_pages(path) | |
| fname = os.path.basename(path) | |
| payloads, vectors, ids = [], [], [] | |
| for page_num, text in pages: | |
| if not text: | |
| continue | |
| for chunk_idx, chunk in enumerate(chunk_text(text)): | |
| vec = embedder.encode( | |
| chunk, | |
| normalize_embeddings=True, | |
| convert_to_numpy=True, | |
| ) | |
| vectors.append(vec) | |
| ids.append(str(uuid.uuid4())) | |
| payloads.append({ | |
| "text": chunk, | |
| "file": fname, | |
| "page": page_num, | |
| "chunk": chunk_idx, | |
| }) | |
| total_chunks += 1 | |
| if vectors: | |
| client.upsert( | |
| collection_name=collection_name, | |
| points=[PointStruct(id=pid, vector=v.tolist(), payload=pl) | |
| for pid, v, pl in zip(ids, vectors, payloads)] | |
| ) | |
| return f"Indexed ~{total_chunks} chunks from {len(file_paths)} PDF(s)." | |
| # ----------------------------- | |
| # Retrieval + Rerank | |
| # ----------------------------- | |
| def retrieve(query: str, top_k: int = 16, score_threshold: float = 0.25): | |
| embedder = load_embedder() | |
| client = get_qdrant_client() | |
| qvec = embedder.encode(query, normalize_embeddings=True, convert_to_numpy=True) | |
| candidates = client.search( | |
| collection_name=COLLECTION_NAME, | |
| query_vector=qvec.tolist(), | |
| limit=top_k, | |
| with_payload=True, | |
| score_threshold=score_threshold, | |
| ) | |
| chunks: List[RetrievedChunk] = [] | |
| for p in candidates: | |
| payload = p.payload or {} | |
| chunks.append( | |
| RetrievedChunk( | |
| text=payload.get("text", ""), | |
| score=float(p.score or 0.0), | |
| file=payload.get("file", "unknown.pdf"), | |
| page=int(payload.get("page", 0)), | |
| ) | |
| ) | |
| return chunks | |
| def rerank(query: str, chunks: List[RetrievedChunk], top_n: int = 6): | |
| if not chunks: | |
| return [] | |
| reranker = load_reranker() | |
| pairs = [(query, c.text) for c in chunks] | |
| scores = reranker.predict(pairs) | |
| ranked = sorted(zip(chunks, scores), key=lambda x: float(x[1]), reverse=True) | |
| top = [c for c, _ in ranked[:top_n]] | |
| return top | |
| # ----------------------------- | |
| # QA Generation (with optional JSON compare) | |
| # ----------------------------- | |
| def build_prompt(query: str, contexts: List[RetrievedChunk], json_text: Optional[str] = None): | |
| context_text = "\n\n".join([c.text for c in contexts]) | |
| json_block = f"\n\nJSON_SPEC:\n{json_text}\n" if json_text else "" | |
| prompt = ( | |
| f"<s>[SYSTEM]\n{SYSTEM_GUARDRAILS}\nIf a JSON spec is provided, compare it to the PDF context: identify agreements, conflicts, and missing fields explicitly.\n[/SYSTEM]\n" | |
| f"[USER]\nQuestion: {query}\n\nContext from PDFs:\n{context_text}{json_block}\n" | |
| f"Answer in English. If conflicts exist between JSON and PDFs, report them clearly. Include PDF citations like [filename p.PAGE].\n[/USER]\n[ASSISTANT]" | |
| ) | |
| return prompt | |
| def answer_query(query: str, model_name: str, use_4bit: bool, top_k: int, rerank_k: int, | |
| max_new_tokens: int, temperature: float, json_path: Optional[str] = None, | |
| include_json: bool = False): | |
| if not query or not query.strip(): | |
| return "Please enter a question.", "" | |
| retrieved = retrieve(query.strip(), top_k=top_k) | |
| if not retrieved: | |
| return "I don't know based on the provided PDFs.", "" | |
| selected = rerank(query.strip(), retrieved, top_n=rerank_k) | |
| if not selected: | |
| return "I don't know based on the provided PDFs.", "" | |
| json_text = None | |
| if include_json and json_path: | |
| obj = read_json_file(json_path) | |
| if isinstance(obj, dict) and "__error__" in obj: | |
| json_text = f"__JSON_ERROR__: {obj['__error__']}" | |
| else: | |
| try: | |
| json_text = json.dumps(obj, ensure_ascii=False) | |
| if len(json_text) > 8000: | |
| json_text = json_text[:8000] + "\n... [truncated]" | |
| except Exception as e: | |
| json_text = f"__JSON_ERROR__: {e}" | |
| prompt = build_prompt(query, selected, json_text=json_text) | |
| pipe = load_llm(model_name, use_4bit=use_4bit) | |
| out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature)[0]["generated_text"] | |
| cits = [] | |
| for c in selected: | |
| cits.append(f"[{c.file} p.{c.page}]") | |
| seen = set() | |
| uniq = [] | |
| for ci in cits: | |
| if ci not in seen: | |
| seen.add(ci) | |
| uniq.append(ci) | |
| return out, " ".join(uniq) | |
| # ----------------------------- | |
| # Admin helpers | |
| # ----------------------------- | |
| def wipe_collection(): | |
| client = get_qdrant_client() | |
| client.delete_collection(COLLECTION_NAME) | |
| dim = load_embedder().get_sentence_embedding_dimension() | |
| ensure_collection(client, dim) | |
| return "Collection wiped and recreated." | |
| def get_index_stats(sample_limit: int = 64): | |
| """Return basic collection stats from Qdrant to verify that indexing worked.""" | |
| client = get_qdrant_client() | |
| try: | |
| cnt = client.count(collection_name=COLLECTION_NAME, exact=True).count | |
| except Exception as e: | |
| return f"Count failed: {e}" | |
| files = [] | |
| try: | |
| points, next_offset = client.scroll( | |
| collection_name=COLLECTION_NAME, | |
| limit=sample_limit, | |
| with_payload=True, | |
| ) | |
| for p in points or []: | |
| payload = p.payload or {} | |
| fn = payload.get("file") or "unknown.pdf" | |
| files.append(fn) | |
| except Exception as e: | |
| return f"Points: {cnt}. Scroll failed: {e}" | |
| uniq_files = sorted(set(files)) | |
| return f"Points: {cnt} | Collection: {COLLECTION_NAME} | Sample files ({len(uniq_files)}): {', '.join(uniq_files[:10])}{' …' if len(uniq_files)>10 else ''}" | |
| # ----------------------------- | |
| # Counterfactual (CF) evaluator helpers | |
| # ----------------------------- | |
| def _flatten_first(x): | |
| while isinstance(x, (list, tuple)) and len(x) == 1 and isinstance(x[0], (list, tuple)): | |
| x = x[0] | |
| return x | |
| def parse_cf_input_json(path: str): | |
| data = read_json_file(path) | |
| if isinstance(data, dict) and data.get("__error__"): | |
| return None, f"Load error: {data['__error__']}" | |
| req = [ | |
| "test_data", "cfs_list", "feature_names", "feature_names_including_target", | |
| "data_interface", "desired_class" | |
| ] | |
| for k in req: | |
| if k not in data: | |
| return None, f"Missing key: {k}" | |
| test_data = _flatten_first(data["test_data"]) # one row | |
| cfs_list = _flatten_first(data["cfs_list"]) # list of rows | |
| if not isinstance(cfs_list, (list, tuple)) or not cfs_list: | |
| return None, "Empty cfs_list" | |
| feat_inc = data["feature_names_including_target"] | |
| desired = data["desired_class"] | |
| outcome_name = data.get("data_interface", {}).get("outcome_name") or data.get("outcome_name", "income") | |
| return { | |
| "test_data": test_data, | |
| "cfs_list": cfs_list, | |
| "feature_names_including_target": feat_inc, | |
| "feature_names": data["feature_names"], | |
| "desired_class": desired, | |
| "outcome_name": outcome_name, | |
| }, None | |
| def build_cf_retrieval_query(test_row, feature_names): | |
| try: | |
| fmap = {k: v for k, v in zip(feature_names, test_row)} | |
| except Exception: | |
| return "adult income factors by occupation, education, hours per week, marital status" | |
| keys = ["occupation", "education", "workclass", "marital_status", "age", "hours_per_week", "gender", "race"] | |
| parts = [f"{k}:{fmap[k]}" for k in keys if k in fmap] | |
| parts.append("income threshold and probability drivers") | |
| return ", ".join(map(str, parts)) | |
| def get_rag_context_text(query: str, top_k: int, rerank_k: int, max_chars: int = 8000): | |
| chunks = retrieve(query, top_k=top_k) | |
| if not chunks: | |
| return "" | |
| selected = rerank(query, chunks, top_n=rerank_k) | |
| lines = [f"{c.text}\n[CIT: {c.file} p.{c.page}]" for c in selected] | |
| return "\n\n".join(lines)[:max_chars] | |
| def build_cf_prompt(parsed, rag_text: str = "", extra_json_text: str = ""): | |
| td = parsed["test_data"] | |
| cfs = parsed["cfs_list"] | |
| feat_inc = parsed["feature_names_including_target"] | |
| desired = parsed["desired_class"] | |
| outcome = parsed["outcome_name"] | |
| instr = ( | |
| "You are a helpful assistant with deep knowledge of counterfactual explanations, fairness, and causal reasoning.\n\n" | |
| "You will be given a test data point, candidate counterfactuals, feature names (including target), the desired class," | |
| " and real-world context retrieved from documents.\n\n" | |
| "Goals: (1) choose or propose a counterfactual that flips the class to the desired one, (2) minimize actionable changes," | |
| " (3) ensure plausibility given Adult Income and provided context.\n\n" | |
| "Return only this JSON (no prose): {\"best_cf\": [...], \"explanation\": \"...\"}" | |
| ) | |
| ctx = "" | |
| if rag_text: | |
| ctx += f"\n\nRETRIEVED_CONTEXT:\n{rag_text}" | |
| if extra_json_text: | |
| ctx += f"\n\nUPLOADED_JSON_CONTEXT:\n{extra_json_text}" | |
| user = ( | |
| f"feature_names_including_target: {json.dumps(feat_inc)}\n" | |
| f"desired_class: {desired}\n" | |
| f"outcome_name: {outcome}\n" | |
| f"test_data: {json.dumps(td)}\n" | |
| f"cfs_list: {json.dumps(cfs)}\n" | |
| f"{ctx}\n\n" | |
| "Only output the JSON with keys 'best_cf' and 'explanation'. Ensure 'best_cf' matches the length and order of feature_names_including_target." | |
| ) | |
| return ( | |
| f"<s>[SYSTEM]\n{SYSTEM_GUARDRAILS}\n{instr}\n[/SYSTEM]\n" | |
| f"[USER]\n{user}\n[/USER]\n[ASSISTANT]" | |
| ) | |
| def extract_json_object(text: str): | |
| try: | |
| obj = json.loads(text) | |
| if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj: | |
| return json.dumps(obj, ensure_ascii=False) | |
| except Exception: | |
| pass | |
| m = re.search(r"\{[\s\S]*\}", text) | |
| if m: | |
| try: | |
| obj = json.loads(m.group(0)) | |
| if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj: | |
| return json.dumps(obj, ensure_ascii=False) | |
| except Exception: | |
| return "{\n \"error\": \"Model returned invalid JSON.\"\n}" | |
| return "{\n \"error\": \"No JSON object found in model output.\"\n}" | |
| def evaluate_cfs(cf_json_path: Optional[str], use_rag: bool, top_k: int, rerank_k: int, | |
| max_new_tokens: int, temperature: float, extra_json_files, include_extra_json: bool, | |
| model_name: str, use_4bit: bool): | |
| if not cf_json_path: | |
| return "{\n \"error\": \"No CF input JSON uploaded.\"\n}" | |
| parsed, err = parse_cf_input_json(cf_json_path) | |
| if err: | |
| return json.dumps({"error": err}) | |
| rag_text = "" | |
| if use_rag: | |
| q = build_cf_retrieval_query(parsed["test_data"], parsed["feature_names_including_target"][:-1]) | |
| rag_text = get_rag_context_text(q, top_k=int(top_k), rerank_k=int(rerank_k)) | |
| extra_json_text = "" | |
| if include_extra_json and extra_json_files: | |
| paths = _normalize_file_inputs(extra_json_files) | |
| blobs = [] | |
| for p in paths: | |
| obj = read_json_file(p) | |
| try: | |
| blobs.append(json.dumps(obj, ensure_ascii=False)) | |
| except Exception: | |
| continue | |
| extra_json_text = ("\n\n".join(blobs))[:6000] | |
| prompt = build_cf_prompt(parsed, rag_text=rag_text, extra_json_text=extra_json_text) | |
| pipe = load_llm(model_name, use_4bit=use_4bit) | |
| out = pipe(prompt, max_new_tokens=int(max_new_tokens), do_sample=(temperature > 0), temperature=float(temperature))[0]["generated_text"] | |
| return extract_json_object(out) | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| with gr.Blocks(title="PDF RAG (ZeroGPU + Qdrant Cloud)") as demo: | |
| gr.Markdown(""" | |
| # PDF RAG (ZeroGPU + Qdrant) | |
| - **Upload** PDFs → **Index** → ask a **question**.\ | |
| - Answers are strictly grounded in your PDFs. If unknown, the app will say so. | |
| """) | |
| with gr.Accordion("Setup & Indexing", open=True): | |
| files = gr.File(file_count="multiple", file_types=[".pdf"], type="filepath") | |
| idx_btn = gr.Button("Build / Update Index") | |
| idx_status = gr.Textbox(label="Status", interactive=False) | |
| wipe_btn = gr.Button("Wipe Index (danger)") | |
| stats_btn = gr.Button("Index stats") | |
| stats_box = gr.Textbox(label="Index stats", interactive=False) | |
| with gr.Accordion("JSON compare (optional)", open=False): | |
| json_file = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload JSON spec") | |
| include_json = gr.Checkbox(value=False, label="Include JSON in prompt for comparison") | |
| json_info_box = gr.Textbox(label="JSON status", interactive=False) | |
| with gr.Row(): | |
| model_name = gr.Dropdown(choices=DEFAULT_MODELS, value=DEFAULT_MODELS[0], label="LLM") | |
| use_4bit = gr.Checkbox(value=True, label="Use 4-bit quantization (if available)") | |
| with gr.Row(): | |
| top_k = gr.Slider(4, 32, value=16, step=1, label="Vector top_k") | |
| rerank_k = gr.Slider(2, 12, value=6, step=1, label="Rerank top_n") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(64, 1024, value=320, step=16, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Temperature") | |
| question = gr.Textbox(label="Ask a question (English)") | |
| answer = gr.Textbox(label="Answer", lines=10) | |
| citations = gr.Textbox(label="Citations", lines=2) | |
| # CF evaluator UI | |
| with gr.Accordion("Counterfactual Evaluator", open=False): | |
| cf_input_json = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload CF input JSON (Adult format)") | |
| cf_extra_jsons = gr.File(file_count="multiple", file_types=[".json"], type="filepath", label="Optional: Additional JSON context") | |
| include_rag_cf = gr.Checkbox(value=True, label="Use RAG context from indexed PDFs") | |
| include_extra_json_cf = gr.Checkbox(value=False, label="Include uploaded JSON context in prompt") | |
| eval_btn = gr.Button("Evaluate Counterfactuals → JSON output") | |
| result_cf = gr.Textbox(label="Result JSON", lines=10) | |
| # Wiring | |
| idx_btn.click(fn=ingest_pdfs, inputs=[files], outputs=[idx_status]) | |
| wipe_btn.click(fn=wipe_collection, inputs=None, outputs=[idx_status]) | |
| stats_btn.click(fn=get_index_stats, inputs=None, outputs=[stats_box]) | |
| def _json_info(path): | |
| if not path: | |
| return "No JSON uploaded." | |
| obj = read_json_file(path) | |
| if isinstance(obj, dict) and "__error__" in obj: | |
| return f"JSON error: {obj['__error__']}" | |
| try: | |
| if isinstance(obj, dict): | |
| keys = len(obj.keys()) | |
| return f"Loaded JSON object with {keys} top-level keys." | |
| elif isinstance(obj, list): | |
| return f"Loaded JSON array with {len(obj)} items." | |
| else: | |
| return f"Loaded JSON of type {type(obj).__name__}." | |
| except Exception: | |
| return "Loaded JSON." | |
| json_file.change(fn=_json_info, inputs=[json_file], outputs=[json_info_box]) | |
| question.submit( | |
| fn=answer_query, | |
| inputs=[question, model_name, use_4bit, top_k, rerank_k, max_new_tokens, temperature, json_file, include_json], | |
| outputs=[answer, citations] | |
| ) | |
| eval_btn.click( | |
| fn=evaluate_cfs, | |
| inputs=[cf_input_json, include_rag_cf, top_k, rerank_k, max_new_tokens, temperature, cf_extra_jsons, include_extra_json_cf, model_name, use_4bit], | |
| outputs=[result_cf] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |