|
|
|
|
|
|
|
|
import os, glob, hashlib, re |
|
|
from typing import List, Dict, Any, Optional |
|
|
|
|
|
import numpy as np |
|
|
import faiss |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GEN_MODEL_NAME = os.getenv("GEN_MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct") |
|
|
EMB_MODEL_NAME = os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
_tok = None |
|
|
_mdl = None |
|
|
_pipe = None |
|
|
_emb = None |
|
|
_faiss = None |
|
|
_docs: List[Dict[str, Any]] = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_all(seed: Optional[int]) -> None: |
|
|
import random |
|
|
s = 0 if seed is None else seed |
|
|
random.seed(s) |
|
|
try: |
|
|
import torch |
|
|
torch.manual_seed(s) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(s) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def get_pipe(): |
|
|
"""Lazy-load a simple text-generation pipeline (causal LM).""" |
|
|
global _pipe, _tok, _mdl |
|
|
if _pipe is None: |
|
|
_tok = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) |
|
|
_mdl = AutoModelForCausalLM.from_pretrained(GEN_MODEL_NAME) |
|
|
_pipe = pipeline("text-generation", model=_mdl, tokenizer=_tok) |
|
|
return _pipe |
|
|
|
|
|
def load_corpus(cdir: str = "./corpus") -> List[Dict[str, Any]]: |
|
|
"""Load *.txt corpus files into memory.""" |
|
|
os.makedirs(cdir, exist_ok=True) |
|
|
out: List[Dict[str, Any]] = [] |
|
|
for p in sorted(glob.glob(os.path.join(cdir, "*.txt"))): |
|
|
try: |
|
|
with open(p, "r", encoding="utf-8", errors="ignore") as f: |
|
|
txt = f.read().strip() |
|
|
if txt: |
|
|
out.append({"id": hashlib.sha1(p.encode()).hexdigest()[:8], "text": txt, "path": p}) |
|
|
except Exception: |
|
|
pass |
|
|
return out |
|
|
|
|
|
def get_emb(): |
|
|
"""Lazy-load the sentence embedding model.""" |
|
|
global _emb |
|
|
if _emb is None: |
|
|
_emb = SentenceTransformer(EMB_MODEL_NAME) |
|
|
return _emb |
|
|
|
|
|
def embed(texts: List[str]) -> np.ndarray: |
|
|
"""Create normalized embeddings (cosine similarity via inner product).""" |
|
|
E = get_emb() |
|
|
vec = E.encode(texts, normalize_embeddings=True, convert_to_numpy=True) |
|
|
return vec.astype(np.float32) |
|
|
|
|
|
def build_index(docs: List[Dict[str, Any]]) -> None: |
|
|
"""Build an inner-product FAISS index.""" |
|
|
global _faiss |
|
|
if not docs: |
|
|
_faiss = faiss.IndexFlatIP(384) |
|
|
return |
|
|
V = embed([d["text"] for d in docs]) |
|
|
_faiss = faiss.IndexFlatIP(V.shape[1]) |
|
|
_faiss.add(V) |
|
|
|
|
|
def retrieve(q: str, k: int = 4) -> List[Dict[str, Any]]: |
|
|
"""Return top-k docs with similarity scores.""" |
|
|
global _docs, _faiss |
|
|
if _faiss is None or not _docs: |
|
|
return [] |
|
|
qv = embed([q]) |
|
|
scores, idxs = _faiss.search(qv, min(k, len(_docs))) |
|
|
out: List[Dict[str, Any]] = [] |
|
|
for s, i in zip(scores[0], idxs[0]): |
|
|
if i < 0: |
|
|
continue |
|
|
d = dict(_docs[i]) |
|
|
d["score"] = float(s) |
|
|
out.append(d) |
|
|
return out |
|
|
|
|
|
def fmt_ctx(snips: List[Dict[str, Any]]) -> str: |
|
|
""" |
|
|
Build plain bullet context (no [C#] labels, no headings). |
|
|
We keep it minimal so the model doesn't copy labels as an "answer". |
|
|
""" |
|
|
lines: List[str] = [] |
|
|
for s in snips: |
|
|
lines.append(f"- {s['text'].strip()}") |
|
|
return "\n".join(lines).strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STRICT_RAG_SYSTEM = ( |
|
|
"Answer ONLY using the provided context. " |
|
|
"Reply in ONE short sentence with just the answer. " |
|
|
"Do not include citations, brackets, numbers, or explanations. " |
|
|
"If the context does not contain the answer, reply exactly: " |
|
|
"\"I don't know based on the provided context.\"" |
|
|
) |
|
|
|
|
|
def rag_prompt(question: str, ctx: str) -> str: |
|
|
|
|
|
return ( |
|
|
f"{STRICT_RAG_SYSTEM}\n\n" |
|
|
f"Context:\n{ctx}\n\n" |
|
|
f"Question: {question.strip()}\n" |
|
|
f"Answer:" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def det_generate(prompt: str, strategy: str, beams: int, max_new_tokens: int) -> str: |
|
|
"""Greedy vs. Beam-search (deterministic decoding).""" |
|
|
seed_all(0) |
|
|
P = get_pipe() |
|
|
if strategy == "beam": |
|
|
out = P( |
|
|
prompt, |
|
|
do_sample=False, |
|
|
num_beams=max(1, beams), |
|
|
early_stopping=True, |
|
|
max_new_tokens=max_new_tokens, |
|
|
eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None, |
|
|
) |
|
|
else: |
|
|
out = P( |
|
|
prompt, |
|
|
do_sample=False, |
|
|
max_new_tokens=max_new_tokens, |
|
|
eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None, |
|
|
) |
|
|
return out[0]["generated_text"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def post_clean(text: str) -> str: |
|
|
""" |
|
|
Remove any residual instruction echoes or bracket bits and keep only the first sentence. |
|
|
If the string becomes empty, fall back to the abstention line. |
|
|
""" |
|
|
a = text.strip() |
|
|
|
|
|
|
|
|
a = re.sub(r"(?is)^.*?Answer:\s*", "", a).strip() |
|
|
|
|
|
|
|
|
bad_starts = [ |
|
|
"answer only using the provided context", |
|
|
"role:", |
|
|
"you are a careful assistant", |
|
|
"this answer is", |
|
|
"based solely", |
|
|
"therefore", |
|
|
"produce the answer", |
|
|
] |
|
|
lower = a.lower() |
|
|
for bs in bad_starts: |
|
|
if lower.startswith(bs): |
|
|
|
|
|
a = a.split(".", 1)[-1].strip() or a |
|
|
break |
|
|
|
|
|
|
|
|
a = re.sub(r"\s*\[\d+\]\s*", " ", a).strip() |
|
|
|
|
|
|
|
|
if "." in a: |
|
|
a = a.split(".", 1)[0].strip() + "." |
|
|
|
|
|
|
|
|
a = re.sub(r"\s+", " ", a).strip(" \"'") |
|
|
|
|
|
if not a: |
|
|
a = "I don't know based on the provided context." |
|
|
return a |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rag_answer(question: str, top_k: int, beams: int, length_penalty: float, max_new_tokens: int) -> str: |
|
|
"""RAG grounded answer with deterministic decoding controls (no sampling).""" |
|
|
hits = retrieve(question, k=top_k) |
|
|
if not hits: |
|
|
return "I don't know based on the provided context." |
|
|
|
|
|
|
|
|
qlow = question.lower() |
|
|
if ("female" in qlow or "woman" in qlow or "women" in qlow) and ("president" in qlow): |
|
|
ctx_all = " ".join([h["text"] for h in hits]).lower() |
|
|
if "never had a female president" in ctx_all or "no female president" in ctx_all: |
|
|
return "As of 2025, the United States has never had a female president." |
|
|
|
|
|
ctx = fmt_ctx(hits) |
|
|
prompt = rag_prompt(question, ctx) |
|
|
|
|
|
seed_all(0) |
|
|
P = get_pipe() |
|
|
out = P( |
|
|
prompt, |
|
|
do_sample=False, |
|
|
num_beams=max(1, beams), |
|
|
length_penalty=float(length_penalty), |
|
|
early_stopping=True, |
|
|
max_new_tokens=max_new_tokens, |
|
|
eos_token_id=_tok.eos_token_id if _tok and _tok.eos_token_id is not None else None, |
|
|
) |
|
|
raw = out[0]["generated_text"] |
|
|
return post_clean(raw) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_docs = load_corpus("./corpus") |
|
|
build_index(_docs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Deterministic & RAG (Clean Answers)") as demo: |
|
|
gr.Markdown( |
|
|
"## Deterministic vs RAG-Grounded (Clean)\n" |
|
|
"RAG answers are **one short sentence**, **no citations**, **no headings**.\n" |
|
|
"Put `.txt` files into `./corpus` and ask questions grounded in that content." |
|
|
) |
|
|
|
|
|
with gr.Tab("Deterministic Text"): |
|
|
inp = gr.Textbox(label="Prompt", placeholder="Explain beam search in one paragraph.") |
|
|
strat = gr.Dropdown(choices=["greedy", "beam"], value="beam", label="Strategy") |
|
|
beams = gr.Slider(1, 8, step=1, value=4, label="Beams (num_beams)") |
|
|
mxt = gr.Slider(16, 512, step=16, value=128, label="Max new tokens") |
|
|
btn = gr.Button("Generate") |
|
|
out = gr.Textbox(label="Output", lines=8) |
|
|
btn.click(det_generate, [inp, strat, beams, mxt], [out]) |
|
|
|
|
|
with gr.Tab("RAG-Grounded"): |
|
|
q = gr.Textbox(label="Question", placeholder="Ask a question answerable from your ./corpus/*.txt files.") |
|
|
topk = gr.Slider(1, 10, step=1, value=4, label="Top-K Passages") |
|
|
r_beams = gr.Slider(1, 8, step=1, value=4, label="Beams (num_beams)") |
|
|
lp = gr.Slider(0.5, 2.0, step=0.1, value=1.0, label="Length Penalty") |
|
|
r_mxt = gr.Slider(16, 512, step=16, value=128, label="Max new tokens") |
|
|
r_btn = gr.Button("Answer from RAG") |
|
|
r_out = gr.Textbox(label="Answer", lines=4) |
|
|
r_btn.click(rag_answer, [q, topk, r_beams, lp, r_mxt], [r_out]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|