Study_RAG_Final / app.py
Asalun's picture
Update app.py
3326deb verified
import re
from typing import List, Dict, Any, Tuple
import numpy as np
import gradio as gr
import faiss
from pypdf import PdfReader
import nbformat
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# =========================
# Config
# =========================
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
GEN_MODEL_NAME = "google/flan-t5-base" # CPU-friendly baseline
DEFAULT_CHUNK_SIZE = 900 # chars
DEFAULT_OVERLAP = 150 # chars
DEFAULT_TOP_K = 4
# =========================
# Globals (in-memory)
# =========================
embedder = SentenceTransformer(EMBED_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
INDEX = None
CHUNKS: List[Dict[str, Any]] = []
EMBEDS = None
# =========================
# Helpers
# =========================
def clean_text(t: str) -> str:
if not t:
return ""
t = t.replace("\u00a0", " ")
t = re.sub(r"\s+", " ", t).strip()
return t
def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
text = clean_text(text)
if not text:
return []
chunks = []
start = 0
n = len(text)
while start < n:
end = min(n, start + chunk_size)
chunks.append(text[start:end])
if end == n:
break
start = max(0, end - overlap)
return chunks
def read_pdf(path: str) -> List[Tuple[int, str]]:
reader = PdfReader(path)
pages = []
for i, page in enumerate(reader.pages):
txt = clean_text(page.extract_text() or "")
if txt:
pages.append((i + 1, txt))
return pages
def read_ipynb(path: str) -> List[Tuple[int, str, str]]:
nb = nbformat.read(path, as_version=4)
cells = []
for i, cell in enumerate(nb.cells):
ctype = cell.get("cell_type")
if ctype in ("markdown", "code"):
src = clean_text(cell.get("source", ""))
if src:
cells.append((i + 1, ctype, src))
return cells
def build_index(file_objs, chunk_size: int, overlap: int) -> str:
global INDEX, CHUNKS, EMBEDS
CHUNKS = []
texts_for_embed = []
if not file_objs:
INDEX = None
EMBEDS = None
return "❌ Upload at least 1 PDF or IPYNB."
for f in file_objs:
path = f.name
name = path.split("/")[-1].split("\\")[-1]
lname = name.lower()
if lname.endswith(".pdf"):
for page_no, page_text in read_pdf(path):
for j, ch in enumerate(chunk_text(page_text, chunk_size, overlap), start=1):
CHUNKS.append({"text": ch, "source": name, "loc": f"page {page_no} Β· chunk {j}"})
texts_for_embed.append(ch)
elif lname.endswith(".ipynb"):
for cell_no, cell_type, cell_text in read_ipynb(path):
for j, ch in enumerate(chunk_text(cell_text, chunk_size, overlap), start=1):
CHUNKS.append({"text": ch, "source": name, "loc": f"{cell_type} cell {cell_no} Β· chunk {j}"})
texts_for_embed.append(ch)
if not texts_for_embed:
INDEX = None
EMBEDS = None
return "❌ No readable text found (scanned PDFs will look empty)."
X = embedder.encode(texts_for_embed, normalize_embeddings=True, show_progress_bar=False)
EMBEDS = X.astype("float32")
dim = EMBEDS.shape[1]
INDEX = faiss.IndexFlatIP(dim)
INDEX.add(EMBEDS)
return f"βœ… Indexed {len(file_objs)} files β†’ {len(CHUNKS)} chunks."
def retrieve(query: str, k: int) -> List[Dict[str, Any]]:
if INDEX is None:
return []
q = embedder.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32")
scores, idxs = INDEX.search(q, k)
out = []
for score, idx in zip(scores[0], idxs[0]):
if idx < 0:
continue
item = CHUNKS[idx]
out.append({**item, "score": float(score)})
return out
def make_context_snippets(items: List[Dict[str, Any]], max_chars=700) -> str:
parts = []
for i, it in enumerate(items, start=1):
s = it["text"]
if len(s) > max_chars:
s = s[:max_chars] + "..."
parts.append(f"[{i}] {it['source']} ({it['loc']})\n{s}")
return "\n\n".join(parts)
def generate_text(prompt: str, max_new_tokens: int) -> str:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
with torch.no_grad():
out_ids = gen_model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
return tokenizer.decode(out_ids[0], skip_special_tokens=True)
def citations(retrieved: List[Dict[str, Any]]) -> str:
if not retrieved:
return "- (none)"
return "\n".join([f"- {i+1}. {it['source']} β€” {it['loc']}" for i, it in enumerate(retrieved)])
def answer_question(q: str, top_k: int) -> str:
retrieved = retrieve(q, top_k)
ctx = make_context_snippets(retrieved)
prompt = (
"You are a study assistant. Answer ONLY using the SOURCES.\n"
"If not enough info, say: Not enough information in the provided files.\n\n"
f"SOURCES:\n{ctx}\n\nQUESTION: {q}\nANSWER (bullets + 1-line summary):"
)
ans = generate_text(prompt, 256)
return f"{ans}\n\nCitations:\n{citations(retrieved)}"
def make_notes(topic: str, top_k: int) -> str:
retrieved = retrieve(topic, top_k)
ctx = make_context_snippets(retrieved)
prompt = (
"Create clean study notes ONLY from the SOURCES.\n"
"Use headings and bullets. Keep concise.\n\n"
f"TOPIC: {topic}\n\nSOURCES:\n{ctx}\n\nNOTES:"
)
out = generate_text(prompt, 256)
return f"{out}\n\nCitations:\n{citations(retrieved)}"
def make_quiz(topic: str, n_q: int, top_k: int) -> str:
retrieved = retrieve(topic, top_k)
ctx = make_context_snippets(retrieved)
prompt = (
"Create a tricky quiz ONLY from the SOURCES.\n"
f"Generate exactly {n_q} questions.\n"
"Mix MCQ, True/False, short answer. Include ANSWER KEY at end.\n\n"
f"TOPIC: {topic}\n\nSOURCES:\n{ctx}\n\nQUIZ:"
)
out = generate_text(prompt, 512)
return f"{out}\n\nCitations:\n{citations(retrieved)}"
# =========================
# Gradio callbacks (IMPORTANT: messages format)
# =========================
def cb_index(files, chunk_size, overlap):
return build_index(files, int(chunk_size), int(overlap))
def cb_chat(user_text, history, top_k):
history = history or []
if INDEX is None:
history.append({"role": "user", "content": user_text})
history.append({"role": "assistant", "content": "❌ Upload files and click **Index** first."})
return history, ""
history.append({"role": "user", "content": user_text})
history.append({"role": "assistant", "content": answer_question(user_text, int(top_k))})
return history, ""
def cb_notes(topic, top_k):
if INDEX is None:
return "❌ Upload files and click **Index** first."
t = topic.strip() if topic and topic.strip() else "main topics"
return make_notes(t, int(top_k))
def cb_quiz(topic, n_q, top_k):
if INDEX is None:
return "❌ Upload files and click **Index** first."
t = topic.strip() if topic and topic.strip() else "important concepts"
return make_quiz(t, int(n_q), int(top_k))
# =========================
# UI (nicer layout + light CSS)
# =========================
CSS = """
#title {font-weight:800;}
.sidebar {border-right: 1px solid #2223;}
"""
with gr.Blocks(css=CSS, title="Study RAG Assistant") as demo:
gr.Markdown("## πŸ“š Study RAG Assistant", elem_id="title")
gr.Markdown("Upload your PDFs + notebooks β†’ Index β†’ Chat / Notes / Quiz grounded in your files.")
with gr.Row():
# Left sidebar
with gr.Column(scale=1, elem_classes=["sidebar"]):
gr.Markdown("### Sources")
files = gr.File(
label="Upload (.pdf, .ipynb)",
file_count="multiple",
file_types=[".pdf", ".ipynb"]
)
chunk_size = gr.Slider(300, 2000, value=DEFAULT_CHUNK_SIZE, step=50, label="Chunk size (chars)")
overlap = gr.Slider(0, 500, value=DEFAULT_OVERLAP, step=10, label="Chunk overlap (chars)")
index_btn = gr.Button("Index", variant="primary")
index_status = gr.Textbox(label="Index status", interactive=False)
# Main area
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Chat"):
top_k_chat = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks")
chat = gr.Chatbot(type="messages", height=420)
user = gr.Textbox(label="Ask a question", placeholder="e.g., explain backpropagation from my lecture")
ask = gr.Button("Ask", variant="primary")
ask.click(cb_chat, inputs=[user, chat, top_k_chat], outputs=[chat, user])
user.submit(cb_chat, inputs=[user, chat, top_k_chat], outputs=[chat, user])
with gr.Tab("Notes"):
top_k_notes = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks")
topic_notes = gr.Textbox(label="Topic (optional)", placeholder="e.g., activation functions")
notes_btn = gr.Button("Generate Notes", variant="primary")
notes_out = gr.Textbox(label="Notes", lines=18)
notes_btn.click(cb_notes, inputs=[topic_notes, top_k_notes], outputs=notes_out)
with gr.Tab("Quiz"):
top_k_quiz = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks")
topic_quiz = gr.Textbox(label="Topic (optional)", placeholder="e.g., CNN vs RNN")
n_q = gr.Slider(10, 50, value=10, step=1, label="Questions")
quiz_btn = gr.Button("Generate Quiz", variant="primary")
quiz_out = gr.Textbox(label="Quiz", lines=18)
quiz_btn.click(cb_quiz, inputs=[topic_quiz, n_q, top_k_quiz], outputs=quiz_out)
index_btn.click(cb_index, inputs=[files, chunk_size, overlap], outputs=index_status)
demo.launch()