Hesamnasiri commited on
Commit
449a488
·
verified ·
1 Parent(s): 0280254

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -23
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
  import uuid
3
- import math
4
  import tempfile
5
  from dataclasses import dataclass
6
  from functools import lru_cache
 
 
 
 
7
 
8
  import gradio as gr
9
- import numpy as np
10
  from pypdf import PdfReader
11
 
12
  from qdrant_client import QdrantClient
@@ -49,7 +51,6 @@ def read_pdf_to_pages(file_path: str):
49
  text = page.extract_text() or ""
50
  except Exception:
51
  text = ""
52
- # normalize whitespace
53
  text = "\n".join(line.strip() for line in text.splitlines() if line.strip())
54
  pages.append((i, text))
55
  return pages
@@ -78,13 +79,24 @@ class RetrievedChunk:
78
  page: int
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # -----------------------------
82
  # Embeddings & Reranker
83
  # -----------------------------
84
  @lru_cache(maxsize=1)
85
  def load_embedder():
86
  model = SentenceTransformer(EMBED_MODEL_NAME)
87
- # BGE recommends query instruction (English): "Represent this sentence for searching relevant passages: "
88
  return model
89
 
90
 
@@ -128,7 +140,6 @@ def load_llm(model_name: str, use_4bit: bool = True):
128
  if _pipe is not None and _current_model_name == model_name:
129
  return _pipe
130
 
131
- torch.cuda.empty_cache()
132
  try:
133
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
134
  quant_kwargs = {}
@@ -242,7 +253,7 @@ def retrieve(query: str, top_k: int = 16, score_threshold: float = 0.25):
242
  with_payload=True,
243
  score_threshold=score_threshold,
244
  )
245
- chunks = []
246
  for p in candidates:
247
  payload = p.payload or {}
248
  chunks.append(
@@ -256,7 +267,7 @@ def retrieve(query: str, top_k: int = 16, score_threshold: float = 0.25):
256
  return chunks
257
 
258
 
259
- def rerank(query: str, chunks, top_n: int = 6):
260
  if not chunks:
261
  return []
262
  reranker = load_reranker()
@@ -268,45 +279,55 @@ def rerank(query: str, chunks, top_n: int = 6):
268
 
269
 
270
  # -----------------------------
271
- # Generation
272
  # -----------------------------
273
 
274
- def build_prompt(query: str, contexts):
275
  context_text = "\n\n".join([c.text for c in contexts])
 
276
  prompt = (
277
- f"<s>[SYSTEM]\n{SYSTEM_GUARDRAILS}\n[/SYSTEM]\n" # works fine across these instruct models
278
- f"[USER]\nQuestion: {query}\n\nContext:\n{context_text}\n\n"
279
- f"Answer in English, with citations at the end like [filename p.PAGE].\n[/USER]\n[ASSISTANT]"
280
  )
281
  return prompt
282
 
283
 
284
- def answer_query(query: str, model_name: str, use_4bit: bool, top_k: int, rerank_k: int, max_new_tokens: int, temperature: float):
 
 
285
  if not query or not query.strip():
286
  return "Please enter a question.", ""
287
 
288
- # 1) retrieve
289
  retrieved = retrieve(query.strip(), top_k=top_k)
290
  if not retrieved:
291
  return "I don't know based on the provided PDFs.", ""
292
 
293
- # 2) rerank
294
  selected = rerank(query.strip(), retrieved, top_n=rerank_k)
295
  if not selected:
296
  return "I don't know based on the provided PDFs.", ""
297
 
298
- # 3) build prompt
299
- prompt = build_prompt(query, selected)
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- # 4) LLM generate
302
  pipe = load_llm(model_name, use_4bit=use_4bit)
303
  out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature)[0]["generated_text"]
304
 
305
- # 5) citations
306
  cits = []
307
  for c in selected:
308
  cits.append(f"[{c.file} p.{c.page}]")
309
- # unique preserve order
310
  seen = set()
311
  uniq = []
312
  for ci in cits:
