DrDavis commited on
Commit
ed943e3
Β·
verified Β·
1 Parent(s): f3e76de

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +350 -0
app.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Mini Demo (CPU-friendly)
3
+ ----------------------------
4
+ This Gradio app shows side-by-side answers from:
5
+ 1) LLM-Only β†’ the model answers directly from the question
6
+ 2) RAG β†’ the model answers using retrieved context from a small corpus
7
+
8
+ Stack (all CPU-friendly):
9
+ - sentence-transformers/all-MiniLM-L6-v2 for embeddings (vector representations)
10
+ - FAISS (CPU) for fast similarity search over vectors
11
+ - google/flan-t5-small for generation
12
+ - Gradio for the web UI
13
+ """
14
+
15
+ import gradio as gr
16
+ import os, io, re, faiss
17
+ from typing import List, Tuple
18
+ from dataclasses import dataclass
19
+
20
+ # Embedding model (turns text β†’ vectors)
21
+ from sentence_transformers import SentenceTransformer
22
+ # Text generation pipeline (small, instruction-friendly model)
23
+ from transformers import pipeline
24
+
25
+ # ----------------------------
26
+ # App configuration (easy knobs)
27
+ # ----------------------------
28
+ EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # small, high-quality sentence embeddings
29
+ GEN_MODEL_ID = "google/flan-t5-small" # tiny generator for CPU Spaces
30
+
31
+ # Chunking settings for splitting long documents
32
+ CHUNK_SIZE = 500 # characters per chunk (teaching default)
33
+ CHUNK_OVERLAP = 100 # characters of overlap between consecutive chunks
34
+ TOP_K = 3 # how many chunks to retrieve for the RAG prompt
35
+
36
+ # ----------------------------
37
+ # Utility functions
38
+ # ----------------------------
39
+ def normalize_ws(text: str) -> str:
40
+ """
41
+ Normalize whitespace so we don't store noisy text.
42
+ Replaces multiple spaces/newlines with a single space, strips ends.
43
+ """
44
+ return re.sub(r"\s+", " ", text).strip()
45
+
46
+ def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
47
+ """
48
+ Split long text into overlapping chunks so that retrieval can match smaller sections.
49
+ Overlap helps avoid 'boundary' problems where a key sentence is split between two chunks.
50
+ """
51
+ text = normalize_ws(text)
52
+ if len(text) <= chunk_size:
53
+ return [text]
54
+
55
+ chunks = []
56
+ start = 0
57
+ while start < len(text):
58
+ end = min(len(text), start + chunk_size)
59
+ chunks.append(text[start:end])
60
+ if end == len(text):
61
+ break
62
+ # move the window forward, but keep 'overlap' characters of the previous chunk
63
+ start = max(0, end - overlap)
64
+ return chunks
65
+
66
+ def read_txt_or_md(file_obj: io.BytesIO, filename: str) -> str:
67
+ """
68
+ Read .txt or .md files as UTF-8 text.
69
+ We restrict to these formats to keep the demo simple and robust on CPU Spaces.
70
+ """
71
+ ext = os.path.splitext(filename.lower())[1]
72
+ if ext not in [".txt", ".md"]:
73
+ return ""
74
+ try:
75
+ content = file_obj.read().decode("utf-8", errors="ignore")
76
+ return content
77
+ except Exception:
78
+ return ""
79
+
80
+ # ----------------------------
81
+ # RAG store: Keeps chunks + FAISS index
82
+ # ----------------------------
83
+ @dataclass
84
+ class RAGStore:
85
+ """
86
+ Holds everything needed for retrieval:
87
+ - Original docs and chunked docs
88
+ - The embedding model (SentenceTransformer)
89
+ - A FAISS index built over the chunk embeddings
90
+ - A local copy of embeddings for possible future use (not strictly required)
91
+ """
92
+ corpus_docs: List[str] # raw documents for bookkeeping (not used in retrieval)
93
+ corpus_chunks: List[str] # chunked strings actually used for retrieval
94
+ embedder: SentenceTransformer # embedding model
95
+ d: int # embedding dimension
96
+ index: faiss.IndexFlatIP # FAISS index (Inner Product = cosine when normalized)
97
+ matrix: any # numpy array of embeddings for all chunks
98
+
99
+ @classmethod
100
+ def create(cls, embedder: SentenceTransformer):
101
+ """
102
+ Build a RAGStore with a tiny seed corpus so the Space works 'out of the box'.
103
+ Students can add more docs later via the UI.
104
+ """
105
+ seed_docs = [
106
+ "Graduation Honors Policy: Students who graduate with a GPA of 3.75 or higher are eligible for Latin honors as specified by the university catalog.",
107
+ "Add/Drop Deadline: The last day to drop a full-semester class without a grade penalty is the end of week 10, unless otherwise specified by the academic calendar.",
108
+ "Library Hours: During fall and spring semesters, the main library is open from 8am to 10pm Monday through Thursday."
109
+ ]
110
+
111
+ # Chunk the seed docs
112
+ chunks = []
113
+ for doc in seed_docs:
114
+ chunks.extend(chunk_text(doc))
115
+
116
+ # Embed all chunks (normalize to enable cosine similarity via Inner Product)
117
+ embeds = embedder.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
118
+
119
+ # Build a FAISS index: IndexFlatIP = inner product (dot product)
120
+ # With normalized vectors, dot product == cosine similarity
121
+ d = embeds.shape[1]
122
+ index = faiss.IndexFlatIP(d)
123
+ index.add(embeds)
124
+
125
+ return cls(
126
+ corpus_docs=seed_docs,
127
+ corpus_chunks=chunks,
128
+ embedder=embedder,
129
+ d=d,
130
+ index=index,
131
+ matrix=embeds
132
+ )
133
+
134
+ def add_documents(self, new_docs: List[str]):
135
+ """
136
+ Add new documents to the store:
137
+ 1) Clean and append to corpus
138
+ 2) Chunk
139
+ 3) Embed
140
+ 4) Add embeddings to FAISS and local matrix
141
+ """
142
+ clean = [normalize_ws(x) for x in new_docs if x and normalize_ws(x)]
143
+ if not clean:
144
+ return
145
+
146
+ self.corpus_docs.extend(clean)
147
+
148
+ # Re-chunk new docs
149
+ new_chunks = []
150
+ for doc in clean:
151
+ new_chunks.extend(chunk_text(doc))
152
+ if not new_chunks:
153
+ return
154
+
155
+ # Embed and add to FAISS
156
+ new_embeds = self.embedder.encode(new_chunks, convert_to_numpy=True, normalize_embeddings=True)
157
+ self.index.add(new_embeds)
158
+
159
+ # Also update our local embedding matrix and chunk list
160
+ import numpy as np
161
+ self.matrix = np.vstack([self.matrix, new_embeds]) if self.matrix is not None else new_embeds
162
+ self.corpus_chunks.extend(new_chunks)
163
+
164
+ def retrieve(self, query: str, k: int = TOP_K) -> List[Tuple[float, str]]:
165
+ """
166
+ Retrieve top-k chunks for a user query.
167
+ Steps:
168
+ a) Embed the query
169
+ b) Search FAISS for nearest chunk vectors
170
+ c) Return (score, chunk_text) pairs
171
+ """
172
+ if not query.strip() or len(self.corpus_chunks) == 0:
173
+ return []
174
+
175
+ q = self.embedder.encode([normalize_ws(query)], convert_to_numpy=True, normalize_embeddings=True)
176
+ scores, idxs = self.index.search(q, min(k, len(self.corpus_chunks)))
177
+
178
+ hits = []
179
+ for score, idx in zip(scores[0], idxs[0]):
180
+ if idx == -1: # safety if FAISS returns -1
181
+ continue
182
+ hits.append((float(score), self.corpus_chunks[idx]))
183
+ return hits
184
+
185
+ # ----------------------------
186
+ # Build models (loaded once at startup)
187
+ # ----------------------------
188
+ embedder = SentenceTransformer(EMBED_MODEL_ID)
189
+ rag = RAGStore.create(embedder)
190
+
191
+ # Generator: FLAN-T5 small for CPU
192
+ generator = pipeline("text2text-generation", model=GEN_MODEL_ID)
193
+
194
+ # ----------------------------
195
+ # Generation helpers
196
+ # ----------------------------
197
+ def generate_llm_only(question: str,
198
+ max_new_tokens: int = 128,
199
+ temperature: float = 0.6,
200
+ top_p: float = 0.9) -> str:
201
+ """
202
+ LLM-only: send the question directly to the generator without context.
203
+ This is our baseline; can hallucinate if question requires specific facts.
204
+ """
205
+ if not question.strip():
206
+ return "Please enter a question."
207
+ out = generator(
208
+ question.strip(),
209
+ max_new_tokens=int(max_new_tokens),
210
+ do_sample=True,
211
+ temperature=float(temperature),
212
+ top_p=float(top_p),
213
+ )
214
+ return out[0]["generated_text"]
215
+
216
+ def generate_rag(question: str,
217
+ k: int = TOP_K,
218
+ max_new_tokens: int = 128,
219
+ temperature: float = 0.6,
220
+ top_p: float = 0.9):
221
+ """
222
+ RAG: retrieve top-k chunks, then build a prompt that *forces* the model
223
+ to use only the provided context (and say "I don't know" if missing).
224
+ Returns (answer, retrieved_hits).
225
+ """
226
+ if not question.strip():
227
+ return "Please enter a question.", []
228
+
229
+ # 1) Retrieve
230
+ hits = rag.retrieve(question, k=k)
231
+ if not hits:
232
+ context = ""
233
+ else:
234
+ # Pretty-print with indices so students can see the grounding
235
+ context = "\n\n".join([f"[{i+1}] {c}" for i, (_, c) in enumerate(hits)])
236
+
237
+ # 2) Build grounded prompt
238
+ prompt = (
239
+ "You are a careful assistant. Use ONLY the context to answer. "
240
+ "If the answer is not in the context, say you don't know.\n\n"
241
+ f"Context:\n{context}\n\nQuestion: {question.strip()}\nAnswer:"
242
+ )
243
+
244
+ # 3) Generate
245
+ out = generator(
246
+ prompt,
247
+ max_new_tokens=int(max_new_tokens),
248
+ do_sample=True,
249
+ temperature=float(temperature),
250
+ top_p=float(top_p),
251
+ )
252
+ answer = out[0]["generated_text"]
253
+ return answer, hits
254
+
255
+ # ----------------------------
256
+ # Gradio UI
257
+ # ----------------------------
258
+ with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
259
+ gr.Markdown(
260
+ "# πŸ”Ž Retrieval-Augmented Generation (RAG) β€” Mini Demo\n"
261
+ "Ask a question on the right. Compare **LLM-only** vs **RAG-grounded** answers. "
262
+ "Add your own documents on the left and re-ask your question.\n\n"
263
+ "_Tip: keep answers short for CPU. This demo may be incorrect; always verify facts._"
264
+ )
265
+
266
+ with gr.Row():
267
+ # Left column: manage the corpus (paste/upload and index)
268
+ with gr.Column(scale=1):
269
+ gr.Markdown("### πŸ“š Corpus\nPaste text or upload .txt/.md to add to the knowledge base.")
270
+ paste_box = gr.Textbox(lines=8, label="Paste text (optional)")
271
+ upload = gr.File(label="Upload .txt or .md", file_types=[".txt", ".md"], file_count="multiple")
272
+ add_btn = gr.Button("Add to Corpus", variant="secondary")
273
+ corpus_count = gr.Markdown(f"**Chunks indexed:** {len(rag.corpus_chunks)}")
274
+
275
+ # Right column: Q&A with two panels (LLM-only vs RAG)
276
+ with gr.Column(scale=2):
277
+ question = gr.Textbox(label="Your question",
278
+ placeholder="Example: What GPA do I need for Latin honors?",
279
+ lines=3)
280
+
281
+ with gr.Row():
282
+ # LLM-only panel
283
+ with gr.Column():
284
+ gr.Markdown("#### πŸ€– LLM-Only")
285
+ max_new_llm = gr.Slider(32, 256, value=128, step=8, label="Max new tokens")
286
+ temp_llm = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature")
287
+ topp_llm = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
288
+ llm_btn = gr.Button("Generate (LLM-Only)")
289
+ llm_out = gr.Textbox(label="LLM-Only Answer", lines=8)
290
+
291
+ # RAG panel
292
+ with gr.Column():
293
+ gr.Markdown("#### πŸ“Ž RAG-Grounded")
294
+ topk = gr.Slider(1, 8, value=3, step=1, label="Top-K chunks")
295
+ max_new_rag = gr.Slider(32, 256, value=128, step=8, label="Max new tokens")
296
+ temp_rag = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature")
297
+ topp_rag = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
298
+ rag_btn = gr.Button("Generate (RAG)")
299
+ rag_out = gr.Textbox(label="RAG Answer", lines=8)
300
+ retrieved = gr.Markdown("") # shows retrieved chunks + scores
301
+
302
+ # ------------- Button callbacks (Python functions wired to UI) -------------
303
+ def _add_to_corpus(pasted: str, files: List[gr.File]) -> str:
304
+ """
305
+ Gather pasted text and uploaded files, read/clean them, add to the RAG store,
306
+ and return an updated chunk count for the UI label.
307
+ """
308
+ docs = []
309
+ if pasted and pasted.strip():
310
+ docs.append(pasted)
311
+
312
+ if files:
313
+ for f in files:
314
+ try:
315
+ with open(f.name, "rb") as fh:
316
+ content = read_txt_or_md(io.BytesIO(fh.read()), f.name)
317
+ if content:
318
+ docs.append(content)
319
+ except Exception:
320
+ # Ignore unreadable files to keep class happy-path smooth
321
+ continue
322
+
323
+ if docs:
324
+ rag.add_documents(docs)
325
+ return f"**Chunks indexed:** {len(rag.corpus_chunks)}"
326
+
327
+ def _llm_only(q, mx, t, p):
328
+ """Thin wrapper to pass UI slider values into the LLM-only generator."""
329
+ return generate_llm_only(q, mx, t, p)
330
+
331
+ def _rag(q, k, mx, t, p):
332
+ """
333
+ Thin wrapper to invoke RAG, then pretty-print the retrieved chunks
334
+ with similarity scores under the answer.
335
+ """
336
+ ans, hits = generate_rag(q, k, mx, t, p)
337
+ if hits:
338
+ md = "##### Retrieved Chunks\n" + "\n".join([f"- (score={score:.3f}) {chunk}" for score, chunk in hits])
339
+ else:
340
+ md = "_No chunks retrieved._"
341
+ return ans, md
342
+
343
+ # Wire UI events to functions
344
+ add_btn.click(_add_to_corpus, inputs=[paste_box, upload], outputs=[corpus_count])
345
+ llm_btn.click(_llm_only, inputs=[question, max_new_llm, temp_llm, topp_llm], outputs=[llm_out])
346
+ rag_btn.click(_rag, inputs=[question, topk, max_new_rag, temp_rag, topp_rag], outputs=[rag_out, retrieved])
347
+
348
+ # Standard Gradio launcher
349
+ if __name__ == "__main__":
350
+ demo.launch()