Inframat-x commited on
Commit
65df9cc
·
verified ·
1 Parent(s): 2287ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -26
app.py CHANGED
@@ -5,6 +5,7 @@
5
  # - Predictor: safe model caching + safe feature alignment
6
  # - Stable categoricals ("NA"); no over-strict completeness gate
7
  # - Fixed [[PAGE=...]] regex
 
8
  # ================================================================
9
 
10
  # ---------------------- Runtime flags (HF-safe) ----------------------
@@ -14,7 +15,7 @@ os.environ["TRANSFORMERS_NO_FLAX"] = "1"
14
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
 
16
  # ------------------------------- Imports ------------------------------
17
- import re, joblib, warnings, json, traceback
18
  from pathlib import Path
19
  from typing import List, Dict, Any
20
 
@@ -548,9 +549,27 @@ def compose_extractive(selected: List[Dict[str, Any]]) -> str:
548
  return ""
549
  return " ".join(f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected)
550
 
551
- def synthesize_with_llm(question: str, sentence_lines: List[str], model: str = None, temperature: float = 0.2) -> str:
552
- if not LLM_AVAILABLE:
 
 
 
 
 
 
 
 
 
 
 
 
553
  return None
 
 
 
 
 
 
554
  client = OpenAI(api_key=OPENAI_API_KEY)
555
  model = model or OPENAI_MODEL
556
  SYSTEM_PROMPT = (
@@ -573,9 +592,22 @@ def synthesize_with_llm(question: str, sentence_lines: List[str], model: str = N
573
  ],
574
  temperature=temperature,
575
  )
576
- return getattr(resp, "output_text", None) or str(resp)
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  except Exception:
578
- return None
579
 