@@ -324,12 +345,185 @@ def answer_query(query: str, model_name: str, use_4bit: bool, top_k: int, rerank
324
  def wipe_collection():
325
  client = get_qdrant_client()
326
  client.delete_collection(COLLECTION_NAME)
327
- # recreate with correct dim
328
  dim = load_embedder().get_sentence_embedding_dimension()
329
  ensure_collection(client, dim)
330
  return "Collection wiped and recreated."
331
 
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  # -----------------------------
334
  # UI
335
  # -----------------------------
@@ -345,6 +539,13 @@ with gr.Blocks(title="PDF RAG (ZeroGPU + Qdrant Cloud)") as demo:
345
  idx_btn = gr.Button("Build / Update Index")
346
  idx_status = gr.Textbox(label="Status", interactive=False)
347
  wipe_btn = gr.Button("Wipe Index (danger)")
 
 
 
 
 
 
 
348
 
349
  with gr.Row():
350
  model_name = gr.Dropdown(choices=DEFAULT_MODELS, value=DEFAULT_MODELS[0], label="LLM")
@@ -362,14 +563,50 @@ with gr.Blocks(title="PDF RAG (ZeroGPU + Qdrant Cloud)") as demo:
362
  answer = gr.Textbox(label="Answer", lines=10)
363
  citations = gr.Textbox(label="Citations", lines=2)
364
 
 
 
 
 
 
 
 
 
 
 
365
  idx_btn.click(fn=ingest_pdfs, inputs=[files], outputs=[idx_status])
366
  wipe_btn.click(fn=wipe_collection, inputs=None, outputs=[idx_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  question.submit(
369
  fn=answer_query,
370
- inputs=[question, model_name, use_4bit, top_k, rerank_k, max_new_tokens, temperature],
371
  outputs=[answer, citations]
372
  )
373
 
 
 
 
 
 
 
374
  if __name__ == "__main__":
375
- demo.launch()
 
1
  import os
2
  import uuid
 
3
  import tempfile
4
  from dataclasses import dataclass
5
  from functools import lru_cache
6
+ from typing import Optional, List, Tuple, Any
7
+
8
+ import json
9
+ import re
10
 
11
  import gradio as gr
 
12
  from pypdf import PdfReader
13
 
14
  from qdrant_client import QdrantClient
 
51
  text = page.extract_text() or ""
52
  except Exception:
53
  text = ""
 
54
  text = "\n".join(line.strip() for line in text.splitlines() if line.strip())
55
  pages.append((i, text))
56
  return pages
 
79
  page: int
80
 
81
 
82
+ # -----------------------------
83
+ # JSON helpers
84
+ # -----------------------------
85
+
86
+ def read_json_file(path: str):
87
+ try:
88
+ with open(path, 'r', encoding='utf-8') as f:
89
+ return json.load(f)
90
+ except Exception as e:
91
+ return {"__error__": str(e)}
92
+
93
+
94
  # -----------------------------
95
  # Embeddings & Reranker
96
  # -----------------------------
97
  @lru_cache(maxsize=1)
98
  def load_embedder():
99
  model = SentenceTransformer(EMBED_MODEL_NAME)
 
100
  return model
101
 
102
 
 
140
  if _pipe is not None and _current_model_name == model_name:
141
  return _pipe
142
 
 
143
  try:
144
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
145
  quant_kwargs = {}
 
253
  with_payload=True,
254
  score_threshold=score_threshold,
255
  )
256
+ chunks: List[RetrievedChunk] = []
257
  for p in candidates:
258
  payload = p.payload or {}
259
  chunks.append(
 
267
  return chunks
268
 
269
 
270
+ def rerank(query: str, chunks: List[RetrievedChunk], top_n: int = 6):
271
  if not chunks:
272
  return []
273
  reranker = load_reranker()
 
279
 
280
 
281
  # -----------------------------
282
+ # QA Generation (with optional JSON compare)
283
  # -----------------------------
284
 
285
+ def build_prompt(query: str, contexts: List[RetrievedChunk], json_text: Optional[str] = None):
286
  context_text = "\n\n".join([c.text for c in contexts])
287
+ json_block = f"\n\nJSON_SPEC:\n{json_text}\n" if json_text else ""
288
  prompt = (
289
+ f"<s>[SYSTEM]\n{SYSTEM_GUARDRAILS}\nIf a JSON spec is provided, compare it to the PDF context: identify agreements, conflicts, and missing fields explicitly.\n[/SYSTEM]\n"
290
+ f"[USER]\nQuestion: {query}\n\nContext from PDFs:\n{context_text}{json_block}\n"
291
+ f"Answer in English. If conflicts exist between JSON and PDFs, report them clearly. Include PDF citations like [filename p.PAGE].\n[/USER]\n[ASSISTANT]"
292
  )
293
  return prompt
294
 
295
 
296
+ def answer_query(query: str, model_name: str, use_4bit: bool, top_k: int, rerank_k: int,
297
+ max_new_tokens: int, temperature: float, json_path: Optional[str] = None,
298
+ include_json: bool = False):
299
  if not query or not query.strip():
300
  return "Please enter a question.", ""
301
 
 
302
  retrieved = retrieve(query.strip(), top_k=top_k)
303
  if not retrieved:
304
  return "I don't know based on the provided PDFs.", ""
305
 
 
306
  selected = rerank(query.strip(), retrieved, top_n=rerank_k)
307
  if not selected:
308
  return "I don't know based on the provided PDFs.", ""
309
 
310
+ json_text = None
311
+ if include_json and json_path:
312
+ obj = read_json_file(json_path)
313
+ if isinstance(obj, dict) and "__error__" in obj:
314
+ json_text = f"__JSON_ERROR__: {obj['__error__']}"
315
+ else:
316
+ try:
317
+ json_text = json.dumps(obj, ensure_ascii=False)
318
+ if len(json_text) > 8000:
319
+ json_text = json_text[:8000] + "\n... [truncated]"
320
+ except Exception as e:
321
+ json_text = f"__JSON_ERROR__: {e}"
322
+
323
+ prompt = build_prompt(query, selected, json_text=json_text)
324
 
 
325
  pipe = load_llm(model_name, use_4bit=use_4bit)
326
  out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature)[0]["generated_text"]
327
 
 
328
  cits = []
329
  for c in selected:
330
  cits.append(f"[{c.file} p.{c.page}]")
 
331
  seen = set()
332
  uniq = []
333
  for ci in cits:
 
345
  def wipe_collection():
346
  client = get_qdrant_client()
347
  client.delete_collection(COLLECTION_NAME)
 
348
  dim = load_embedder().get_sentence_embedding_dimension()
349
  ensure_collection(client, dim)
350
  return "Collection wiped and recreated."
351
 
352
 
353
+ def get_index_stats(sample_limit: int = 64):
354
+ """Return basic collection stats from Qdrant to verify that indexing worked."""
355
+ client = get_qdrant_client()
356
+ try:
357
+ cnt = client.count(collection_name=COLLECTION_NAME, exact=True).count
358
+ except Exception as e:
359
+ return f"Count failed: {e}"
360
+
361
+ files = []
362
+ try:
363
+ points, next_offset = client.scroll(
364
+ collection_name=COLLECTION_NAME,
365
+ limit=sample_limit,
366
+ with_payload=True,
367
+ )
368
+ for p in points or []:
369
+ payload = p.payload or {}
370
+ fn = payload.get("file") or "unknown.pdf"
371
+ files.append(fn)
372
+ except Exception as e:
373
+ return f"Points: {cnt}. Scroll failed: {e}"
374
+
375
+ uniq_files = sorted(set(files))
376
+ return f"Points: {cnt} | Collection: {COLLECTION_NAME} | Sample files ({len(uniq_files)}): {', '.join(uniq_files[:10])}{' …' if len(uniq_files)>10 else ''}"
377
+
378
+
379
+ # -----------------------------
380
+ # Counterfactual (CF) evaluator helpers
381
+ # -----------------------------
382
+
383
+ def _flatten_first(x):
384
+ while isinstance(x, (list, tuple)) and len(x) == 1 and isinstance(x[0], (list, tuple)):
385
+ x = x[0]
386
+ return x
387
+
388
+
389
+ def parse_cf_input_json(path: str):
390
+ data = read_json_file(path)
391
+ if isinstance(data, dict) and data.get("__error__"):
392
+ return None, f"Load error: {data['__error__']}"
393
+ req = [
394
+ "test_data", "cfs_list", "feature_names", "feature_names_including_target",
395
+ "data_interface", "desired_class"
396
+ ]
397
+ for k in req:
398
+ if k not in data:
399
+ return None, f"Missing key: {k}"
400
+ test_data = _flatten_first(data["test_data"]) # one row
401
+ cfs_list = _flatten_first(data["cfs_list"]) # list of rows
402
+ if not isinstance(cfs_list, (list, tuple)) or not cfs_list:
403
+ return None, "Empty cfs_list"
404
+ feat_inc = data["feature_names_including_target"]
405
+ desired = data["desired_class"]
406
+ outcome_name = data.get("data_interface", {}).get("outcome_name") or data.get("outcome_name", "income")
407
+ return {
408
+ "test_data": test_data,
409
+ "cfs_list": cfs_list,
410
+ "feature_names_including_target": feat_inc,
411
+ "feature_names": data["feature_names"],
412
+ "desired_class": desired,
413
+ "outcome_name": outcome_name,
414
+ }, None
415
+
416
+
417
+ def build_cf_retrieval_query(test_row, feature_names):
418
+ try:
419
+ fmap = {k: v for k, v in zip(feature_names, test_row)}
420
+ except Exception:
421
+ return "adult income factors by occupation, education, hours per week, marital status"
422
+ keys = ["occupation", "education", "workclass", "marital_status", "age", "hours_per_week", "gender", "race"]
423
+ parts = [f"{k}:{fmap[k]}" for k in keys if k in fmap]
424
+ parts.append("income threshold and probability drivers")
425
+ return ", ".join(map(str, parts))
426
+
427
+
428
+ def get_rag_context_text(query: str, top_k: int, rerank_k: int, max_chars: int = 8000):
429
+ chunks = retrieve(query, top_k=top_k)
430
+ if not chunks:
431
+ return ""
432
+ selected = rerank(query, chunks, top_n=rerank_k)
433
+ lines = [f"{c.text}\n[CIT: {c.file} p.{c.page}]" for c in selected]
434
+ return "\n\n".join(lines)[:max_chars]
435
+
436
+
437
+ def build_cf_prompt(parsed, rag_text: str = "", extra_json_text: str = ""):
438
+ td = parsed["test_data"]
439
+ cfs = parsed["cfs_list"]
440
+ feat_inc = parsed["feature_names_including_target"]
441
+ desired = parsed["desired_class"]
442
+ outcome = parsed["outcome_name"]
443
+
444
+ instr = (
445
+ "You are a helpful assistant with deep knowledge of counterfactual explanations, fairness, and causal reasoning.\n\n"
446
+ "You will be given a test data point, candidate counterfactuals, feature names (including target), the desired class,"
447
+ " and real-world context retrieved from documents.\n\n"
448
+ "Goals: (1) choose or propose a counterfactual that flips the class to the desired one, (2) minimize actionable changes,"
449
+ " (3) ensure plausibility given Adult Income and provided context.\n\n"
450
+ "Return only this JSON (no prose): {\"best_cf\": [...], \"explanation\": \"...\"}"
451
+ )
452
+
453
+ ctx = ""
454
+ if rag_text:
455
+ ctx += f"\n\nRETRIEVED_CONTEXT:\n{rag_text}"
456
+ if extra_json_text:
457
+ ctx += f"\n\nUPLOADED_JSON_CONTEXT:\n{extra_json_text}"
458
+
459
+ user = (
460
+ f"feature_names_including_target: {json.dumps(feat_inc)}\n"
461
+ f"desired_class: {desired}\n"
462
+ f"outcome_name: {outcome}\n"
463
+ f"test_data: {json.dumps(td)}\n"
464
+ f"cfs_list: {json.dumps(cfs)}\n"
465
+ f"{ctx}\n\n"
466
+ "Only output the JSON with keys 'best_cf' and 'explanation'. Ensure 'best_cf' matches the length and order of feature_names_including_target."
467
+ )
468
+
469
+ return (
470
+ f"<s>[SYSTEM]\n{SYSTEM_GUARDRAILS}\n{instr}\n[/SYSTEM]\n"
471
+ f"[USER]\n{user}\n[/USER]\n[ASSISTANT]"
472
+ )
473
+
474
+
475
+ def extract_json_object(text: str):
476
+ try:
477
+ obj = json.loads(text)
478
+ if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj:
479
+ return json.dumps(obj, ensure_ascii=False)
480
+ except Exception:
481
+ pass
482
+ m = re.search(r"\{[\s\S]*\}", text)
483
+ if m:
484
+ try:
485
+ obj = json.loads(m.group(0))
486
+ if isinstance(obj, dict) and "best_cf" in obj and "explanation" in obj:
487
+ return json.dumps(obj, ensure_ascii=False)
488
+ except Exception:
489
+ return "{\n \"error\": \"Model returned invalid JSON.\"\n}"
490
+ return "{\n \"error\": \"No JSON object found in model output.\"\n}"
491
+
492
+
493
+ def evaluate_cfs(cf_json_path: Optional[str], use_rag: bool, top_k: int, rerank_k: int,
494
+ max_new_tokens: int, temperature: float, extra_json_files, include_extra_json: bool,
495
+ model_name: str, use_4bit: bool):
496
+ if not cf_json_path:
497
+ return "{\n \"error\": \"No CF input JSON uploaded.\"\n}"
498
+ parsed, err = parse_cf_input_json(cf_json_path)
499
+ if err:
500
+ return json.dumps({"error": err})
501
+
502
+ rag_text = ""
503
+ if use_rag:
504
+ q = build_cf_retrieval_query(parsed["test_data"], parsed["feature_names_including_target"][:-1])
505
+ rag_text = get_rag_context_text(q, top_k=int(top_k), rerank_k=int(rerank_k))
506
+
507
+ extra_json_text = ""
508
+ if include_extra_json and extra_json_files:
509
+ paths = _normalize_file_inputs(extra_json_files)
510
+ blobs = []
511
+ for p in paths:
512
+ obj = read_json_file(p)
513
+ try:
514
+ blobs.append(json.dumps(obj, ensure_ascii=False))
515
+ except Exception:
516
+ continue
517
+ extra_json_text = ("\n\n".join(blobs))[:6000]
518
+
519
+ prompt = build_cf_prompt(parsed, rag_text=rag_text, extra_json_text=extra_json_text)
520
+
521
+ pipe = load_llm(model_name, use_4bit=use_4bit)
522
+ out = pipe(prompt, max_new_tokens=int(max_new_tokens), do_sample=(temperature > 0), temperature=float(temperature))[0]["generated_text"]
523
+
524
+ return extract_json_object(out)
525
+
526
+
527
  # -----------------------------
528
  # UI
529
  # -----------------------------
 
539
  idx_btn = gr.Button("Build / Update Index")
540
  idx_status = gr.Textbox(label="Status", interactive=False)
541
  wipe_btn = gr.Button("Wipe Index (danger)")
542
+ stats_btn = gr.Button("Index stats")
543
+ stats_box = gr.Textbox(label="Index stats", interactive=False)
544
+
545
+ with gr.Accordion("JSON compare (optional)", open=False):
546
+ json_file = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload JSON spec")
547
+ include_json = gr.Checkbox(value=False, label="Include JSON in prompt for comparison")
548
+ json_info_box = gr.Textbox(label="JSON status", interactive=False)
549
 
550
  with gr.Row():
551
  model_name = gr.Dropdown(choices=DEFAULT_MODELS, value=DEFAULT_MODELS[0], label="LLM")
 
563
  answer = gr.Textbox(label="Answer", lines=10)
564
  citations = gr.Textbox(label="Citations", lines=2)
565
 
566
+ # CF evaluator UI
567
+ with gr.Accordion("Counterfactual Evaluator", open=False):
568
+ cf_input_json = gr.File(file_count="single", file_types=[".json"], type="filepath", label="Upload CF input JSON (Adult format)")
569
+ cf_extra_jsons = gr.File(file_count="multiple", file_types=[".json"], type="filepath", label="Optional: Additional JSON context")
570
+ include_rag_cf = gr.Checkbox(value=True, label="Use RAG context from indexed PDFs")
571
+ include_extra_json_cf = gr.Checkbox(value=False, label="Include uploaded JSON context in prompt")
572
+ eval_btn = gr.Button("Evaluate Counterfactuals → JSON output")
573
+ result_cf = gr.Textbox(label="Result JSON", lines=10)
574
+
575
+ # Wiring
576
  idx_btn.click(fn=ingest_pdfs, inputs=[files], outputs=[idx_status])
577
  wipe_btn.click(fn=wipe_collection, inputs=None, outputs=[idx_status])
578
+ stats_btn.click(fn=get_index_stats, inputs=None, outputs=[stats_box])
579
+
580
+ def _json_info(path):
581
+ if not path:
582
+ return "No JSON uploaded."
583
+ obj = read_json_file(path)
584
+ if isinstance(obj, dict) and "__error__" in obj:
585
+ return f"JSON error: {obj['__error__']}"
586
+ try:
587
+ if isinstance(obj, dict):
588
+ keys = len(obj.keys())
589
+ return f"Loaded JSON object with {keys} top-level keys."
590
+ elif isinstance(obj, list):
591
+ return f"Loaded JSON array with {len(obj)} items."
592
+ else:
593
+ return f"Loaded JSON of type {type(obj).__name__}."
594
+ except Exception:
595
+ return "Loaded JSON."
596
+
597
+ json_file.change(fn=_json_info, inputs=[json_file], outputs=[json_info_box])
598
 
599
  question.submit(
600
  fn=answer_query,
601
+ inputs=[question, model_name, use_4bit, top_k, rerank_k, max_new_tokens, temperature, json_file, include_json],
602
  outputs=[answer, citations]
603
  )
604
 
605
+ eval_btn.click(
606
+ fn=evaluate_cfs,
607
+ inputs=[cf_input_json, include_rag_cf, top_k, rerank_k, max_new_tokens, temperature, cf_extra_jsons, include_extra_json_cf, model_name, use_4bit],
608
+ outputs=[result_cf]
609
+ )
610
+
611
  if __name__ == "__main__":
612
+ demo.launch()