Spaces:
Runtime error
Runtime error
| import re | |
| from typing import List, Dict, Any, Tuple | |
| import numpy as np | |
| import gradio as gr | |
| import faiss | |
| from pypdf import PdfReader | |
| import nbformat | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ========================= | |
| # Config | |
| # ========================= | |
| EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| GEN_MODEL_NAME = "google/flan-t5-base" # CPU-friendly baseline | |
| DEFAULT_CHUNK_SIZE = 900 # chars | |
| DEFAULT_OVERLAP = 150 # chars | |
| DEFAULT_TOP_K = 4 | |
| # ========================= | |
| # Globals (in-memory) | |
| # ========================= | |
| embedder = SentenceTransformer(EMBED_MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) | |
| gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME) | |
| INDEX = None | |
| CHUNKS: List[Dict[str, Any]] = [] | |
| EMBEDS = None | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def clean_text(t: str) -> str: | |
| if not t: | |
| return "" | |
| t = t.replace("\u00a0", " ") | |
| t = re.sub(r"\s+", " ", t).strip() | |
| return t | |
| def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: | |
| text = clean_text(text) | |
| if not text: | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| n = len(text) | |
| while start < n: | |
| end = min(n, start + chunk_size) | |
| chunks.append(text[start:end]) | |
| if end == n: | |
| break | |
| start = max(0, end - overlap) | |
| return chunks | |
| def read_pdf(path: str) -> List[Tuple[int, str]]: | |
| reader = PdfReader(path) | |
| pages = [] | |
| for i, page in enumerate(reader.pages): | |
| txt = clean_text(page.extract_text() or "") | |
| if txt: | |
| pages.append((i + 1, txt)) | |
| return pages | |
| def read_ipynb(path: str) -> List[Tuple[int, str, str]]: | |
| nb = nbformat.read(path, as_version=4) | |
| cells = [] | |
| for i, cell in enumerate(nb.cells): | |
| ctype = cell.get("cell_type") | |
| if ctype in ("markdown", "code"): | |
| src = clean_text(cell.get("source", "")) | |
| if src: | |
| cells.append((i + 1, ctype, src)) | |
| return cells | |
| def build_index(file_objs, chunk_size: int, overlap: int) -> str: | |
| global INDEX, CHUNKS, EMBEDS | |
| CHUNKS = [] | |
| texts_for_embed = [] | |
| if not file_objs: | |
| INDEX = None | |
| EMBEDS = None | |
| return "β Upload at least 1 PDF or IPYNB." | |
| for f in file_objs: | |
| path = f.name | |
| name = path.split("/")[-1].split("\\")[-1] | |
| lname = name.lower() | |
| if lname.endswith(".pdf"): | |
| for page_no, page_text in read_pdf(path): | |
| for j, ch in enumerate(chunk_text(page_text, chunk_size, overlap), start=1): | |
| CHUNKS.append({"text": ch, "source": name, "loc": f"page {page_no} Β· chunk {j}"}) | |
| texts_for_embed.append(ch) | |
| elif lname.endswith(".ipynb"): | |
| for cell_no, cell_type, cell_text in read_ipynb(path): | |
| for j, ch in enumerate(chunk_text(cell_text, chunk_size, overlap), start=1): | |
| CHUNKS.append({"text": ch, "source": name, "loc": f"{cell_type} cell {cell_no} Β· chunk {j}"}) | |
| texts_for_embed.append(ch) | |
| if not texts_for_embed: | |
| INDEX = None | |
| EMBEDS = None | |
| return "β No readable text found (scanned PDFs will look empty)." | |
| X = embedder.encode(texts_for_embed, normalize_embeddings=True, show_progress_bar=False) | |
| EMBEDS = X.astype("float32") | |
| dim = EMBEDS.shape[1] | |
| INDEX = faiss.IndexFlatIP(dim) | |
| INDEX.add(EMBEDS) | |
| return f"β Indexed {len(file_objs)} files β {len(CHUNKS)} chunks." | |
| def retrieve(query: str, k: int) -> List[Dict[str, Any]]: | |
| if INDEX is None: | |
| return [] | |
| q = embedder.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32") | |
| scores, idxs = INDEX.search(q, k) | |
| out = [] | |
| for score, idx in zip(scores[0], idxs[0]): | |
| if idx < 0: | |
| continue | |
| item = CHUNKS[idx] | |
| out.append({**item, "score": float(score)}) | |
| return out | |
| def make_context_snippets(items: List[Dict[str, Any]], max_chars=700) -> str: | |
| parts = [] | |
| for i, it in enumerate(items, start=1): | |
| s = it["text"] | |
| if len(s) > max_chars: | |
| s = s[:max_chars] + "..." | |
| parts.append(f"[{i}] {it['source']} ({it['loc']})\n{s}") | |
| return "\n\n".join(parts) | |
| def generate_text(prompt: str, max_new_tokens: int) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| out_ids = gen_model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
| return tokenizer.decode(out_ids[0], skip_special_tokens=True) | |
| def citations(retrieved: List[Dict[str, Any]]) -> str: | |
| if not retrieved: | |
| return "- (none)" | |
| return "\n".join([f"- {i+1}. {it['source']} β {it['loc']}" for i, it in enumerate(retrieved)]) | |
| def answer_question(q: str, top_k: int) -> str: | |
| retrieved = retrieve(q, top_k) | |
| ctx = make_context_snippets(retrieved) | |
| prompt = ( | |
| "You are a study assistant. Answer ONLY using the SOURCES.\n" | |
| "If not enough info, say: Not enough information in the provided files.\n\n" | |
| f"SOURCES:\n{ctx}\n\nQUESTION: {q}\nANSWER (bullets + 1-line summary):" | |
| ) | |
| ans = generate_text(prompt, 256) | |
| return f"{ans}\n\nCitations:\n{citations(retrieved)}" | |
| def make_notes(topic: str, top_k: int) -> str: | |
| retrieved = retrieve(topic, top_k) | |
| ctx = make_context_snippets(retrieved) | |
| prompt = ( | |
| "Create clean study notes ONLY from the SOURCES.\n" | |
| "Use headings and bullets. Keep concise.\n\n" | |
| f"TOPIC: {topic}\n\nSOURCES:\n{ctx}\n\nNOTES:" | |
| ) | |
| out = generate_text(prompt, 256) | |
| return f"{out}\n\nCitations:\n{citations(retrieved)}" | |
| def make_quiz(topic: str, n_q: int, top_k: int) -> str: | |
| retrieved = retrieve(topic, top_k) | |
| ctx = make_context_snippets(retrieved) | |
| prompt = ( | |
| "Create a tricky quiz ONLY from the SOURCES.\n" | |
| f"Generate exactly {n_q} questions.\n" | |
| "Mix MCQ, True/False, short answer. Include ANSWER KEY at end.\n\n" | |
| f"TOPIC: {topic}\n\nSOURCES:\n{ctx}\n\nQUIZ:" | |
| ) | |
| out = generate_text(prompt, 512) | |
| return f"{out}\n\nCitations:\n{citations(retrieved)}" | |
| # ========================= | |
| # Gradio callbacks (IMPORTANT: messages format) | |
| # ========================= | |
| def cb_index(files, chunk_size, overlap): | |
| return build_index(files, int(chunk_size), int(overlap)) | |
| def cb_chat(user_text, history, top_k): | |
| history = history or [] | |
| if INDEX is None: | |
| history.append({"role": "user", "content": user_text}) | |
| history.append({"role": "assistant", "content": "β Upload files and click **Index** first."}) | |
| return history, "" | |
| history.append({"role": "user", "content": user_text}) | |
| history.append({"role": "assistant", "content": answer_question(user_text, int(top_k))}) | |
| return history, "" | |
| def cb_notes(topic, top_k): | |
| if INDEX is None: | |
| return "β Upload files and click **Index** first." | |
| t = topic.strip() if topic and topic.strip() else "main topics" | |
| return make_notes(t, int(top_k)) | |
| def cb_quiz(topic, n_q, top_k): | |
| if INDEX is None: | |
| return "β Upload files and click **Index** first." | |
| t = topic.strip() if topic and topic.strip() else "important concepts" | |
| return make_quiz(t, int(n_q), int(top_k)) | |
| # ========================= | |
| # UI (nicer layout + light CSS) | |
| # ========================= | |
| CSS = """ | |
| #title {font-weight:800;} | |
| .sidebar {border-right: 1px solid #2223;} | |
| """ | |
| with gr.Blocks(css=CSS, title="Study RAG Assistant") as demo: | |
| gr.Markdown("## π Study RAG Assistant", elem_id="title") | |
| gr.Markdown("Upload your PDFs + notebooks β Index β Chat / Notes / Quiz grounded in your files.") | |
| with gr.Row(): | |
| # Left sidebar | |
| with gr.Column(scale=1, elem_classes=["sidebar"]): | |
| gr.Markdown("### Sources") | |
| files = gr.File( | |
| label="Upload (.pdf, .ipynb)", | |
| file_count="multiple", | |
| file_types=[".pdf", ".ipynb"] | |
| ) | |
| chunk_size = gr.Slider(300, 2000, value=DEFAULT_CHUNK_SIZE, step=50, label="Chunk size (chars)") | |
| overlap = gr.Slider(0, 500, value=DEFAULT_OVERLAP, step=10, label="Chunk overlap (chars)") | |
| index_btn = gr.Button("Index", variant="primary") | |
| index_status = gr.Textbox(label="Index status", interactive=False) | |
| # Main area | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.Tab("Chat"): | |
| top_k_chat = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks") | |
| chat = gr.Chatbot(type="messages", height=420) | |
| user = gr.Textbox(label="Ask a question", placeholder="e.g., explain backpropagation from my lecture") | |
| ask = gr.Button("Ask", variant="primary") | |
| ask.click(cb_chat, inputs=[user, chat, top_k_chat], outputs=[chat, user]) | |
| user.submit(cb_chat, inputs=[user, chat, top_k_chat], outputs=[chat, user]) | |
| with gr.Tab("Notes"): | |
| top_k_notes = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks") | |
| topic_notes = gr.Textbox(label="Topic (optional)", placeholder="e.g., activation functions") | |
| notes_btn = gr.Button("Generate Notes", variant="primary") | |
| notes_out = gr.Textbox(label="Notes", lines=18) | |
| notes_btn.click(cb_notes, inputs=[topic_notes, top_k_notes], outputs=notes_out) | |
| with gr.Tab("Quiz"): | |
| top_k_quiz = gr.Slider(2, 8, value=DEFAULT_TOP_K, step=1, label="Top-k chunks") | |
| topic_quiz = gr.Textbox(label="Topic (optional)", placeholder="e.g., CNN vs RNN") | |
| n_q = gr.Slider(10, 50, value=10, step=1, label="Questions") | |
| quiz_btn = gr.Button("Generate Quiz", variant="primary") | |
| quiz_out = gr.Textbox(label="Quiz", lines=18) | |
| quiz_btn.click(cb_quiz, inputs=[topic_quiz, n_q, top_k_quiz], outputs=quiz_out) | |
| index_btn.click(cb_index, inputs=[files, chunk_size, overlap], outputs=index_status) | |
| demo.launch() | |