RAGDemo / app.py
DrDavis's picture
Create app.py
ed943e3 verified
raw
history blame
13.8 kB
"""
RAG Mini Demo (CPU-friendly)
----------------------------
This Gradio app shows side-by-side answers from:
1) LLM-Only β†’ the model answers directly from the question
2) RAG β†’ the model answers using retrieved context from a small corpus
Stack (all CPU-friendly):
- sentence-transformers/all-MiniLM-L6-v2 for embeddings (vector representations)
- FAISS (CPU) for fast similarity search over vectors
- google/flan-t5-small for generation
- Gradio for the web UI
"""
import gradio as gr
import os, io, re, faiss
from typing import List, Tuple
from dataclasses import dataclass
# Embedding model (turns text β†’ vectors)
from sentence_transformers import SentenceTransformer
# Text generation pipeline (small, instruction-friendly model)
from transformers import pipeline
# ----------------------------
# App configuration (easy knobs)
# ----------------------------
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # small, high-quality sentence embeddings
GEN_MODEL_ID = "google/flan-t5-small" # tiny generator for CPU Spaces
# Chunking settings for splitting long documents
CHUNK_SIZE = 500 # characters per chunk (teaching default)
CHUNK_OVERLAP = 100 # characters of overlap between consecutive chunks
TOP_K = 3 # how many chunks to retrieve for the RAG prompt
# ----------------------------
# Utility functions
# ----------------------------
def normalize_ws(text: str) -> str:
"""
Normalize whitespace so we don't store noisy text.
Replaces multiple spaces/newlines with a single space, strips ends.
"""
return re.sub(r"\s+", " ", text).strip()
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""
Split long text into overlapping chunks so that retrieval can match smaller sections.
Overlap helps avoid 'boundary' problems where a key sentence is split between two chunks.
"""
text = normalize_ws(text)
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = min(len(text), start + chunk_size)
chunks.append(text[start:end])
if end == len(text):
break
# move the window forward, but keep 'overlap' characters of the previous chunk
start = max(0, end - overlap)
return chunks
def read_txt_or_md(file_obj: io.BytesIO, filename: str) -> str:
"""
Read .txt or .md files as UTF-8 text.
We restrict to these formats to keep the demo simple and robust on CPU Spaces.
"""
ext = os.path.splitext(filename.lower())[1]
if ext not in [".txt", ".md"]:
return ""
try:
content = file_obj.read().decode("utf-8", errors="ignore")
return content
except Exception:
return ""
# ----------------------------
# RAG store: Keeps chunks + FAISS index
# ----------------------------
@dataclass
class RAGStore:
"""
Holds everything needed for retrieval:
- Original docs and chunked docs
- The embedding model (SentenceTransformer)
- A FAISS index built over the chunk embeddings
- A local copy of embeddings for possible future use (not strictly required)
"""
corpus_docs: List[str] # raw documents for bookkeeping (not used in retrieval)
corpus_chunks: List[str] # chunked strings actually used for retrieval
embedder: SentenceTransformer # embedding model
d: int # embedding dimension
index: faiss.IndexFlatIP # FAISS index (Inner Product = cosine when normalized)
matrix: any # numpy array of embeddings for all chunks
@classmethod
def create(cls, embedder: SentenceTransformer):
"""
Build a RAGStore with a tiny seed corpus so the Space works 'out of the box'.
Students can add more docs later via the UI.
"""
seed_docs = [
"Graduation Honors Policy: Students who graduate with a GPA of 3.75 or higher are eligible for Latin honors as specified by the university catalog.",
"Add/Drop Deadline: The last day to drop a full-semester class without a grade penalty is the end of week 10, unless otherwise specified by the academic calendar.",
"Library Hours: During fall and spring semesters, the main library is open from 8am to 10pm Monday through Thursday."
]
# Chunk the seed docs
chunks = []
for doc in seed_docs:
chunks.extend(chunk_text(doc))
# Embed all chunks (normalize to enable cosine similarity via Inner Product)
embeds = embedder.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
# Build a FAISS index: IndexFlatIP = inner product (dot product)
# With normalized vectors, dot product == cosine similarity
d = embeds.shape[1]
index = faiss.IndexFlatIP(d)
index.add(embeds)
return cls(
corpus_docs=seed_docs,
corpus_chunks=chunks,
embedder=embedder,
d=d,
index=index,
matrix=embeds
)
def add_documents(self, new_docs: List[str]):
"""
Add new documents to the store:
1) Clean and append to corpus
2) Chunk
3) Embed
4) Add embeddings to FAISS and local matrix
"""
clean = [normalize_ws(x) for x in new_docs if x and normalize_ws(x)]
if not clean:
return
self.corpus_docs.extend(clean)
# Re-chunk new docs
new_chunks = []
for doc in clean:
new_chunks.extend(chunk_text(doc))
if not new_chunks:
return
# Embed and add to FAISS
new_embeds = self.embedder.encode(new_chunks, convert_to_numpy=True, normalize_embeddings=True)
self.index.add(new_embeds)
# Also update our local embedding matrix and chunk list
import numpy as np
self.matrix = np.vstack([self.matrix, new_embeds]) if self.matrix is not None else new_embeds
self.corpus_chunks.extend(new_chunks)
def retrieve(self, query: str, k: int = TOP_K) -> List[Tuple[float, str]]:
"""
Retrieve top-k chunks for a user query.
Steps:
a) Embed the query
b) Search FAISS for nearest chunk vectors
c) Return (score, chunk_text) pairs
"""
if not query.strip() or len(self.corpus_chunks) == 0:
return []
q = self.embedder.encode([normalize_ws(query)], convert_to_numpy=True, normalize_embeddings=True)
scores, idxs = self.index.search(q, min(k, len(self.corpus_chunks)))
hits = []
for score, idx in zip(scores[0], idxs[0]):
if idx == -1: # safety if FAISS returns -1
continue
hits.append((float(score), self.corpus_chunks[idx]))
return hits
# ----------------------------
# Build models (loaded once at startup)
# ----------------------------
embedder = SentenceTransformer(EMBED_MODEL_ID)
rag = RAGStore.create(embedder)
# Generator: FLAN-T5 small for CPU
generator = pipeline("text2text-generation", model=GEN_MODEL_ID)
# ----------------------------
# Generation helpers
# ----------------------------
def generate_llm_only(question: str,
max_new_tokens: int = 128,
temperature: float = 0.6,
top_p: float = 0.9) -> str:
"""
LLM-only: send the question directly to the generator without context.
This is our baseline; can hallucinate if question requires specific facts.
"""
if not question.strip():
return "Please enter a question."
out = generator(
question.strip(),
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
)
return out[0]["generated_text"]
def generate_rag(question: str,
k: int = TOP_K,
max_new_tokens: int = 128,
temperature: float = 0.6,
top_p: float = 0.9):
"""
RAG: retrieve top-k chunks, then build a prompt that *forces* the model
to use only the provided context (and say "I don't know" if missing).
Returns (answer, retrieved_hits).
"""
if not question.strip():
return "Please enter a question.", []
# 1) Retrieve
hits = rag.retrieve(question, k=k)
if not hits:
context = ""
else:
# Pretty-print with indices so students can see the grounding
context = "\n\n".join([f"[{i+1}] {c}" for i, (_, c) in enumerate(hits)])
# 2) Build grounded prompt
prompt = (
"You are a careful assistant. Use ONLY the context to answer. "
"If the answer is not in the context, say you don't know.\n\n"
f"Context:\n{context}\n\nQuestion: {question.strip()}\nAnswer:"
)
# 3) Generate
out = generator(
prompt,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
)
answer = out[0]["generated_text"]
return answer, hits
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
gr.Markdown(
"# πŸ”Ž Retrieval-Augmented Generation (RAG) β€” Mini Demo\n"
"Ask a question on the right. Compare **LLM-only** vs **RAG-grounded** answers. "
"Add your own documents on the left and re-ask your question.\n\n"
"_Tip: keep answers short for CPU. This demo may be incorrect; always verify facts._"
)
with gr.Row():
# Left column: manage the corpus (paste/upload and index)
with gr.Column(scale=1):
gr.Markdown("### πŸ“š Corpus\nPaste text or upload .txt/.md to add to the knowledge base.")
paste_box = gr.Textbox(lines=8, label="Paste text (optional)")
upload = gr.File(label="Upload .txt or .md", file_types=[".txt", ".md"], file_count="multiple")
add_btn = gr.Button("Add to Corpus", variant="secondary")
corpus_count = gr.Markdown(f"**Chunks indexed:** {len(rag.corpus_chunks)}")
# Right column: Q&A with two panels (LLM-only vs RAG)
with gr.Column(scale=2):
question = gr.Textbox(label="Your question",
placeholder="Example: What GPA do I need for Latin honors?",
lines=3)
with gr.Row():
# LLM-only panel
with gr.Column():
gr.Markdown("#### πŸ€– LLM-Only")
max_new_llm = gr.Slider(32, 256, value=128, step=8, label="Max new tokens")
temp_llm = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature")
topp_llm = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
llm_btn = gr.Button("Generate (LLM-Only)")
llm_out = gr.Textbox(label="LLM-Only Answer", lines=8)
# RAG panel
with gr.Column():
gr.Markdown("#### πŸ“Ž RAG-Grounded")
topk = gr.Slider(1, 8, value=3, step=1, label="Top-K chunks")
max_new_rag = gr.Slider(32, 256, value=128, step=8, label="Max new tokens")
temp_rag = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature")
topp_rag = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
rag_btn = gr.Button("Generate (RAG)")
rag_out = gr.Textbox(label="RAG Answer", lines=8)
retrieved = gr.Markdown("") # shows retrieved chunks + scores
# ------------- Button callbacks (Python functions wired to UI) -------------
def _add_to_corpus(pasted: str, files: List[gr.File]) -> str:
"""
Gather pasted text and uploaded files, read/clean them, add to the RAG store,
and return an updated chunk count for the UI label.
"""
docs = []
if pasted and pasted.strip():
docs.append(pasted)
if files:
for f in files:
try:
with open(f.name, "rb") as fh:
content = read_txt_or_md(io.BytesIO(fh.read()), f.name)
if content:
docs.append(content)
except Exception:
# Ignore unreadable files to keep class happy-path smooth
continue
if docs:
rag.add_documents(docs)
return f"**Chunks indexed:** {len(rag.corpus_chunks)}"
def _llm_only(q, mx, t, p):
"""Thin wrapper to pass UI slider values into the LLM-only generator."""
return generate_llm_only(q, mx, t, p)
def _rag(q, k, mx, t, p):
"""
Thin wrapper to invoke RAG, then pretty-print the retrieved chunks
with similarity scores under the answer.
"""
ans, hits = generate_rag(q, k, mx, t, p)
if hits:
md = "##### Retrieved Chunks\n" + "\n".join([f"- (score={score:.3f}) {chunk}" for score, chunk in hits])
else:
md = "_No chunks retrieved._"
return ans, md
# Wire UI events to functions
add_btn.click(_add_to_corpus, inputs=[paste_box, upload], outputs=[corpus_count])
llm_btn.click(_llm_only, inputs=[question, max_new_llm, temp_llm, topp_llm], outputs=[llm_out])
rag_btn.click(_rag, inputs=[question, topk, max_new_rag, temp_rag, topp_rag], outputs=[rag_out, retrieved])
# Standard Gradio launcher
if __name__ == "__main__":
demo.launch()