import re from typing import List, Dict, Any, Tuple import numpy as np import gradio as gr import faiss from pypdf import PdfReader import nbformat from sentence_transformers import SentenceTransformer import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ========================= # Config # ========================= EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" GEN_MODEL_NAME = "google/flan-t5-base" # CPU-friendly baseline DEFAULT_CHUNK_SIZE = 900 # chars DEFAULT_OVERLAP = 150 # chars DEFAULT_TOP_K = 4 # ========================= # Globals (in-memory) # ========================= embedder = SentenceTransformer(EMBED_MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME) INDEX = None CHUNKS: List[Dict[str, Any]] = [] EMBEDS = None # ========================= # Helpers # ========================= def clean_text(t: str) -> str: if not t: return "" t = t.replace("\u00a0", " ") t = re.sub(r"\s+", " ", t).strip() return t def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: text = clean_text(text) if not text: return [] chunks = [] start = 0 n = len(text) while start < n: end = min(n, start + chunk_size) chunks.append(text[start:end]) if end == n: break start = max(0, end - overlap) return chunks def read_pdf(path: str) -> List[Tuple[int, str]]: reader = PdfReader(path) pages = [] for i, page in enumerate(reader.pages): txt = clean_text(page.extract_text() or "") if txt: pages.append((i + 1, txt)) return pages def read_ipynb(path: str) -> List[Tuple[int, str, str]]: nb = nbformat.read(path, as_version=4) cells = [] for i, cell in enumerate(nb.cells): ctype = cell.get("cell_type") if ctype in ("markdown", "code"): src = clean_text(cell.get("source", "")) if src: cells.append((i + 1, ctype, src)) return cells def build_index(file_objs, chunk_size: int, overlap: int) -> str: global INDEX, CHUNKS, EMBEDS CHUNKS = [] texts_for_embed = [] if not file_objs: INDEX = None EMBEDS = None return "❌ Upload at least 1 PDF or IPYNB." for f in file_objs: path = f.name name = path.split("/")[-1].split("\\")[-1] lname = name.lower() if lname.endswith(".pdf"): for page_no, page_text in read_pdf(path): for j, ch in enumerate(chunk_text(page_text, chunk_size, overlap), start=1): CHUNKS.append({"text": ch, "source": name, "loc": f"page {page_no} · chunk {j}"}) texts_for_embed.append(ch) elif lname.endswith(".ipynb"): for cell_no, cell_type, cell_text in read_ipynb(path): for j, ch in enumerate(chunk_text(cell_text, chunk_size, overlap), start=1): CHUNKS.append({"text": ch, "source": name, "loc": f"{cell_type} cell {cell_no} · chunk {j}"}) texts_for_embed.append(ch) if not texts_for_embed: INDEX = None EMBEDS = None return "❌ No readable text found (scanned PDFs will look empty)." X = embedder.encode(texts_for_embed, normalize_embeddings=True, show_progress_bar=False) EMBEDS = X.astype("float32") dim = EMBEDS.shape[1] INDEX = faiss.IndexFlatIP(dim) INDEX.add(EMBEDS) return f"✅ Indexed {len(file_objs)} files → {len(CHUNKS)} chunks." def retrieve(query: str, k: int) -> List[Dict[str, Any]]: if INDEX is None: return [] q = embedder.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32") scores, idxs = INDEX.search(q, k) out = [] for score, idx in zip(scores[0], idxs[0]): if idx < 0: continue item = CHUNKS[idx] out.append({**item, "score": float(score)}) return out def make_context_snippets(items: List[Dict[str, Any]], max_chars=700) -> str: parts = [] for i, it in enumerate(items, start=1): s = it["text"] if len(s) > max_chars: s = s[:max_chars] + "..." parts.append(f"[{i}] {it['source']} ({it['loc']})\n{s}") return "\n\n".join(parts) def generate_text(prompt: str, max_new_tokens: int) -> str: inputs = tokenizer(prompt, return_tensors="pt", truncation=True) with torch.no_grad(): out_ids = gen_model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) return tokenizer.decode(out_ids[0], skip_special_tokens=True) def citations(retrieved: List[Dict[str, Any]]) -> str: if not retrieved: return "- (none)" return "\n".join([f"- {i+1}. {it['source']} — {it['loc']}" for i, it in enumerate(retrieved)]) def answer_question(q: str, top_k: int) -> str: retrieved = retrieve(q, top_k) ctx = make_context_snippets(retrieved) prompt = ( "You are a study assistant. Answer ONLY using the SOURCES.\n" "If not enough info, say: Not enough information in the provided files.\n\n" f"SOURCES:\n{ctx}\n\nQUESTION: {q}\nANSWER (bullets + 1-line summary):" ) ans = generate_text(prompt, 256) return f"{ans}\n\nCitations:\n{citations(retrieved)}" def make_notes(topic: str, top_k: int) -> str: retrieved = retrieve(topic, top_k) ctx = make_context_snippets(retrieved) prompt = ( "Create clean study notes ONLY from the SOURCES.\n" "Use headings and bullets. Keep concise.\n\n" f"TOPIC: {topic}\n\nSOURCES:\n{ctx}\n\nNOTES:" ) out = generate_text(prompt, 256) return f"{out}\n\nCitations:\n{citations(retrieved)}" def make_quiz(topic: str, n_q: int, top_k: int) -> str: retrieved = retrieve(topic, top_k) ctx = make_context_snippets(retrieved) prompt = ( "Create a tricky quiz ONLY from the SOURCES.\n" f"Generate exactly {n_q} questions.\n" "Mix MCQ, True/False, short answer. Include ANSWER KEY at end.\n\n" f"TOPIC: {topic}\n\nSOURCES:\n{ctx}\n\nQUIZ:" ) out = generate_text(prompt, 512) return f"{out}\n\nCitations:\n{citations(retrieved)}" # ========================= # Gradio callbacks (IMPORTANT: messages format) # ========================= def cb_index(files, chunk_size, overlap): return build_index(files, int(chunk_size), int(overlap)) def cb_chat(user_text, history, top_k): history = history or [] if INDEX is None: history.append({"role": "user", "content": user_text}) history.append({"role": "assistant", "content": "❌ Upload files and click **Index** first."}) return history, "" history.append({"role": "user", "content": user_text}) history.append({"role": "assistant", "content": answer_question(user_text, int(top_k))}) return history, "" def cb_notes(topic, top_k): if INDEX is None: return "❌ Upload files and click **Index** first." t = topic.strip() if topic and topic.strip() else "main topics" return make_notes(t, int(top_k)) def cb_quiz(topic, n_q, top_k): if INDEX is None: return "❌ Upload files and click **Index** first." t = topic.strip() if topic and topic.strip() else "important concepts" return make_quiz(t, int(n_q), int(top_k)) # ========================= # UI (nicer layout + light CSS) # ========================= CSS = """ #title {font-weight:800;} .sidebar {border-right: 1px solid #2223;} """ with gr.Blocks(css=CSS, title="Study RAG Assistant") as demo: gr.Markdown("## 📚 Study RAG Assistant", elem_id="title") gr.Markdown("Upload your PDFs + notebooks → Index → Chat / Notes / Quiz grounded in your files.") with gr.Row(): # Left sidebar with gr.Column(scale=1, elem_classes=["sidebar"]): gr.Markdown("### Sources") files = gr.File( label="Upload (.pdf, .ipynb)", file_count="multiple", file_types=[".pdf", ".ipynb"] ) chunk_size = gr.Slider(300, 2000, value=DEFAULT_CHUNK_SIZE, step=50, label="Chunk size (chars)") overlap = gr.Slider(0, 500, value=DEFAULT_OVERLAP, step=10, label="Chunk overlap (chars)") index_btn = gr.Button("Index", variant="primary") index_status = gr.Textbox(label="Index status", interactive=False) # Main area with gr.Column(scale=3): with gr.Tabs(): with gr.Tab("Chat"): top_k_chat = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks") chat = gr.Chatbot(type="messages", height=420) user = gr.Textbox(label="Ask a question", placeholder="e.g., explain backpropagation from my lecture") ask = gr.Button("Ask", variant="primary") ask.click(cb_chat, inputs=[user, chat, top_k_chat], outputs=[chat, user]) user.submit(cb_chat, inputs=[user, chat, top_k_chat], outputs=[chat, user]) with gr.Tab("Notes"): top_k_notes = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks") topic_notes = gr.Textbox(label="Topic (optional)", placeholder="e.g., activation functions") notes_btn = gr.Button("Generate Notes", variant="primary") notes_out = gr.Textbox(label="Notes", lines=18) notes_btn.click(cb_notes, inputs=[topic_notes, top_k_notes], outputs=notes_out) with gr.Tab("Quiz"): top_k_quiz = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks") topic_quiz = gr.Textbox(label="Topic (optional)", placeholder="e.g., CNN vs RNN") n_q = gr.Slider(10, 50, value=10, step=1, label="Questions") quiz_btn = gr.Button("Generate Quiz", variant="primary") quiz_out = gr.Textbox(label="Quiz", lines=18) quiz_btn.click(cb_quiz, inputs=[topic_quiz, n_q, top_k_quiz], outputs=quiz_out) index_btn.click(cb_index, inputs=[files, chunk_size, overlap], outputs=index_status) demo.launch()