580
  def rag_reply(
581
  question: str,
@@ -590,41 +622,141 @@ def rag_reply(
590
  w_bm25: float = W_BM25_DEFAULT,
591
  w_emb: float = W_EMB_DEFAULT
592
  ) -> str:
 
 
 
 
 
593
  hits = hybrid_search(question, k=k, w_tfidf=w_tfidf, w_bm25=w_bm25, w_emb=w_emb)
594
- if hits is None or hits.empty:
595
- return "No indexed PDFs found. Upload PDFs to the 'papers/' folder and reload the Space."
596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  selected = mmr_select_sentences(question, hits, top_n=int(n_sentences), pool_per_chunk=6, lambda_div=0.7)
598
  header_cites = "; ".join(f"{Path(r['doc_path']).name} (p.{_extract_page(r['text'])})" for _, r in hits.head(6).iterrows())
599
  srcs = {Path(r['doc_path']).name for _, r in hits.iterrows()}
600
  coverage_note = "" if len(srcs) >= 3 else f"\n\n> Note: Only {len(srcs)} unique source(s) contributed. Add more PDFs or increase Top-K."
601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  if strict_quotes_only:
603
  if not selected:
604
- return f"**Quoted Passages:**\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2]) + f"\n\n**Citations:** {header_cites}{coverage_note}"
605
- msg = "**Quoted Passages:**\n- " + "\n- ".join(f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected)
606
- msg += f"\n\n**Citations:** {header_cites}{coverage_note}"
607
- if include_passages:
608
- msg += "\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
609
- return msg
610
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  extractive = compose_extractive(selected)
 
 
612
  if use_llm and selected:
613
  lines = [f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected]
614
- llm_text = synthesize_with_llm(question, lines, model=model, temperature=temperature)
 
 
 
 
615
  if llm_text:
616
- msg = f"**Answer (LLM synthesis):** {llm_text}\n\n**Citations:** {header_cites}{coverage_note}"
617
  if include_passages:
618
- msg += "\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
619
- return msg
620
-
621
- if not extractive:
622
- return f"**Answer:** Here are relevant passages.\n\n**Citations:** {header_cites}{coverage_note}\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
623
-
624
- msg = f"**Answer:** {extractive}\n\n**Citations:** {header_cites}{coverage_note}"
625
- if include_passages:
626
- msg += "\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
627
- return msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
 
629
  def rag_chat_fn(message, history, top_k, n_sentences, include_passages,
630
  use_llm, model_name, temperature, strict_quotes_only,
 
5
  # - Predictor: safe model caching + safe feature alignment
6
  # - Stable categoricals ("NA"); no over-strict completeness gate
7
  # - Fixed [[PAGE=...]] regex
8
+ # - NEW: Lightweight instrumentation (JSONL logs per RAG turn)
9
  # ================================================================
10
 
11
  # ---------------------- Runtime flags (HF-safe) ----------------------
 
15
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
 
17
  # ------------------------------- Imports ------------------------------
18
+ import re, joblib, warnings, json, traceback, time, uuid
19
  from pathlib import Path
20
  from typing import List, Dict, Any
21
 
 
549
  return ""
550
  return " ".join(f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected)
551
 
552
+ # ========================= NEW: Instrumentation helpers =========================
553
+ LOG_PATH = ARTIFACT_DIR / "rag_logs.jsonl"
554
+ OPENAI_IN_COST_PER_1K = float(os.getenv("OPENAI_COST_IN_PER_1K", "0"))
555
+ OPENAI_OUT_COST_PER_1K = float(os.getenv("OPENAI_COST_OUT_PER_1K", "0"))
556
+
557
+ def _safe_write_jsonl(path: Path, record: dict):
558
+ try:
559
+ with open(path, "a", encoding="utf-8") as f:
560
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
561
+ except Exception as e:
562
+ print("[Log] write failed:", e)
563
+
564
+ def _calc_cost_usd(prompt_toks, completion_toks):
565
+ if prompt_toks is None or completion_toks is None:
566
  return None
567
+ return (prompt_toks / 1000.0) * OPENAI_IN_COST_PER_1K + (completion_toks / 1000.0) * OPENAI_OUT_COST_PER_1K
568
+
569
+ # ----------------- Modified to return (text, usage_dict) -----------------
570
+ def synthesize_with_llm(question: str, sentence_lines: List[str], model: str = None, temperature: float = 0.2):
571
+ if not LLM_AVAILABLE:
572
+ return None, None
573
  client = OpenAI(api_key=OPENAI_API_KEY)
574
  model = model or OPENAI_MODEL
575
  SYSTEM_PROMPT = (
 
592
  ],
593
  temperature=temperature,
594
  )
595
+ # Try to extract text
596
+ out_text = getattr(resp, "output_text", None) or str(resp)
597
+ # Try to extract usage (prompt_tokens, completion_tokens)
598
+ usage = None
599
+ try:
600
+ u = getattr(resp, "usage", None)
601
+ if u:
602
+ # Newer SDKs: resp.usage has attributes or dict-like
603
+ pt = getattr(u, "prompt_tokens", None) if hasattr(u, "prompt_tokens") else u.get("prompt_tokens", None)
604
+ ct = getattr(u, "completion_tokens", None) if hasattr(u, "completion_tokens") else u.get("completion_tokens", None)
605
+ usage = {"prompt_tokens": pt, "completion_tokens": ct}
606
+ except Exception:
607
+ usage = None
608
+ return out_text, usage
609
  except Exception:
610
+ return None, None
611
 
612
  def rag_reply(
613
  question: str,
 
622
  w_bm25: float = W_BM25_DEFAULT,
623
  w_emb: float = W_EMB_DEFAULT
624
  ) -> str:
625
+ run_id = str(uuid.uuid4())
626
+ t0_total = time.time()
627
+ t0_retr = time.time()
628
+
629
+ # --- Retrieval ---
630
  hits = hybrid_search(question, k=k, w_tfidf=w_tfidf, w_bm25=w_bm25, w_emb=w_emb)
631
+ t1_retr = time.time()
632
+ latency_ms_retriever = int((t1_retr - t0_retr) * 1000)
633
 
634
+ if hits is None or hits.empty:
635
+ final = "No indexed PDFs found. Upload PDFs to the 'papers/' folder and reload the Space."
636
+ # Minimal log on miss
637
+ record = {
638
+ "run_id": run_id,
639
+ "ts": int(time.time()*1000),
640
+ "inputs": {
641
+ "question": question, "top_k": int(k), "n_sentences": int(n_sentences),
642
+ "w_tfidf": float(w_tfidf), "w_bm25": float(w_bm25), "w_emb": float(w_emb),
643
+ "use_llm": bool(use_llm), "model": model, "temperature": float(temperature)
644
+ },
645
+ "retrieval": {"hits": [], "latency_ms_retriever": latency_ms_retriever},
646
+ "output": {"final_answer": final, "used_sentences": []},
647
+ "latency_ms_total": int((time.time()-t0_total)*1000),
648
+ "openai": None
649
+ }
650
+ _safe_write_jsonl(LOG_PATH, record)
651
+ return final
652
+
653
+ # Select sentences
654
  selected = mmr_select_sentences(question, hits, top_n=int(n_sentences), pool_per_chunk=6, lambda_div=0.7)
655
  header_cites = "; ".join(f"{Path(r['doc_path']).name} (p.{_extract_page(r['text'])})" for _, r in hits.head(6).iterrows())
656
  srcs = {Path(r['doc_path']).name for _, r in hits.iterrows()}
657
  coverage_note = "" if len(srcs) >= 3 else f"\n\n> Note: Only {len(srcs)} unique source(s) contributed. Add more PDFs or increase Top-K."
658
 
659
+ # Prepare retrieval list for logging
660
+ retr_list = []
661
+ for _, r in hits.iterrows():
662
+ retr_list.append({
663
+ "doc": Path(r["doc_path"]).name,
664
+ "page": _extract_page(r["text"]),
665
+ "score_tfidf": float(r.get("score_tfidf", 0.0)),
666
+ "score_bm25": float(r.get("score_bm25", 0.0)),
667
+ "score_dense": float(r.get("score_dense", 0.0)),
668
+ "combo_score": float(r.get("score", 0.0)),
669
+ })
670
+
671
+ # Strict quotes only (no LLM)
672
  if strict_quotes_only:
673
  if not selected:
674
+ final = f"**Quoted Passages:**\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2]) + f"\n\n**Citations:** {header_cites}{coverage_note}"
675
+ else:
676
+ final = "**Quoted Passages:**\n- " + "\n- ".join(f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected)
677
+ final += f"\n\n**Citations:** {header_cites}{coverage_note}"
678
+ if include_passages:
679
+ final += "\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
680
+
681
+ record = {
682
+ "run_id": run_id,
683
+ "ts": int(time.time()*1000),
684
+ "inputs": {
685
+ "question": question, "top_k": int(k), "n_sentences": int(n_sentences),
686
+ "w_tfidf": float(w_tfidf), "w_bm25": float(w_bm25), "w_emb": float(w_emb),
687
+ "use_llm": False, "model": None, "temperature": float(temperature)
688
+ },
689
+ "retrieval": {"hits": retr_list, "latency_ms_retriever": latency_ms_retriever},
690
+ "output": {
691
+ "final_answer": final,
692
+ "used_sentences": [{"sent": s["sent"], "doc": s["doc"], "page": s["page"]} for s in selected]
693
+ },
694
+ "latency_ms_total": int((time.time()-t0_total)*1000),
695
+ "openai": None
696
+ }
697
+ _safe_write_jsonl(LOG_PATH, record)
698
+ return final
699
+
700
+ # Extractive or LLM synthesis
701
  extractive = compose_extractive(selected)
702
+ llm_usage = None
703
+ llm_latency_ms = None
704
  if use_llm and selected:
705
  lines = [f"{s['sent']} ({s['doc']}, p.{s['page']})" for s in selected]
706
+ t0_llm = time.time()
707
+ llm_text, llm_usage = synthesize_with_llm(question, lines, model=model, temperature=temperature)
708
+ t1_llm = time.time()
709
+ llm_latency_ms = int((t1_llm - t0_llm) * 1000)
710
+
711
  if llm_text:
712
+ final = f"**Answer (LLM synthesis):** {llm_text}\n\n**Citations:** {header_cites}{coverage_note}"
713
  if include_passages:
714
+ final += "\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
715
+ else:
716
+ # fall back to extractive
717
+ if not extractive:
718
+ final = f"**Answer:** Here are relevant passages.\n\n**Citations:** {header_cites}{coverage_note}\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
719
+ else:
720
+ final = f"**Answer:** {extractive}\n\n**Citations:** {header_cites}{coverage_note}"
721
+ if include_passages:
722
+ final += "\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
723
+ else:
724
+ if not extractive:
725
+ final = f"**Answer:** Here are relevant passages.\n\n**Citations:** {header_cites}{coverage_note}\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
726
+ else:
727
+ final = f"**Answer:** {extractive}\n\n**Citations:** {header_cites}{coverage_note}"
728
+ if include_passages:
729
+ final += "\n\n---\n" + "\n\n".join(hits['text'].tolist()[:2])
730
+
731
+ # --------- Log full run ---------
732
+ prompt_toks = llm_usage.get("prompt_tokens") if llm_usage else None
733
+ completion_toks = llm_usage.get("completion_tokens") if llm_usage else None
734
+ cost_usd = _calc_cost_usd(prompt_toks, completion_toks)
735
+
736
+ total_ms = int((time.time() - t0_total) * 1000)
737
+ record = {
738
+ "run_id": run_id,
739
+ "ts": int(time.time()*1000),
740
+ "inputs": {
741
+ "question": question, "top_k": int(k), "n_sentences": int(n_sentences),
742
+ "w_tfidf": float(w_tfidf), "w_bm25": float(w_bm25), "w_emb": float(w_emb),
743
+ "use_llm": bool(use_llm), "model": model, "temperature": float(temperature)
744
+ },
745
+ "retrieval": {"hits": retr_list, "latency_ms_retriever": latency_ms_retriever},
746
+ "output": {
747
+ "final_answer": final,
748
+ "used_sentences": [{"sent": s["sent"], "doc": s["doc"], "page": s["page"]} for s in selected]
749
+ },
750
+ "latency_ms_total": total_ms,
751
+ "latency_ms_llm": llm_latency_ms,
752
+ "openai": {
753
+ "prompt_tokens": prompt_toks,
754
+ "completion_tokens": completion_toks,
755
+ "cost_usd": cost_usd
756
+ } if use_llm else None
757
+ }
758
+ _safe_write_jsonl(LOG_PATH, record)
759
+ return final
760
 
761
  def rag_chat_fn(message, history, top_k, n_sentences, include_passages,
762
  use_llm, model_name, temperature, strict_quotes_only,