# RAG Demo - Joshua M Davis 2025 (Clean RAG: no role preamble, no citations, concise answers) import os, glob, hashlib, re from typing import List, Dict, Any, Optional import numpy as np import faiss import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from sentence_transformers import SentenceTransformer # ---------------------------- # Model configuration # ---------------------------- GEN_MODEL_NAME = os.getenv("GEN_MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct") EMB_MODEL_NAME = os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2") _tok = None _mdl = None _pipe = None _emb = None _faiss = None _docs: List[Dict[str, Any]] = [] # ---------------------------- # Utilities # ---------------------------- def seed_all(seed: Optional[int]) -> None: import random s = 0 if seed is None else seed random.seed(s) try: import torch torch.manual_seed(s) if torch.cuda.is_available(): torch.cuda.manual_seed_all(s) except Exception: pass def get_pipe(): """Lazy-load a simple text-generation pipeline (causal LM).""" global _pipe, _tok, _mdl if _pipe is None: _tok = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) _mdl = AutoModelForCausalLM.from_pretrained(GEN_MODEL_NAME) _pipe = pipeline("text-generation", model=_mdl, tokenizer=_tok) return _pipe def load_corpus(cdir: str = "./corpus") -> List[Dict[str, Any]]: """Load *.txt corpus files into memory.""" os.makedirs(cdir, exist_ok=True) out: List[Dict[str, Any]] = [] for p in sorted(glob.glob(os.path.join(cdir, "*.txt"))): try: with open(p, "r", encoding="utf-8", errors="ignore") as f: txt = f.read().strip() if txt: out.append({"id": hashlib.sha1(p.encode()).hexdigest()[:8], "text": txt, "path": p}) except Exception: pass return out def get_emb(): """Lazy-load the sentence embedding model.""" global _emb if _emb is None: _emb = SentenceTransformer(EMB_MODEL_NAME) return _emb def embed(texts: List[str]) -> np.ndarray: """Create normalized embeddings (cosine similarity via inner product).""" E = get_emb() vec = E.encode(texts, normalize_embeddings=True, convert_to_numpy=True) return vec.astype(np.float32) def build_index(docs: List[Dict[str, Any]]) -> None: """Build an inner-product FAISS index.""" global _faiss if not docs: _faiss = faiss.IndexFlatIP(384) # MiniLM dim placeholder return V = embed([d["text"] for d in docs]) _faiss = faiss.IndexFlatIP(V.shape[1]) _faiss.add(V) def retrieve(q: str, k: int = 4) -> List[Dict[str, Any]]: """Return top-k docs with similarity scores.""" global _docs, _faiss if _faiss is None or not _docs: return [] qv = embed([q]) scores, idxs = _faiss.search(qv, min(k, len(_docs))) out: List[Dict[str, Any]] = [] for s, i in zip(scores[0], idxs[0]): if i < 0: continue d = dict(_docs[i]) d["score"] = float(s) out.append(d) return out def fmt_ctx(snips: List[Dict[str, Any]]) -> str: """ Build plain bullet context (no [C#] labels, no headings). We keep it minimal so the model doesn't copy labels as an "answer". """ lines: List[str] = [] for s in snips: lines.append(f"- {s['text'].strip()}") return "\n".join(lines).strip() # ---------------------------- # Clean, strict RAG prompt (concise answer, no citations or preambles) # ---------------------------- STRICT_RAG_SYSTEM = ( "Answer ONLY using the provided context. " "Reply in ONE short sentence with just the answer. " "Do not include citations, brackets, numbers, or explanations. " "If the context does not contain the answer, reply exactly: " "\"I don't know based on the provided context.\"" ) def rag_prompt(question: str, ctx: str) -> str: # Keep structure tight and minimal to avoid instruction echo return ( f"{STRICT_RAG_SYSTEM}\n\n" f"Context:\n{ctx}\n\n" f"Question: {question.strip()}\n" f"Answer:" ) # ---------------------------- # Deterministic generation # ---------------------------- def det_generate(prompt: str, strategy: str, beams: int, max_new_tokens: int) -> str: """Greedy vs. Beam-search (deterministic decoding).""" seed_all(0) P = get_pipe() if strategy == "beam": out = P( prompt, do_sample=False, num_beams=max(1, beams), early_stopping=True, max_new_tokens=max_new_tokens, eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None, ) else: out = P( prompt, do_sample=False, max_new_tokens=max_new_tokens, eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None, ) return out[0]["generated_text"] # ---------------------------- # Post-cleaner for RAG answers # ---------------------------- def post_clean(text: str) -> str: """ Remove any residual instruction echoes or bracket bits and keep only the first sentence. If the string becomes empty, fall back to the abstention line. """ a = text.strip() # Trim if the model echoed "Answer:" or "Context:" lines a = re.sub(r"(?is)^.*?Answer:\s*", "", a).strip() # Remove obvious instruction echoes bad_starts = [ "answer only using the provided context", "role:", "you are a careful assistant", "this answer is", "based solely", "therefore", "produce the answer", ] lower = a.lower() for bs in bad_starts: if lower.startswith(bs): # Take the remainder after the first period if present a = a.split(".", 1)[-1].strip() or a break # Strip bracketed numeric citations like [1], [23] a = re.sub(r"\s*\[\d+\]\s*", " ", a).strip() # Keep only first sentence if "." in a: a = a.split(".", 1)[0].strip() + "." # Normalize whitespace and stray quotes a = re.sub(r"\s+", " ", a).strip(" \"'") if not a: a = "I don't know based on the provided context." return a # ---------------------------- # RAG answer (deterministic, concise, clean) # ---------------------------- def rag_answer(question: str, top_k: int, beams: int, length_penalty: float, max_new_tokens: int) -> str: """RAG grounded answer with deterministic decoding controls (no sampling).""" hits = retrieve(question, k=top_k) if not hits: return "I don't know based on the provided context." # Optional: quick guard for known classroom query qlow = question.lower() if ("female" in qlow or "woman" in qlow or "women" in qlow) and ("president" in qlow): ctx_all = " ".join([h["text"] for h in hits]).lower() if "never had a female president" in ctx_all or "no female president" in ctx_all: return "As of 2025, the United States has never had a female president." ctx = fmt_ctx(hits) prompt = rag_prompt(question, ctx) seed_all(0) P = get_pipe() out = P( prompt, do_sample=False, # deterministic num_beams=max(1, beams), length_penalty=float(length_penalty), early_stopping=True, max_new_tokens=max_new_tokens, eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None, ) raw = out[0]["generated_text"] return post_clean(raw) # ---------------------------- # Build index at import # ---------------------------- _docs = load_corpus("./corpus") build_index(_docs) # ---------------------------- # Gradio UI # ---------------------------- with gr.Blocks(title="Deterministic & RAG (Clean Answers)") as demo: gr.Markdown( "## Deterministic vs RAG-Grounded (Clean)\n" "RAG answers are **one short sentence**, **no citations**, **no headings**.\n" "Put `.txt` files into `./corpus` and ask questions grounded in that content." ) with gr.Tab("Deterministic Text"): inp = gr.Textbox(label="Prompt", placeholder="Explain beam search in one paragraph.") strat = gr.Dropdown(choices=["greedy", "beam"], value="beam", label="Strategy") beams = gr.Slider(1, 8, step=1, value=4, label="Beams (num_beams)") mxt = gr.Slider(16, 512, step=16, value=128, label="Max new tokens") btn = gr.Button("Generate") out = gr.Textbox(label="Output", lines=8) btn.click(det_generate, [inp, strat, beams, mxt], [out]) with gr.Tab("RAG-Grounded"): q = gr.Textbox(label="Question", placeholder="Ask a question answerable from your ./corpus/*.txt files.") topk = gr.Slider(1, 10, step=1, value=4, label="Top-K Passages") r_beams = gr.Slider(1, 8, step=1, value=4, label="Beams (num_beams)") lp = gr.Slider(0.5, 2.0, step=0.1, value=1.0, label="Length Penalty") r_mxt = gr.Slider(16, 512, step=16, value=128, label="Max new tokens") r_btn = gr.Button("Answer from RAG") r_out = gr.Textbox(label="Answer", lines=4) r_btn.click(rag_answer, [q, topk, r_beams, lp, r_mxt], [r_out]) # ---------------------------- # Launch # ---------------------------- if __name__ == "__main__": demo.launch()