Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import io | |
| import streamlit as st | |
| import pdfplumber | |
| from pptx import Presentation | |
| import docx as docx_lib | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| from groq import Groq | |
| import markdown2 | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| # ---------------- CONFIG ---------------- | |
| EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| GROQ_LLM_MODEL = "llama-3.3-70b-versatile" | |
| # ---------------- HELPERS ---------------- | |
| def load_embedder(): | |
| return SentenceTransformer(EMBED_MODEL) | |
| embedder = load_embedder() | |
| def parse_pdf_bytes(file_bytes): | |
| try: | |
| text = "" | |
| with pdfplumber.open(io.BytesIO(file_bytes)) as pdf: | |
| for page in pdf.pages: | |
| p = page.extract_text() | |
| if p: | |
| text += p + "\n" | |
| return text | |
| except Exception as e: | |
| st.warning(f"PDF parse warning: {e}") | |
| return "" | |
| def parse_docx_bytes(file_bytes): | |
| try: | |
| doc = docx_lib.Document(io.BytesIO(file_bytes)) | |
| return "\n".join([p.text for p in doc.paragraphs]) | |
| except Exception as e: | |
| st.warning(f"DOCX parse warning: {e}") | |
| return "" | |
| def parse_pptx_bytes(file_bytes): | |
| try: | |
| prs = Presentation(io.BytesIO(file_bytes)) | |
| text = "" | |
| for slide in prs.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text"): | |
| text += shape.text + "\n" | |
| return text | |
| except Exception as e: | |
| st.warning(f"PPTX parse warning: {e}") | |
| return "" | |
| def parse_spreadsheet_bytes(file_bytes): | |
| try: | |
| try: | |
| df = pd.read_excel(io.BytesIO(file_bytes)) | |
| except Exception: | |
| df = pd.read_csv(io.BytesIO(file_bytes)) | |
| return df.to_csv(index=False) | |
| except Exception as e: | |
| st.warning(f"Spreadsheet parse warning: {e}") | |
| return "" | |
| def parse_txt_bytes(file_bytes): | |
| try: | |
| return file_bytes.decode("utf-8", errors="ignore") | |
| except Exception as e: | |
| st.warning(f"TXT parse warning: {e}") | |
| return "" | |
| def chunk_text(text, max_chars=1000, overlap=200): | |
| if not text: | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = min(start + max_chars, len(text)) | |
| chunk = text[start:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| if end == len(text): | |
| break | |
| start = end - overlap | |
| return chunks | |
| def build_faiss_index(chunks, embedder): | |
| if not chunks: | |
| return None, None | |
| embeddings = embedder.encode(chunks, convert_to_numpy=True) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dim) | |
| index.add(embeddings.astype("float32")) | |
| return index, embeddings | |
| def retrieve_chunks(query, embedder, faiss_index, chunks, k=5): | |
| if faiss_index is None or not chunks: | |
| return [] | |
| q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32") | |
| D, I = faiss_index.search(q_emb, k) | |
| results = [] | |
| for idx in I[0]: | |
| if 0 <= idx < len(chunks): | |
| results.append(chunks[idx]) | |
| return results | |
| # ---------------- Groq LLM ---------------- | |
| EDU_PROMPTS = { | |
| "Primary School": "Explain this to me like I'm 5 years old, in a fun and simple way with examples and analogies.", | |
| "Middle School": "Explain this in a simple and clear way appropriate for a middle school student with examples.", | |
| "High School": "Explain this clearly, assuming knowledge up to high school level.", | |
| "Undergraduate": "Explain this in a university-level way, with clarity and useful details and examples.", | |
| "Graduate": "Explain this at graduate-level rigor, including key details, nuance, and technical terms as appropriate.", | |
| } | |
| def get_groq_client(): | |
| api_key = None | |
| try: | |
| api_key = st.secrets[""] | |
| except Exception: | |
| pass | |
| if not api_key: | |
| api_key = st.session_state.get("groq_api_key") or os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError("Groq API key not found. Set st.secrets['GROQ_API_KEY'], or enter in sidebar, or set env GROQ_API_KEY.") | |
| return Groq(api_key=api_key) | |
| def call_llm_with_context(question, retrieved_chunks, edu_level): | |
| client = get_groq_client() | |
| edu_instr = EDU_PROMPTS.get(edu_level, "") | |
| context = "\n\n".join(retrieved_chunks) if retrieved_chunks else "" | |
| user_content = f"{edu_instr}\n\nContext:\n{context}\n\nQuestion: {question}" | |
| response = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful and knowledgeable tutor."}, | |
| {"role": "user", "content": user_content} | |
| ], | |
| model=GROQ_LLM_MODEL, | |
| ) | |
| return response.choices[0].message.content | |
| def make_summary(question, retrieved_chunks, edu_level): | |
| client = get_groq_client() | |
| edu_instr = EDU_PROMPTS.get(edu_level, "") | |
| context = "\n\n".join(retrieved_chunks) if retrieved_chunks else "" | |
| prompt = f"{edu_instr}\n\nHere is some context:\n{context}\n\nPlease give a short, easy-to-understand summary of: {question}\nKeep it concise and simple; use bullet points if helpful." | |
| response = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a concise summarizer."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=GROQ_LLM_MODEL, | |
| ) | |
| return response.choices[0].message.content | |
| def make_mcqs_from_summary(summary_text, count=5, difficulty="medium"): | |
| client = get_groq_client() | |
| prompt = ( | |
| f"Create {count} multiple choice questions (MCQs) from the following summary. " | |
| "Each question should have 4 options labeled A-D and indicate the correct option. " | |
| "Also provide a 1-2 sentence explanation for the correct answer. " | |
| f"Difficulty: {difficulty}.\n\nSummary:\n{summary_text}" | |
| ) | |
| response = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are an assistant that generates high-quality multiple-choice questions."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=GROQ_LLM_MODEL, | |
| ) | |
| return response.choices[0].message.content | |
| # ---------------- STREAMLIT UI ---------------- | |
| st.set_page_config(page_title="AI Study Assistant", layout="wide") | |
| st.title("📚 AI Study Assistant — Exam Mode") | |
| with st.sidebar: | |
| st.header("Settings") | |
| groq_key = st.text_input("Groq API key (optional)", type="password") | |
| if groq_key: | |
| st.session_state["groq_api_key"] = groq_key | |
| edu_level = st.selectbox("Education level", list(EDU_PROMPTS.keys())) | |
| st.info("Upload documents and ask questions. You can generate summaries + MCQs.") | |
| uploaded_files = st.file_uploader("Upload study documents (PDF, DOCX, PPTX, XLSX/CSV, TXT)", accept_multiple_files=True) | |
| if not uploaded_files: | |
| st.info("Please upload at least one document.") | |
| st.stop() | |
| # ---------------- PARSE FILES ---------------- | |
| all_text = "" | |
| for uf in uploaded_files: | |
| raw = uf.read() | |
| text = "" | |
| name = uf.name.lower() | |
| if name.endswith(".pdf"): | |
| text = parse_pdf_bytes(raw) | |
| elif name.endswith(".docx"): | |
| text = parse_docx_bytes(raw) | |
| elif name.endswith(".pptx"): | |
| text = parse_pptx_bytes(raw) | |
| elif name.endswith((".xls", ".xlsx", ".csv")): | |
| text = parse_spreadsheet_bytes(raw) | |
| elif name.endswith(".txt"): | |
| text = parse_txt_bytes(raw) | |
| else: | |
| try: | |
| text = raw.decode("utf-8") | |
| except Exception: | |
| text = "" | |
| if text: | |
| all_text += f"\n\n### From file: {uf.name}\n\n{text}" | |
| if not all_text.strip(): | |
| st.error("No textual content extracted.") | |
| st.stop() | |
| # ---------------- CHUNK + INDEX ---------------- | |
| with st.spinner("Processing documents..."): | |
| chunks = chunk_text(all_text) | |
| faiss_index, embeddings = build_faiss_index(chunks, embedder) | |
| st.success(f"Prepared {len(chunks)} chunks and built vector index.") | |
| # ---------------- ASK QUESTION ---------------- | |
| question = st.text_input("Ask a question about your materials:") | |
| if not question: | |
| st.info("Type a question to begin.") | |
| st.stop() | |
| topk = st.number_input("Top-k passages", min_value=1, max_value=10, value=5) | |
| mcq_count = st.number_input("MCQs to generate", min_value=1, max_value=20, value=5) | |
| mcq_diff = st.selectbox("MCQ difficulty", ["easy", "medium", "hard"], index=1) | |
| retrieved = retrieve_chunks(question, embedder, faiss_index, chunks, k=int(topk)) | |
| if retrieved: | |
| st.subheader("Relevant passages:") | |
| for i, r in enumerate(retrieved): | |
| st.markdown(f"**Passage {i+1}:**") | |
| st.write(r[:800] + ("..." if len(r) > 800 else "")) | |
| else: | |
| st.warning("No relevant passages found.") | |
| # ---------------- GENERATE ANSWER ---------------- | |
| try: | |
| answer = call_llm_with_context(question, retrieved, edu_level) | |
| st.subheader("Answer:") | |
| st.write(answer) | |
| except Exception as e: | |
| st.error(f"LLM error: {e}") | |
| st.stop() | |
| # ---------------- GENERATE SUMMARY + MCQs ---------------- | |
| if st.checkbox("Generate summary and MCQs"): | |
| try: | |
| summary = make_summary(question, retrieved, edu_level) | |
| st.subheader("📘 Summary") | |
| st.write(summary) | |
| # Downloads | |
| md_text = summary | |
| html_text = markdown2.markdown(summary) | |
| pdf_buffer = io.BytesIO() | |
| p = canvas.Canvas(pdf_buffer, pagesize=letter) | |
| width, height = letter | |
| text_obj = p.beginText(40, height - 40) | |
| for line in summary.split("\n"): | |
| while len(line) > 90: | |
| text_obj.textLine(line[:90]) | |
| line = line[90:] | |
| text_obj.textLine(line) | |
| p.drawText(text_obj) | |
| p.showPage() | |
| p.save() | |
| pdf_buffer.seek(0) | |
| # DOCX | |
| docx_buffer = io.BytesIO() | |
| doc = docx_lib.Document() | |
| doc.add_heading("Summary", level=1) | |
| for line in summary.split("\n"): | |
| doc.add_paragraph(line) | |
| doc.save(docx_buffer) | |
| docx_buffer.seek(0) | |
| st.download_button("⬇️ Download Summary (Markdown)", md_text, file_name="summary.md") | |
| st.download_button("⬇️ Download Summary (HTML)", html_text, file_name="summary.html", mime="text/html") | |
| st.download_button("⬇️ Download Summary (PDF)", pdf_buffer, file_name="summary.pdf", mime="application/pdf") | |
| st.download_button("⬇️ Download Summary (DOCX)", docx_buffer, file_name="summary.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document") | |
| # MCQs | |
| mcq_text = make_mcqs_from_summary(summary, count=int(mcq_count), difficulty=mcq_diff) | |
| st.subheader("📝 Generated MCQs") | |
| st.write(mcq_text) | |
| mcq_docx_buf = io.BytesIO() | |
| doc_mcq = docx_lib.Document() | |
| doc_mcq.add_heading("MCQs", level=1) | |
| for line in mcq_text.split("\n"): | |
| doc_mcq.add_paragraph(line) | |
| doc_mcq.save(mcq_docx_buf) | |
| mcq_docx_buf.seek(0) | |
| mcq_pdf_buf = io.BytesIO() | |
| p2 = canvas.Canvas(mcq_pdf_buf, pagesize=letter) | |
| text_obj2 = p2.beginText(40, height - 40) | |
| for line in mcq_text.split("\n"): | |
| while len(line) > 90: | |
| text_obj2.textLine(line[:90]) | |
| line = line[90:] | |
| text_obj2.textLine(line) | |
| p2.drawText(text_obj2) | |
| p2.showPage() | |
| p2.save() | |
| mcq_pdf_buf.seek(0) | |
| st.download_button("⬇️ Download MCQs (DOCX)", mcq_docx_buf, file_name="mcqs.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document") | |
| st.download_button("⬇️ Download MCQs (PDF)", mcq_pdf_buf, file_name="mcqs.pdf", mime="application/pdf") | |
| except Exception as e: | |
| st.error(f"Error generating summary or MCQs: {e}") | |
| else: | |
| st.info("Check the box above to generate summary + MCQs from retrieved content.") | |