| """ |
| RAG Mini Demo (CPU-friendly) |
| ---------------------------- |
| This Gradio app shows side-by-side answers from: |
| 1) LLM-Only β the model answers directly from the question |
| 2) RAG β the model answers using retrieved context from a small corpus |
| |
| Stack (all CPU-friendly): |
| - sentence-transformers/all-MiniLM-L6-v2 for embeddings (vector representations) |
| - FAISS (CPU) for fast similarity search over vectors |
| - google/flan-t5-small for generation |
| - Gradio for the web UI |
| """ |
|
|
| import gradio as gr |
| import os, io, re, faiss |
| from typing import List, Tuple |
| from dataclasses import dataclass |
|
|
| |
| from sentence_transformers import SentenceTransformer |
| |
| from transformers import pipeline |
|
|
| |
| |
| |
| EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" |
| GEN_MODEL_ID = "google/flan-t5-small" |
|
|
| |
| CHUNK_SIZE = 500 |
| CHUNK_OVERLAP = 100 |
| TOP_K = 3 |
|
|
| |
| |
| |
| def normalize_ws(text: str) -> str: |
| """ |
| Normalize whitespace so we don't store noisy text. |
| Replaces multiple spaces/newlines with a single space, strips ends. |
| """ |
| return re.sub(r"\s+", " ", text).strip() |
|
|
| def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]: |
| """ |
| Split long text into overlapping chunks so that retrieval can match smaller sections. |
| Overlap helps avoid 'boundary' problems where a key sentence is split between two chunks. |
| """ |
| text = normalize_ws(text) |
| if len(text) <= chunk_size: |
| return [text] |
|
|
| chunks = [] |
| start = 0 |
| while start < len(text): |
| end = min(len(text), start + chunk_size) |
| chunks.append(text[start:end]) |
| if end == len(text): |
| break |
| |
| start = max(0, end - overlap) |
| return chunks |
|
|
| def read_txt_or_md(file_obj: io.BytesIO, filename: str) -> str: |
| """ |
| Read .txt or .md files as UTF-8 text. |
| We restrict to these formats to keep the demo simple and robust on CPU Spaces. |
| """ |
| ext = os.path.splitext(filename.lower())[1] |
| if ext not in [".txt", ".md"]: |
| return "" |
| try: |
| content = file_obj.read().decode("utf-8", errors="ignore") |
| return content |
| except Exception: |
| return "" |
|
|
| |
| |
| |
| @dataclass |
| class RAGStore: |
| """ |
| Holds everything needed for retrieval: |
| - Original docs and chunked docs |
| - The embedding model (SentenceTransformer) |
| - A FAISS index built over the chunk embeddings |
| - A local copy of embeddings for possible future use (not strictly required) |
| """ |
| corpus_docs: List[str] |
| corpus_chunks: List[str] |
| embedder: SentenceTransformer |
| d: int |
| index: faiss.IndexFlatIP |
| matrix: any |
|
|
| @classmethod |
| def create(cls, embedder: SentenceTransformer): |
| """ |
| Build a RAGStore with a tiny seed corpus so the Space works 'out of the box'. |
| Students can add more docs later via the UI. |
| """ |
| seed_docs = [ |
| "Graduation Honors Policy: Students who graduate with a GPA of 3.75 or higher are eligible for Latin honors as specified by the university catalog.", |
| "Add/Drop Deadline: The last day to drop a full-semester class without a grade penalty is the end of week 10, unless otherwise specified by the academic calendar.", |
| "Library Hours: During fall and spring semesters, the main library is open from 8am to 10pm Monday through Thursday." |
| ] |
|
|
| |
| chunks = [] |
| for doc in seed_docs: |
| chunks.extend(chunk_text(doc)) |
|
|
| |
| embeds = embedder.encode(chunks, convert_to_numpy=True, normalize_embeddings=True) |
|
|
| |
| |
| d = embeds.shape[1] |
| index = faiss.IndexFlatIP(d) |
| index.add(embeds) |
|
|
| return cls( |
| corpus_docs=seed_docs, |
| corpus_chunks=chunks, |
| embedder=embedder, |
| d=d, |
| index=index, |
| matrix=embeds |
| ) |
|
|
| def add_documents(self, new_docs: List[str]): |
| """ |
| Add new documents to the store: |
| 1) Clean and append to corpus |
| 2) Chunk |
| 3) Embed |
| 4) Add embeddings to FAISS and local matrix |
| """ |
| clean = [normalize_ws(x) for x in new_docs if x and normalize_ws(x)] |
| if not clean: |
| return |
|
|
| self.corpus_docs.extend(clean) |
|
|
| |
| new_chunks = [] |
| for doc in clean: |
| new_chunks.extend(chunk_text(doc)) |
| if not new_chunks: |
| return |
|
|
| |
| new_embeds = self.embedder.encode(new_chunks, convert_to_numpy=True, normalize_embeddings=True) |
| self.index.add(new_embeds) |
|
|
| |
| import numpy as np |
| self.matrix = np.vstack([self.matrix, new_embeds]) if self.matrix is not None else new_embeds |
| self.corpus_chunks.extend(new_chunks) |
|
|
| def retrieve(self, query: str, k: int = TOP_K) -> List[Tuple[float, str]]: |
| """ |
| Retrieve top-k chunks for a user query. |
| Steps: |
| a) Embed the query |
| b) Search FAISS for nearest chunk vectors |
| c) Return (score, chunk_text) pairs |
| """ |
| if not query.strip() or len(self.corpus_chunks) == 0: |
| return [] |
|
|
| q = self.embedder.encode([normalize_ws(query)], convert_to_numpy=True, normalize_embeddings=True) |
| scores, idxs = self.index.search(q, min(k, len(self.corpus_chunks))) |
|
|
| hits = [] |
| for score, idx in zip(scores[0], idxs[0]): |
| if idx == -1: |
| continue |
| hits.append((float(score), self.corpus_chunks[idx])) |
| return hits |
|
|
| |
| |
| |
| embedder = SentenceTransformer(EMBED_MODEL_ID) |
| rag = RAGStore.create(embedder) |
|
|
| |
| generator = pipeline("text2text-generation", model=GEN_MODEL_ID) |
|
|
| |
| |
| |
| def generate_llm_only(question: str, |
| max_new_tokens: int = 128, |
| temperature: float = 0.6, |
| top_p: float = 0.9) -> str: |
| """ |
| LLM-only: send the question directly to the generator without context. |
| This is our baseline; can hallucinate if question requires specific facts. |
| """ |
| if not question.strip(): |
| return "Please enter a question." |
| out = generator( |
| question.strip(), |
| max_new_tokens=int(max_new_tokens), |
| do_sample=True, |
| temperature=float(temperature), |
| top_p=float(top_p), |
| ) |
| return out[0]["generated_text"] |
|
|
| def generate_rag(question: str, |
| k: int = TOP_K, |
| max_new_tokens: int = 128, |
| temperature: float = 0.6, |
| top_p: float = 0.9): |
| """ |
| RAG: retrieve top-k chunks, then build a prompt that *forces* the model |
| to use only the provided context (and say "I don't know" if missing). |
| Returns (answer, retrieved_hits). |
| """ |
| if not question.strip(): |
| return "Please enter a question.", [] |
|
|
| |
| hits = rag.retrieve(question, k=k) |
| if not hits: |
| context = "" |
| else: |
| |
| context = "\n\n".join([f"[{i+1}] {c}" for i, (_, c) in enumerate(hits)]) |
|
|
| |
| prompt = ( |
| "You are a careful assistant. Use ONLY the context to answer. " |
| "If the answer is not in the context, say you don't know.\n\n" |
| f"Context:\n{context}\n\nQuestion: {question.strip()}\nAnswer:" |
| ) |
|
|
| |
| out = generator( |
| prompt, |
| max_new_tokens=int(max_new_tokens), |
| do_sample=True, |
| temperature=float(temperature), |
| top_p=float(top_p), |
| ) |
| answer = out[0]["generated_text"] |
| return answer, hits |
|
|
| |
| |
| |
| with gr.Blocks(fill_height=True, analytics_enabled=False) as demo: |
| gr.Markdown( |
| "# π Retrieval-Augmented Generation (RAG) β Mini Demo\n" |
| "Ask a question on the right. Compare **LLM-only** vs **RAG-grounded** answers. " |
| "Add your own documents on the left and re-ask your question.\n\n" |
| "_Tip: keep answers short for CPU. This demo may be incorrect; always verify facts._" |
| ) |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### π Corpus\nPaste text or upload .txt/.md to add to the knowledge base.") |
| paste_box = gr.Textbox(lines=8, label="Paste text (optional)") |
| upload = gr.File(label="Upload .txt or .md", file_types=[".txt", ".md"], file_count="multiple") |
| add_btn = gr.Button("Add to Corpus", variant="secondary") |
| corpus_count = gr.Markdown(f"**Chunks indexed:** {len(rag.corpus_chunks)}") |
|
|
| |
| with gr.Column(scale=2): |
| question = gr.Textbox(label="Your question", |
| placeholder="Example: What GPA do I need for Latin honors?", |
| lines=3) |
|
|
| with gr.Row(): |
| |
| with gr.Column(): |
| gr.Markdown("#### π€ LLM-Only") |
| max_new_llm = gr.Slider(32, 256, value=128, step=8, label="Max new tokens") |
| temp_llm = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature") |
| topp_llm = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") |
| llm_btn = gr.Button("Generate (LLM-Only)") |
| llm_out = gr.Textbox(label="LLM-Only Answer", lines=8) |
|
|
| |
| with gr.Column(): |
| gr.Markdown("#### π RAG-Grounded") |
| topk = gr.Slider(1, 8, value=3, step=1, label="Top-K chunks") |
| max_new_rag = gr.Slider(32, 256, value=128, step=8, label="Max new tokens") |
| temp_rag = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature") |
| topp_rag = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") |
| rag_btn = gr.Button("Generate (RAG)") |
| rag_out = gr.Textbox(label="RAG Answer", lines=8) |
| retrieved = gr.Markdown("") |
|
|
| |
| def _add_to_corpus(pasted: str, files: List[gr.File]) -> str: |
| """ |
| Gather pasted text and uploaded files, read/clean them, add to the RAG store, |
| and return an updated chunk count for the UI label. |
| """ |
| docs = [] |
| if pasted and pasted.strip(): |
| docs.append(pasted) |
|
|
| if files: |
| for f in files: |
| try: |
| with open(f.name, "rb") as fh: |
| content = read_txt_or_md(io.BytesIO(fh.read()), f.name) |
| if content: |
| docs.append(content) |
| except Exception: |
| |
| continue |
|
|
| if docs: |
| rag.add_documents(docs) |
| return f"**Chunks indexed:** {len(rag.corpus_chunks)}" |
|
|
| def _llm_only(q, mx, t, p): |
| """Thin wrapper to pass UI slider values into the LLM-only generator.""" |
| return generate_llm_only(q, mx, t, p) |
|
|
| def _rag(q, k, mx, t, p): |
| """ |
| Thin wrapper to invoke RAG, then pretty-print the retrieved chunks |
| with similarity scores under the answer. |
| """ |
| ans, hits = generate_rag(q, k, mx, t, p) |
| if hits: |
| md = "##### Retrieved Chunks\n" + "\n".join([f"- (score={score:.3f}) {chunk}" for score, chunk in hits]) |
| else: |
| md = "_No chunks retrieved._" |
| return ans, md |
|
|
| |
| add_btn.click(_add_to_corpus, inputs=[paste_box, upload], outputs=[corpus_count]) |
| llm_btn.click(_llm_only, inputs=[question, max_new_llm, temp_llm, topp_llm], outputs=[llm_out]) |
| rag_btn.click(_rag, inputs=[question, topk, max_new_rag, temp_rag, topp_rag], outputs=[rag_out, retrieved]) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|