Phoenixak99 commited on
Commit
134e302
·
verified ·
1 Parent(s): cca01b4

Fix text-generation pipeline and move chat_input to top level

Browse files
Files changed (6) hide show
  1. Dockerfile +19 -0
  2. README.md +82 -6
  3. app.py +815 -0
  4. demo_instructions.md +27 -0
  5. requirements.txt +11 -0
  6. sample_documents/sample.txt +13 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY requirements.txt ./
11
+ RUN pip3 install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . .
14
+
15
+ EXPOSE 8501
16
+
17
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
18
+
19
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -1,12 +1,88 @@
1
  ---
2
  title: RAG System Demo
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: RAG System Demo
3
+ emoji: "\U0001F917"
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: "1.28.1"
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # RAG System Demo
13
+
14
+ A fully functional Retrieval-Augmented Generation (RAG) system built with open-source Hugging Face models. Upload your documents, ask questions, and get AI-generated answers grounded in your content -- with full source attribution.
15
+
16
+ ## What It Does
17
+
18
+ This demo implements a complete RAG pipeline:
19
+
20
+ 1. **Document Ingestion** -- Upload PDF, TXT, DOCX, or CSV files. Text is extracted and split into overlapping chunks.
21
+ 2. **Semantic Indexing** -- Chunks are embedded with a sentence-transformer model and stored in an in-memory ChromaDB vector store.
22
+ 3. **Retrieval** -- When you ask a question, the most semantically similar chunks are retrieved using cosine similarity.
23
+ 4. **Generation** -- Retrieved context is passed to a language model that generates a grounded answer.
24
+
25
+ ## Features
26
+
27
+ - Multi-format document upload (PDF, TXT, DOCX, CSV)
28
+ - Semantic search with relevance scoring
29
+ - AI-powered question answering over your documents
30
+ - Source attribution with similarity scores
31
+ - Chat-style interface with conversation history
32
+ - Sample document included for quick testing
33
+
34
+ ## Models Used
35
+
36
+ | Component | Model | Purpose |
37
+ |-----------|-------|---------|
38
+ | Text Generation | `google/flan-t5-small` | Instruction-following seq2seq model for Q&A |
39
+ | Embeddings | `sentence-transformers/all-MiniLM-L6-v2` | Dense vector embeddings for semantic search |
40
+ | Vector Store | ChromaDB (in-memory) | Fast approximate nearest neighbor search |
41
+
42
+ ## Main Repository
43
+
44
+ This Hugging Face Space is a live demo for the full RAG System project:
45
+
46
+ [https://github.com/Phoenixak99/RAG-System](https://github.com/Phoenixak99/RAG-System)
47
+
48
+ ## Running Locally
49
+
50
+ ```bash
51
+ # Clone the repository
52
+ git clone https://github.com/Phoenixak99/RAG-System.git
53
+ cd RAG-System/hf_space
54
+
55
+ # Install dependencies
56
+ pip install -r requirements.txt
57
+
58
+ # Run the Streamlit app
59
+ streamlit run app.py
60
+ ```
61
+
62
+ Or use Docker:
63
+
64
+ ```bash
65
+ docker build -t rag-demo .
66
+ docker run -p 8501:8501 rag-demo
67
+ ```
68
+
69
+ Then open [http://localhost:8501](http://localhost:8501) in your browser.
70
+
71
+ ## Architecture
72
+
73
+ ```
74
+ User Query
75
+ |
76
+ v
77
+ [Embedding Model] --> Query Vector
78
+ |
79
+ v
80
+ [ChromaDB] --> Top-K Similar Chunks
81
+ |
82
+ v
83
+ [flan-t5-small] --> Generated Answer + Source Attribution
84
+ ```
85
+
86
+ ## License
87
+
88
+ MIT License -- see the [main repository](https://github.com/Phoenixak99/RAG-System) for details.
app.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG System Demo - Streamlit Application
3
+ A self-contained RAG (Retrieval-Augmented Generation) demo using
4
+ google/flan-t5-small for text generation and sentence-transformers
5
+ for semantic search with ChromaDB as the vector store.
6
+ """
7
+
8
+ import io
9
+ import os
10
+ import uuid
11
+ import logging
12
+ from typing import Optional
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import streamlit as st
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Custom CSS
24
+ # ---------------------------------------------------------------------------
25
+ CUSTOM_CSS = """
26
+ <style>
27
+ /* General layout */
28
+ .main .block-container {
29
+ padding-top: 2rem;
30
+ padding-bottom: 2rem;
31
+ }
32
+
33
+ /* Source document cards */
34
+ .source-card {
35
+ background: var(--background-secondary, #f8f9fa);
36
+ border: 1px solid var(--border-color, #dee2e6);
37
+ border-radius: 8px;
38
+ padding: 12px 16px;
39
+ margin-bottom: 10px;
40
+ font-size: 0.9em;
41
+ }
42
+
43
+ .source-card .source-header {
44
+ font-weight: 600;
45
+ color: var(--text-color, #1a1a2e);
46
+ margin-bottom: 4px;
47
+ display: flex;
48
+ justify-content: space-between;
49
+ align-items: center;
50
+ }
51
+
52
+ .source-card .relevance-badge {
53
+ background: #4361ee;
54
+ color: white;
55
+ padding: 2px 8px;
56
+ border-radius: 12px;
57
+ font-size: 0.8em;
58
+ font-weight: 500;
59
+ }
60
+
61
+ .source-card .source-text {
62
+ color: var(--text-color-secondary, #495057);
63
+ font-size: 0.88em;
64
+ line-height: 1.5;
65
+ margin-top: 6px;
66
+ }
67
+
68
+ /* Document list in sidebar */
69
+ .doc-item {
70
+ background: var(--background-secondary, #f0f2f6);
71
+ border-radius: 6px;
72
+ padding: 8px 12px;
73
+ margin-bottom: 6px;
74
+ font-size: 0.85em;
75
+ }
76
+
77
+ /* Footer */
78
+ .footer {
79
+ text-align: center;
80
+ padding: 20px 0 10px 0;
81
+ font-size: 0.8em;
82
+ color: var(--text-color-secondary, #6c757d);
83
+ border-top: 1px solid var(--border-color, #dee2e6);
84
+ margin-top: 40px;
85
+ }
86
+
87
+ /* Chat area spacing */
88
+ .stChatMessage {
89
+ margin-bottom: 8px;
90
+ }
91
+
92
+ /* Dark mode support */
93
+ @media (prefers-color-scheme: dark) {
94
+ .source-card {
95
+ background: #1e1e2e;
96
+ border-color: #383850;
97
+ }
98
+ .source-card .source-header {
99
+ color: #e0e0e0;
100
+ }
101
+ .source-card .source-text {
102
+ color: #b0b0c0;
103
+ }
104
+ .doc-item {
105
+ background: #1e1e2e;
106
+ }
107
+ }
108
+ </style>
109
+ """
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # RAG System
114
+ # ---------------------------------------------------------------------------
115
+ class RAGSystem:
116
+ """Core RAG pipeline: document processing, retrieval, and generation."""
117
+
118
+ def __init__(self) -> None:
119
+ """Initialize the RAG system (models loaded separately via cache)."""
120
+ self.chunk_size = 500
121
+ self.chunk_overlap = 50
122
+
123
+ # ------------------------------------------------------------------
124
+ # Model loading (cached by Streamlit)
125
+ # ------------------------------------------------------------------
126
+ @staticmethod
127
+ @st.cache_resource
128
+ def load_models() -> tuple:
129
+ """Load flan-t5-small and sentence-transformers.
130
+
131
+ Returns:
132
+ Tuple of (text-generation pipeline, SentenceTransformer model).
133
+ """
134
+ from transformers import pipeline as hf_pipeline
135
+ from sentence_transformers import SentenceTransformer
136
+
137
+ with st.spinner("Loading language model (flan-t5-small)..."):
138
+ llm = hf_pipeline(
139
+ "text-generation",
140
+ model="google/flan-t5-small",
141
+ max_new_tokens=200,
142
+ do_sample=False,
143
+ )
144
+
145
+ with st.spinner("Loading embedding model (all-MiniLM-L6-v2)..."):
146
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
147
+
148
+ return llm, embedder
149
+
150
+ # ------------------------------------------------------------------
151
+ # Vector store
152
+ # ------------------------------------------------------------------
153
+ @staticmethod
154
+ def setup_vector_store():
155
+ """Create an in-memory ChromaDB client and collection.
156
+
157
+ Returns:
158
+ Tuple of (chromadb.Client, Collection).
159
+ """
160
+ import chromadb
161
+
162
+ client = chromadb.Client()
163
+ collection = client.get_or_create_collection(
164
+ name="rag_documents",
165
+ metadata={"hnsw:space": "cosine"},
166
+ )
167
+ return client, collection
168
+
169
+ # ------------------------------------------------------------------
170
+ # Text extraction
171
+ # ------------------------------------------------------------------
172
+ def extract_text(self, uploaded_file) -> dict:
173
+ """Extract text content from an uploaded file.
174
+
175
+ Args:
176
+ uploaded_file: Streamlit UploadedFile object.
177
+
178
+ Returns:
179
+ Dict with keys 'filename', 'text', and 'type'.
180
+
181
+ Raises:
182
+ ValueError: If the file type is unsupported or extraction fails.
183
+ """
184
+ filename = uploaded_file.name
185
+ file_ext = os.path.splitext(filename)[1].lower()
186
+
187
+ if file_ext == ".txt":
188
+ text = self._extract_txt(uploaded_file)
189
+ elif file_ext == ".pdf":
190
+ text = self._extract_pdf(uploaded_file)
191
+ elif file_ext == ".docx":
192
+ text = self._extract_docx(uploaded_file)
193
+ elif file_ext == ".csv":
194
+ text = self._extract_csv(uploaded_file)
195
+ else:
196
+ raise ValueError(f"Unsupported file type: {file_ext}")
197
+
198
+ if not text or not text.strip():
199
+ raise ValueError(f"No text content could be extracted from {filename}")
200
+
201
+ return {
202
+ "filename": filename,
203
+ "text": text.strip(),
204
+ "type": file_ext.lstrip("."),
205
+ }
206
+
207
+ @staticmethod
208
+ def _extract_txt(uploaded_file) -> str:
209
+ raw = uploaded_file.read()
210
+ for encoding in ("utf-8", "latin-1", "cp1252"):
211
+ try:
212
+ return raw.decode(encoding)
213
+ except (UnicodeDecodeError, AttributeError):
214
+ continue
215
+ return raw.decode("utf-8", errors="replace")
216
+
217
+ @staticmethod
218
+ def _extract_pdf(uploaded_file) -> str:
219
+ try:
220
+ from PyPDF2 import PdfReader
221
+ except ImportError as exc:
222
+ raise ValueError("PyPDF2 is required for PDF processing") from exc
223
+
224
+ reader = PdfReader(io.BytesIO(uploaded_file.read()))
225
+ pages = []
226
+ for page in reader.pages:
227
+ page_text = page.extract_text()
228
+ if page_text:
229
+ pages.append(page_text)
230
+ return "\n\n".join(pages)
231
+
232
+ @staticmethod
233
+ def _extract_docx(uploaded_file) -> str:
234
+ try:
235
+ from docx import Document
236
+ except ImportError as exc:
237
+ raise ValueError("python-docx is required for DOCX processing") from exc
238
+
239
+ doc = Document(io.BytesIO(uploaded_file.read()))
240
+ paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
241
+ return "\n\n".join(paragraphs)
242
+
243
+ @staticmethod
244
+ def _extract_csv(uploaded_file) -> str:
245
+ df = pd.read_csv(uploaded_file)
246
+ rows = []
247
+ for _, row in df.iterrows():
248
+ parts = [f"{col}: {val}" for col, val in row.items() if pd.notna(val)]
249
+ rows.append(". ".join(parts))
250
+ return "\n\n".join(rows)
251
+
252
+ # ------------------------------------------------------------------
253
+ # Chunking
254
+ # ------------------------------------------------------------------
255
+ def chunk_text(
256
+ self, text: str, chunk_size: int = 500, overlap: int = 50
257
+ ) -> list[str]:
258
+ """Split text into overlapping chunks, breaking at sentence boundaries.
259
+
260
+ Args:
261
+ text: The full document text.
262
+ chunk_size: Maximum characters per chunk.
263
+ overlap: Number of overlapping characters between chunks.
264
+
265
+ Returns:
266
+ List of text chunks.
267
+ """
268
+ if not text or not text.strip():
269
+ return []
270
+
271
+ sentences = self._split_sentences(text)
272
+ chunks: list[str] = []
273
+ current_chunk: list[str] = []
274
+ current_length = 0
275
+
276
+ for sentence in sentences:
277
+ sentence = sentence.strip()
278
+ if not sentence:
279
+ continue
280
+
281
+ sentence_len = len(sentence)
282
+
283
+ if current_length + sentence_len > chunk_size and current_chunk:
284
+ chunks.append(" ".join(current_chunk))
285
+
286
+ # Keep tail sentences for overlap
287
+ overlap_chunk: list[str] = []
288
+ overlap_len = 0
289
+ for s in reversed(current_chunk):
290
+ if overlap_len + len(s) > overlap:
291
+ break
292
+ overlap_chunk.insert(0, s)
293
+ overlap_len += len(s)
294
+
295
+ current_chunk = overlap_chunk
296
+ current_length = overlap_len
297
+
298
+ current_chunk.append(sentence)
299
+ current_length += sentence_len
300
+
301
+ if current_chunk:
302
+ chunks.append(" ".join(current_chunk))
303
+
304
+ return chunks
305
+
306
+ @staticmethod
307
+ def _split_sentences(text: str) -> list[str]:
308
+ """Naive sentence splitter on '.', '!', '?'."""
309
+ import re
310
+
311
+ sentences = re.split(r"(?<=[.!?])\s+", text)
312
+ return [s for s in sentences if s.strip()]
313
+
314
+ # ------------------------------------------------------------------
315
+ # Document ingestion
316
+ # ------------------------------------------------------------------
317
+ def add_document(self, doc_data: dict, embedder, collection) -> int:
318
+ """Chunk, embed, and add a document to the vector store.
319
+
320
+ Args:
321
+ doc_data: Dict with 'filename' and 'text'.
322
+ embedder: SentenceTransformer instance.
323
+ collection: ChromaDB collection.
324
+
325
+ Returns:
326
+ Number of chunks added.
327
+ """
328
+ chunks = self.chunk_text(doc_data["text"], self.chunk_size, self.chunk_overlap)
329
+ if not chunks:
330
+ return 0
331
+
332
+ embeddings = embedder.encode(chunks, show_progress_bar=False)
333
+ ids = [f"{doc_data['filename']}_{uuid.uuid4().hex[:8]}" for _ in chunks]
334
+ metadatas = [
335
+ {
336
+ "source": doc_data["filename"],
337
+ "chunk_index": i,
338
+ "total_chunks": len(chunks),
339
+ }
340
+ for i in range(len(chunks))
341
+ ]
342
+
343
+ collection.add(
344
+ ids=ids,
345
+ embeddings=embeddings.tolist(),
346
+ documents=chunks,
347
+ metadatas=metadatas,
348
+ )
349
+ return len(chunks)
350
+
351
+ # ------------------------------------------------------------------
352
+ # Search
353
+ # ------------------------------------------------------------------
354
+ def search(
355
+ self,
356
+ query: str,
357
+ embedder,
358
+ collection,
359
+ n_results: int = 5,
360
+ ) -> list[dict]:
361
+ """Perform semantic search over the vector store.
362
+
363
+ Args:
364
+ query: User query string.
365
+ embedder: SentenceTransformer instance.
366
+ collection: ChromaDB collection.
367
+ n_results: Maximum number of results to return.
368
+
369
+ Returns:
370
+ List of dicts with 'text', 'source', 'similarity', and 'chunk_index'.
371
+ """
372
+ if collection.count() == 0:
373
+ return []
374
+
375
+ query_embedding = embedder.encode([query], show_progress_bar=False)
376
+ actual_n = min(n_results, collection.count())
377
+ results = collection.query(
378
+ query_embeddings=query_embedding.tolist(),
379
+ n_results=actual_n,
380
+ include=["documents", "metadatas", "distances"],
381
+ )
382
+
383
+ formatted: list[dict] = []
384
+ if results and results["documents"]:
385
+ for doc, meta, dist in zip(
386
+ results["documents"][0],
387
+ results["metadatas"][0],
388
+ results["distances"][0],
389
+ ):
390
+ # ChromaDB cosine distance is in [0, 2]; convert to similarity
391
+ similarity = max(0.0, 1.0 - dist)
392
+ formatted.append(
393
+ {
394
+ "text": doc,
395
+ "source": meta.get("source", "Unknown"),
396
+ "similarity": round(similarity, 4),
397
+ "chunk_index": meta.get("chunk_index", 0),
398
+ }
399
+ )
400
+
401
+ formatted.sort(key=lambda x: x["similarity"], reverse=True)
402
+ return formatted
403
+
404
+ # ------------------------------------------------------------------
405
+ # Response generation
406
+ # ------------------------------------------------------------------
407
+ def generate_response(
408
+ self,
409
+ query: str,
410
+ context_docs: list[dict],
411
+ llm_pipeline,
412
+ ) -> str:
413
+ """Generate an answer with flan-t5-small using retrieved context.
414
+
415
+ Args:
416
+ query: User question.
417
+ context_docs: Retrieved documents from search().
418
+ llm_pipeline: HuggingFace text-generation pipeline.
419
+
420
+ Returns:
421
+ Generated answer string.
422
+ """
423
+ if not context_docs:
424
+ return (
425
+ "I don't have any documents to reference. "
426
+ "Please upload documents first, then ask your question."
427
+ )
428
+
429
+ # Build context from top results
430
+ context_parts: list[str] = []
431
+ for doc in context_docs[:3]:
432
+ text = doc["text"][:400]
433
+ context_parts.append(text)
434
+ context = "\n\n".join(context_parts)
435
+
436
+ prompt = (
437
+ "Answer the following question based on the provided context.\n\n"
438
+ f"Context:\n{context}\n\n"
439
+ f"Question: {query}"
440
+ )
441
+
442
+ try:
443
+ result = llm_pipeline(prompt, max_new_tokens=200)
444
+ answer = result[0]["generated_text"].strip()
445
+
446
+ if not answer or len(answer) < 3:
447
+ return self._fallback_response(query, context_docs)
448
+
449
+ return answer
450
+
451
+ except Exception as exc:
452
+ logger.error("Generation error: %s", exc)
453
+ return self._fallback_response(query, context_docs)
454
+
455
+ def _fallback_response(self, query: str, context_docs: list[dict]) -> str:
456
+ """Provide relevant excerpts when the LLM response is inadequate.
457
+
458
+ Args:
459
+ query: User question.
460
+ context_docs: Retrieved documents.
461
+
462
+ Returns:
463
+ Formatted fallback response with source excerpts.
464
+ """
465
+ if not context_docs:
466
+ return "No relevant information found in the uploaded documents."
467
+
468
+ response_parts = ["Here are the most relevant excerpts from your documents:\n"]
469
+ for i, doc in enumerate(context_docs[:3], 1):
470
+ excerpt = doc["text"][:300].strip()
471
+ source = doc["source"]
472
+ score = doc["similarity"]
473
+ response_parts.append(
474
+ f"**Source {i}** ({source}, relevance: {score:.0%}):\n> {excerpt}...\n"
475
+ )
476
+
477
+ return "\n".join(response_parts)
478
+
479
+
480
+ # ---------------------------------------------------------------------------
481
+ # Streamlit UI Helpers
482
+ # ---------------------------------------------------------------------------
483
+ def init_session_state() -> None:
484
+ """Initialize all required Streamlit session state variables."""
485
+ defaults = {
486
+ "documents": [],
487
+ "chat_history": [],
488
+ "embedder": None,
489
+ "llm_pipeline": None,
490
+ "chroma_client": None,
491
+ "collection": None,
492
+ "rag_system": None,
493
+ "models_loaded": False,
494
+ "retrieved_docs": [],
495
+ }
496
+ for key, value in defaults.items():
497
+ if key not in st.session_state:
498
+ st.session_state[key] = value
499
+
500
+
501
+ def load_models_into_state() -> None:
502
+ """Load models and vector store into session state if not already loaded."""
503
+ if st.session_state.models_loaded:
504
+ return
505
+
506
+ rag = RAGSystem()
507
+ st.session_state.rag_system = rag
508
+
509
+ try:
510
+ llm, embedder = RAGSystem.load_models()
511
+ st.session_state.llm_pipeline = llm
512
+ st.session_state.embedder = embedder
513
+
514
+ client, collection = rag.setup_vector_store()
515
+ st.session_state.chroma_client = client
516
+ st.session_state.collection = collection
517
+ st.session_state.models_loaded = True
518
+ logger.info("All models and vector store loaded successfully")
519
+
520
+ except Exception as exc:
521
+ st.error(f"Failed to load models: {exc}")
522
+ logger.error("Model loading failed: %s", exc)
523
+
524
+
525
+ def render_sidebar() -> None:
526
+ """Render the sidebar with document upload and management."""
527
+ with st.sidebar:
528
+ st.header("Document Management")
529
+ st.markdown("---")
530
+
531
+ # File uploader
532
+ uploaded_files = st.file_uploader(
533
+ "Upload Documents",
534
+ type=["pdf", "txt", "docx", "csv"],
535
+ accept_multiple_files=True,
536
+ help="Supported formats: PDF, TXT, DOCX, CSV",
537
+ )
538
+
539
+ if uploaded_files and st.button("Process Documents", type="primary"):
540
+ process_uploaded_files(uploaded_files)
541
+
542
+ st.markdown("---")
543
+
544
+ # Document list
545
+ st.subheader("Loaded Documents")
546
+ if st.session_state.documents:
547
+ for doc in st.session_state.documents:
548
+ st.markdown(
549
+ f'<div class="doc-item">'
550
+ f"<strong>{doc['filename']}</strong><br>"
551
+ f"<small>{doc['chunks']} chunks | {doc['type'].upper()}</small>"
552
+ f"</div>",
553
+ unsafe_allow_html=True,
554
+ )
555
+
556
+ collection = st.session_state.collection
557
+ if collection is not None:
558
+ st.caption(f"Total chunks in vector store: {collection.count()}")
559
+
560
+ if st.button("Clear All Documents", type="secondary"):
561
+ clear_all_documents()
562
+ st.rerun()
563
+ else:
564
+ st.info("No documents loaded yet. Upload files above to get started.")
565
+
566
+ st.markdown("---")
567
+
568
+ # Sample document loader
569
+ sample_path = os.path.join(
570
+ os.path.dirname(__file__), "sample_documents", "sample.txt"
571
+ )
572
+ if os.path.exists(sample_path):
573
+ if st.button("Load Sample Document"):
574
+ load_sample_document(sample_path)
575
+ st.rerun()
576
+
577
+ # Info
578
+ st.markdown("---")
579
+ st.markdown(
580
+ "**Models:**\n"
581
+ "- LLM: `google/flan-t5-small`\n"
582
+ "- Embeddings: `all-MiniLM-L6-v2`\n"
583
+ "- Vector Store: ChromaDB"
584
+ )
585
+
586
+
587
+ def process_uploaded_files(uploaded_files) -> None:
588
+ """Process a list of uploaded files and add to the vector store."""
589
+ rag = st.session_state.rag_system
590
+ embedder = st.session_state.embedder
591
+ collection = st.session_state.collection
592
+
593
+ if rag is None or embedder is None or collection is None:
594
+ st.error("Models not loaded. Please wait for initialization.")
595
+ return
596
+
597
+ progress = st.sidebar.progress(0)
598
+ total = len(uploaded_files)
599
+
600
+ for idx, uploaded_file in enumerate(uploaded_files):
601
+ try:
602
+ with st.spinner(f"Processing {uploaded_file.name}..."):
603
+ doc_data = rag.extract_text(uploaded_file)
604
+ chunk_count = rag.add_document(doc_data, embedder, collection)
605
+
606
+ st.session_state.documents.append(
607
+ {
608
+ "filename": doc_data["filename"],
609
+ "type": doc_data["type"],
610
+ "chunks": chunk_count,
611
+ "text_length": len(doc_data["text"]),
612
+ }
613
+ )
614
+ st.sidebar.success(f"Added {uploaded_file.name} ({chunk_count} chunks)")
615
+
616
+ except Exception as exc:
617
+ st.sidebar.error(f"Error processing {uploaded_file.name}: {exc}")
618
+ logger.error("File processing error: %s", exc)
619
+
620
+ progress.progress((idx + 1) / total)
621
+
622
+ progress.empty()
623
+
624
+
625
+ def load_sample_document(sample_path: str) -> None:
626
+ """Load the bundled sample document into the vector store."""
627
+ rag = st.session_state.rag_system
628
+ embedder = st.session_state.embedder
629
+ collection = st.session_state.collection
630
+
631
+ if rag is None or embedder is None or collection is None:
632
+ st.error("Models not loaded yet.")
633
+ return
634
+
635
+ try:
636
+ with open(sample_path, "r", encoding="utf-8") as f:
637
+ text = f.read()
638
+
639
+ doc_data = {
640
+ "filename": "sample.txt",
641
+ "text": text,
642
+ "type": "txt",
643
+ }
644
+
645
+ chunk_count = rag.add_document(doc_data, embedder, collection)
646
+ st.session_state.documents.append(
647
+ {
648
+ "filename": "sample.txt",
649
+ "type": "txt",
650
+ "chunks": chunk_count,
651
+ "text_length": len(text),
652
+ }
653
+ )
654
+ st.sidebar.success(f"Loaded sample document ({chunk_count} chunks)")
655
+
656
+ except Exception as exc:
657
+ st.sidebar.error(f"Error loading sample: {exc}")
658
+ logger.error("Sample loading error: %s", exc)
659
+
660
+
661
+ def clear_all_documents() -> None:
662
+ """Clear all documents and reset the vector store."""
663
+ st.session_state.documents = []
664
+ st.session_state.chat_history = []
665
+ st.session_state.retrieved_docs = []
666
+
667
+ rag = st.session_state.rag_system
668
+ if rag is not None:
669
+ try:
670
+ client, collection = rag.setup_vector_store()
671
+ st.session_state.chroma_client = client
672
+ st.session_state.collection = collection
673
+ except Exception as exc:
674
+ logger.error("Error resetting vector store: %s", exc)
675
+
676
+ st.sidebar.success("All documents cleared.")
677
+
678
+
679
+ def render_retrieved_docs(docs: list[dict]) -> None:
680
+ """Render retrieved source documents in the right column."""
681
+ if not docs:
682
+ st.info("Ask a question to see relevant source documents here.")
683
+ return
684
+
685
+ for i, doc in enumerate(docs[:5], 1):
686
+ similarity_pct = f"{doc['similarity']:.0%}"
687
+ excerpt = doc["text"][:250].strip()
688
+ if len(doc["text"]) > 250:
689
+ excerpt += "..."
690
+
691
+ st.markdown(
692
+ f'<div class="source-card">'
693
+ f'<div class="source-header">'
694
+ f"<span>Source {i}: {doc['source']}</span>"
695
+ f'<span class="relevance-badge">{similarity_pct}</span>'
696
+ f"</div>"
697
+ f'<div class="source-text">{excerpt}</div>'
698
+ f"</div>",
699
+ unsafe_allow_html=True,
700
+ )
701
+
702
+
703
+ def render_chat_area() -> None:
704
+ """Render the chat message history."""
705
+ # Display chat history
706
+ for message in st.session_state.chat_history:
707
+ with st.chat_message(message["role"]):
708
+ st.markdown(message["content"])
709
+
710
+
711
+ def handle_user_query(query: str) -> None:
712
+ """Process a user query: search, generate, and display results."""
713
+ rag = st.session_state.rag_system
714
+ embedder = st.session_state.embedder
715
+ collection = st.session_state.collection
716
+ llm_pipeline = st.session_state.llm_pipeline
717
+
718
+ if rag is None or embedder is None or llm_pipeline is None:
719
+ st.error("Models are not loaded yet. Please wait.")
720
+ return
721
+
722
+ # Display user message
723
+ with st.chat_message("user"):
724
+ st.markdown(query)
725
+
726
+ st.session_state.chat_history.append({"role": "user", "content": query})
727
+
728
+ # Retrieve and generate
729
+ with st.chat_message("assistant"):
730
+ with st.spinner("Searching documents and generating response..."):
731
+ # Search
732
+ if collection is not None and collection.count() > 0:
733
+ context_docs = rag.search(query, embedder, collection, n_results=5)
734
+ else:
735
+ context_docs = []
736
+
737
+ st.session_state.retrieved_docs = context_docs
738
+
739
+ # Generate
740
+ response = rag.generate_response(query, context_docs, llm_pipeline)
741
+
742
+ # Display response
743
+ st.markdown(response)
744
+
745
+ # Show source attribution inline
746
+ if context_docs:
747
+ with st.expander("View Sources", expanded=False):
748
+ for i, doc in enumerate(context_docs[:3], 1):
749
+ score = doc["similarity"]
750
+ source = doc["source"]
751
+ st.caption(f"Source {i}: {source} (relevance: {score:.0%})")
752
+
753
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
754
+
755
+
756
+ # ---------------------------------------------------------------------------
757
+ # Main Application
758
+ # ---------------------------------------------------------------------------
759
+ def main() -> None:
760
+ """Entry point for the Streamlit RAG demo application."""
761
+ st.set_page_config(
762
+ page_title="RAG System Demo",
763
+ page_icon="🔍",
764
+ layout="wide",
765
+ initial_sidebar_state="expanded",
766
+ )
767
+
768
+ # Inject custom CSS
769
+ st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
770
+
771
+ # Initialize
772
+ init_session_state()
773
+ load_models_into_state()
774
+
775
+ # Header
776
+ st.title("RAG System Demo")
777
+ st.caption(
778
+ "Upload documents and ask questions -- powered by open-source models on "
779
+ "Hugging Face."
780
+ )
781
+
782
+ # Sidebar
783
+ render_sidebar()
784
+
785
+ # Chat input must be at the top level (not inside columns/tabs/expanders)
786
+ user_query = st.chat_input("Ask a question about your documents...")
787
+ if user_query:
788
+ handle_user_query(user_query)
789
+
790
+ # Main content: two columns (chat | retrieved docs)
791
+ col_chat, col_docs = st.columns([2, 1])
792
+
793
+ with col_chat:
794
+ st.subheader("Chat")
795
+ render_chat_area()
796
+
797
+ with col_docs:
798
+ st.subheader("Retrieved Documents")
799
+ render_retrieved_docs(st.session_state.retrieved_docs)
800
+
801
+ # Footer
802
+ st.markdown(
803
+ '<div class="footer">'
804
+ "Powered by "
805
+ "<strong>google/flan-t5-small</strong> | "
806
+ "<strong>sentence-transformers/all-MiniLM-L6-v2</strong> | "
807
+ "<strong>ChromaDB</strong> | "
808
+ "<strong>Streamlit</strong>"
809
+ "</div>",
810
+ unsafe_allow_html=True,
811
+ )
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
demo_instructions.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG System Demo -- Instructions
2
+
3
+ ## Getting Started
4
+
5
+ 1. **Wait for models to load.** The first launch downloads and caches `google/flan-t5-small` and `all-MiniLM-L6-v2`. This may take a minute on the initial run.
6
+
7
+ 2. **Upload documents.** Use the sidebar file uploader to add one or more documents (PDF, TXT, DOCX, or CSV). Click **Process Documents** to ingest them into the vector store. Alternatively, click **Load Sample Document** to use the bundled sample about AI and machine learning.
8
+
9
+ 3. **Ask questions.** Type a question in the chat input at the bottom of the main area. The system will:
10
+ - Search for the most relevant chunks in your uploaded documents.
11
+ - Pass the top results as context to the language model.
12
+ - Return a generated answer with source attribution.
13
+
14
+ 4. **Review sources.** The right column shows the retrieved document chunks ranked by relevance score. Click **View Sources** in the chat to see which chunks informed each answer.
15
+
16
+ ## Tips
17
+
18
+ - Upload multiple documents and ask comparative questions.
19
+ - Shorter, focused questions tend to produce better answers.
20
+ - The relevance score (0-100%) indicates how closely a chunk matches your query.
21
+ - Use **Clear All Documents** in the sidebar to reset and start fresh.
22
+
23
+ ## Limitations
24
+
25
+ - `flan-t5-small` is a compact model. For complex reasoning, a larger model would perform better.
26
+ - Very long documents are split into chunks; some context may be lost at chunk boundaries.
27
+ - The vector store is in-memory and resets when the app restarts.
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ transformers>=4.35.0
3
+ torch>=2.0.0
4
+ sentence-transformers>=2.2.0
5
+ chromadb>=0.4.15
6
+ PyPDF2>=3.0.1
7
+ python-docx>=1.1.0
8
+ pandas>=2.0.0
9
+ numpy>=1.24.0
10
+ accelerate>=0.24.0
11
+ sentencepiece>=0.1.99
sample_documents/sample.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Artificial Intelligence and Machine Learning: A Comprehensive Overview
2
+
3
+ Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines that can perform tasks that typically require human intelligence. These tasks include visual perception, speech recognition, decision-making, and language translation. AI has evolved significantly since its inception in the 1950s and continues to transform industries worldwide.
4
+
5
+ Machine Learning (ML) is a subset of AI that focuses on developing algorithms and statistical models that enable computer systems to improve their performance on a specific task through experience. Instead of being explicitly programmed, ML systems learn patterns from data. The three main types of machine learning are supervised learning, unsupervised learning, and reinforcement learning.
6
+
7
+ Deep Learning is a subset of machine learning that uses artificial neural networks with multiple layers (hence "deep") to model and understand complex patterns in data. Deep learning has achieved remarkable success in areas such as image recognition, natural language processing, and game playing. Convolutional Neural Networks (CNNs) are particularly effective for image-related tasks, while Recurrent Neural Networks (RNNs) and Transformers excel at sequential data processing.
8
+
9
+ Natural Language Processing (NLP) is a field at the intersection of computer science, artificial intelligence, and linguistics. It focuses on the interaction between computers and humans through natural language. Recent advances in NLP, particularly with transformer-based models like BERT and GPT, have led to significant improvements in tasks such as text classification, sentiment analysis, machine translation, and question answering.
10
+
11
+ Applications of AI span numerous domains. In healthcare, AI assists in disease diagnosis, drug discovery, and personalized treatment plans. In finance, it powers fraud detection, algorithmic trading, and risk assessment. Autonomous vehicles rely on AI for perception, planning, and decision-making. In education, AI enables personalized learning experiences and automated grading systems.
12
+
13
+ The ethical considerations surrounding AI are increasingly important. Issues such as bias in AI systems, privacy concerns, job displacement, and the need for transparent and explainable AI are actively being researched and debated. Responsible AI development requires careful consideration of these factors to ensure that AI technologies benefit society as a whole.