# app.py import os import io import streamlit as st import pdfplumber from pptx import Presentation import docx as docx_lib import pandas as pd from sentence_transformers import SentenceTransformer import faiss from groq import Groq import markdown2 from reportlab.lib.pagesizes import letter from reportlab.pdfgen import canvas # ---------------- CONFIG ---------------- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" GROQ_LLM_MODEL = "llama-3.3-70b-versatile" # ---------------- HELPERS ---------------- @st.cache_resource def load_embedder(): return SentenceTransformer(EMBED_MODEL) embedder = load_embedder() def parse_pdf_bytes(file_bytes): try: text = "" with pdfplumber.open(io.BytesIO(file_bytes)) as pdf: for page in pdf.pages: p = page.extract_text() if p: text += p + "\n" return text except Exception as e: st.warning(f"PDF parse warning: {e}") return "" def parse_docx_bytes(file_bytes): try: doc = docx_lib.Document(io.BytesIO(file_bytes)) return "\n".join([p.text for p in doc.paragraphs]) except Exception as e: st.warning(f"DOCX parse warning: {e}") return "" def parse_pptx_bytes(file_bytes): try: prs = Presentation(io.BytesIO(file_bytes)) text = "" for slide in prs.slides: for shape in slide.shapes: if hasattr(shape, "text"): text += shape.text + "\n" return text except Exception as e: st.warning(f"PPTX parse warning: {e}") return "" def parse_spreadsheet_bytes(file_bytes): try: try: df = pd.read_excel(io.BytesIO(file_bytes)) except Exception: df = pd.read_csv(io.BytesIO(file_bytes)) return df.to_csv(index=False) except Exception as e: st.warning(f"Spreadsheet parse warning: {e}") return "" def parse_txt_bytes(file_bytes): try: return file_bytes.decode("utf-8", errors="ignore") except Exception as e: st.warning(f"TXT parse warning: {e}") return "" def chunk_text(text, max_chars=1000, overlap=200): if not text: return [] chunks = [] start = 0 while start < len(text): end = min(start + max_chars, len(text)) chunk = text[start:end].strip() if chunk: chunks.append(chunk) if end == len(text): break start = end - overlap return chunks def build_faiss_index(chunks, embedder): if not chunks: return None, None embeddings = embedder.encode(chunks, convert_to_numpy=True) dim = embeddings.shape[1] index = faiss.IndexFlatL2(dim) index.add(embeddings.astype("float32")) return index, embeddings def retrieve_chunks(query, embedder, faiss_index, chunks, k=5): if faiss_index is None or not chunks: return [] q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32") D, I = faiss_index.search(q_emb, k) results = [] for idx in I[0]: if 0 <= idx < len(chunks): results.append(chunks[idx]) return results # ---------------- Groq LLM ---------------- EDU_PROMPTS = { "Primary School": "Explain this to me like I'm 5 years old, in a fun and simple way with examples and analogies.", "Middle School": "Explain this in a simple and clear way appropriate for a middle school student with examples.", "High School": "Explain this clearly, assuming knowledge up to high school level.", "Undergraduate": "Explain this in a university-level way, with clarity and useful details and examples.", "Graduate": "Explain this at graduate-level rigor, including key details, nuance, and technical terms as appropriate.", } def get_groq_client(): api_key = None try: api_key = st.secrets[""] except Exception: pass if not api_key: api_key = st.session_state.get("groq_api_key") or os.environ.get("GROQ_API_KEY") if not api_key: raise ValueError("Groq API key not found. Set st.secrets['GROQ_API_KEY'], or enter in sidebar, or set env GROQ_API_KEY.") return Groq(api_key=api_key) def call_llm_with_context(question, retrieved_chunks, edu_level): client = get_groq_client() edu_instr = EDU_PROMPTS.get(edu_level, "") context = "\n\n".join(retrieved_chunks) if retrieved_chunks else "" user_content = f"{edu_instr}\n\nContext:\n{context}\n\nQuestion: {question}" response = client.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful and knowledgeable tutor."}, {"role": "user", "content": user_content} ], model=GROQ_LLM_MODEL, ) return response.choices[0].message.content def make_summary(question, retrieved_chunks, edu_level): client = get_groq_client() edu_instr = EDU_PROMPTS.get(edu_level, "") context = "\n\n".join(retrieved_chunks) if retrieved_chunks else "" prompt = f"{edu_instr}\n\nHere is some context:\n{context}\n\nPlease give a short, easy-to-understand summary of: {question}\nKeep it concise and simple; use bullet points if helpful." response = client.chat.completions.create( messages=[ {"role": "system", "content": "You are a concise summarizer."}, {"role": "user", "content": prompt} ], model=GROQ_LLM_MODEL, ) return response.choices[0].message.content def make_mcqs_from_summary(summary_text, count=5, difficulty="medium"): client = get_groq_client() prompt = ( f"Create {count} multiple choice questions (MCQs) from the following summary. " "Each question should have 4 options labeled A-D and indicate the correct option. " "Also provide a 1-2 sentence explanation for the correct answer. " f"Difficulty: {difficulty}.\n\nSummary:\n{summary_text}" ) response = client.chat.completions.create( messages=[ {"role": "system", "content": "You are an assistant that generates high-quality multiple-choice questions."}, {"role": "user", "content": prompt} ], model=GROQ_LLM_MODEL, ) return response.choices[0].message.content # ---------------- STREAMLIT UI ---------------- st.set_page_config(page_title="AI Study Assistant", layout="wide") st.title("📚 AI Study Assistant — Exam Mode") with st.sidebar: st.header("Settings") groq_key = st.text_input("Groq API key (optional)", type="password") if groq_key: st.session_state["groq_api_key"] = groq_key edu_level = st.selectbox("Education level", list(EDU_PROMPTS.keys())) st.info("Upload documents and ask questions. You can generate summaries + MCQs.") uploaded_files = st.file_uploader("Upload study documents (PDF, DOCX, PPTX, XLSX/CSV, TXT)", accept_multiple_files=True) if not uploaded_files: st.info("Please upload at least one document.") st.stop() # ---------------- PARSE FILES ---------------- all_text = "" for uf in uploaded_files: raw = uf.read() text = "" name = uf.name.lower() if name.endswith(".pdf"): text = parse_pdf_bytes(raw) elif name.endswith(".docx"): text = parse_docx_bytes(raw) elif name.endswith(".pptx"): text = parse_pptx_bytes(raw) elif name.endswith((".xls", ".xlsx", ".csv")): text = parse_spreadsheet_bytes(raw) elif name.endswith(".txt"): text = parse_txt_bytes(raw) else: try: text = raw.decode("utf-8") except Exception: text = "" if text: all_text += f"\n\n### From file: {uf.name}\n\n{text}" if not all_text.strip(): st.error("No textual content extracted.") st.stop() # ---------------- CHUNK + INDEX ---------------- with st.spinner("Processing documents..."): chunks = chunk_text(all_text) faiss_index, embeddings = build_faiss_index(chunks, embedder) st.success(f"Prepared {len(chunks)} chunks and built vector index.") # ---------------- ASK QUESTION ---------------- question = st.text_input("Ask a question about your materials:") if not question: st.info("Type a question to begin.") st.stop() topk = st.number_input("Top-k passages", min_value=1, max_value=10, value=5) mcq_count = st.number_input("MCQs to generate", min_value=1, max_value=20, value=5) mcq_diff = st.selectbox("MCQ difficulty", ["easy", "medium", "hard"], index=1) retrieved = retrieve_chunks(question, embedder, faiss_index, chunks, k=int(topk)) if retrieved: st.subheader("Relevant passages:") for i, r in enumerate(retrieved): st.markdown(f"**Passage {i+1}:**") st.write(r[:800] + ("..." if len(r) > 800 else "")) else: st.warning("No relevant passages found.") # ---------------- GENERATE ANSWER ---------------- try: answer = call_llm_with_context(question, retrieved, edu_level) st.subheader("Answer:") st.write(answer) except Exception as e: st.error(f"LLM error: {e}") st.stop() # ---------------- GENERATE SUMMARY + MCQs ---------------- if st.checkbox("Generate summary and MCQs"): try: summary = make_summary(question, retrieved, edu_level) st.subheader("📘 Summary") st.write(summary) # Downloads md_text = summary html_text = markdown2.markdown(summary) # PDF pdf_buffer = io.BytesIO() p = canvas.Canvas(pdf_buffer, pagesize=letter) width, height = letter text_obj = p.beginText(40, height - 40) for line in summary.split("\n"): while len(line) > 90: text_obj.textLine(line[:90]) line = line[90:] text_obj.textLine(line) p.drawText(text_obj) p.showPage() p.save() pdf_buffer.seek(0) # DOCX docx_buffer = io.BytesIO() doc = docx_lib.Document() doc.add_heading("Summary", level=1) for line in summary.split("\n"): doc.add_paragraph(line) doc.save(docx_buffer) docx_buffer.seek(0) st.download_button("⬇️ Download Summary (Markdown)", md_text, file_name="summary.md") st.download_button("⬇️ Download Summary (HTML)", html_text, file_name="summary.html", mime="text/html") st.download_button("⬇️ Download Summary (PDF)", pdf_buffer, file_name="summary.pdf", mime="application/pdf") st.download_button("⬇️ Download Summary (DOCX)", docx_buffer, file_name="summary.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document") # MCQs mcq_text = make_mcqs_from_summary(summary, count=int(mcq_count), difficulty=mcq_diff) st.subheader("📝 Generated MCQs") st.write(mcq_text) mcq_docx_buf = io.BytesIO() doc_mcq = docx_lib.Document() doc_mcq.add_heading("MCQs", level=1) for line in mcq_text.split("\n"): doc_mcq.add_paragraph(line) doc_mcq.save(mcq_docx_buf) mcq_docx_buf.seek(0) mcq_pdf_buf = io.BytesIO() p2 = canvas.Canvas(mcq_pdf_buf, pagesize=letter) text_obj2 = p2.beginText(40, height - 40) for line in mcq_text.split("\n"): while len(line) > 90: text_obj2.textLine(line[:90]) line = line[90:] text_obj2.textLine(line) p2.drawText(text_obj2) p2.showPage() p2.save() mcq_pdf_buf.seek(0) st.download_button("⬇️ Download MCQs (DOCX)", mcq_docx_buf, file_name="mcqs.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document") st.download_button("⬇️ Download MCQs (PDF)", mcq_pdf_buf, file_name="mcqs.pdf", mime="application/pdf") except Exception as e: st.error(f"Error generating summary or MCQs: {e}") else: st.info("Check the box above to generate summary + MCQs from retrieved content.")