abid-ai's picture
Update app.py
8db4914 verified
# app.py
"""
AI Study Assistant - Streamlit Application
Features:
- Upload PDF, extract text (pdfplumber / PyPDF2 fallback)
- Summarize document using OpenAI Chat API
- Generate 25+ MCQs (4 options each) using OpenAI
- Retrieval-based Q&A (embeddings + similarity)
- Handwriting-style fonts and professional UI
- Download combined output (summary, MCQs, Q&A history) as markdown (.md/.txt)
- Caching and basic cost-optimizations
"""
import os
import io
import time
import base64
import openai
#import pypdf2
from PyPDF2 import PdfReader
import pdfplumber
import dotenv # Corrected from python-dotenv
from typing import List, Tuple, Dict, Optional
import streamlit as st
import pdfplumber
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv
import openai
# Load .env if present (local dev)
load_dotenv()
# Streamlit page config
st.set_page_config(page_title="AI Study Assistant", layout="wide", initial_sidebar_state="expanded")
# -------------------------
# CSS / Fonts (handwriting)
# -------------------------
HANDWRITING_FONTS = [
"Patrick Hand",
"Caveat",
"Indie Flower",
"Reenie Beanie"
]
google_fonts = "+".join([f"{f.replace(' ', '+')}:wght@400;700" for f in HANDWRITING_FONTS])
st.markdown(
f"<link href=\"https://fonts.googleapis.com/css2?family=Patrick+Hand&family=Caveat&family=Indie+Flower&family=Reenie+Beanie&display=swap\" rel=\"stylesheet\">",
unsafe_allow_html=True
)
st.markdown(
f"""
<style>
:root {{
--handwriting: "{HANDWRITING_FONTS[0]}", "{HANDWRITING_FONTS[1]}", cursive, sans-serif;
}}
body {{
background: linear-gradient(180deg,#fbfbff,#ffffff);
}}
.handwriting {{
font-family: var(--handwriting);
}}
.mcq-block {{
white-space: pre-wrap;
font-family: var(--handwriting);
padding: 12px;
border-radius: 8px;
background: #fffdf7;
border: 1px solid #f1e6d6;
}}
.qa-box {{
background: #ffffff;
border-radius: 8px;
padding: 10px;
box-shadow: 0 2px 8px rgba(12,12,12,0.05);
}}
.small-muted {{
font-size:12px;color:#6b7280;
}}
.download-link {{
margin-top: 8px;
}}
</style>
""",
unsafe_allow_html=True
)
# -------------------------
# Sidebar inputs / config
# -------------------------
st.sidebar.title("AI Study Assistant — Settings")
# API Key input (secure)
openai_key = st.sidebar.text_input("OpenAI API Key (start with sk-)", type="password", help="Your OpenAI API key. For Spaces add it to Secrets.")
if openai_key:
os.environ["OPENAI_API_KEY"] = openai_key
elif "OPENAI_API_KEY" in os.environ:
openai_key = os.environ.get("OPENAI_API_KEY")
# Model selection
model_choice = st.sidebar.selectbox("Generation model", options=["gpt-4", "gpt-4o", "gpt-3.5-turbo"], index=0)
emb_model_choice = st.sidebar.selectbox("Embedding model", options=["text-embedding-3-small", "text-embedding-3-large"], index=0)
# MCQ count (min 25)
mcq_target = st.sidebar.number_input("Target number of MCQs", min_value=25, max_value=200, value=30, step=1)
# Chunk/retrieval settings
chunk_size = st.sidebar.number_input("Chunk size (words)", min_value=200, max_value=2000, value=700, step=50)
chunk_overlap = st.sidebar.number_input("Chunk overlap (words)", min_value=50, max_value=500, value=150, step=10)
retrieval_k = st.sidebar.number_input("Retrieval top-k", min_value=1, max_value=8, value=4, step=1)
st.sidebar.markdown("---")
st.sidebar.markdown("**Tips:** Use PDFs with selectable text for best results. Scanned PDFs may require OCR.")
# -------------------------
# OpenAI initialization
# -------------------------
def ensure_openai_key():
key = os.environ.get("OPENAI_API_KEY", None)
if not key:
raise RuntimeError("OpenAI API key not found. Set it in the sidebar or add OPENAI_API_KEY to environment.")
openai.api_key = key
# -------------------------
# PDF extraction utilities
# -------------------------
@st.cache_data(show_spinner=False)
def extract_text_pdfplumber(file_bytes: bytes) -> str:
"""Extract text using pdfplumber (best for most PDFs). Cached to avoid repeated work."""
text_pages = []
try:
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
for p in pdf.pages:
txt = p.extract_text()
if txt:
text_pages.append(txt)
except Exception as e:
# Let caller fallback to PyPDF2
raise e
return "\n\n".join(text_pages).strip()
@st.cache_data(show_spinner=False)
def extract_text_pypdf2(file_bytes: bytes) -> str:
"""Fallback extraction using PyPDF2."""
text_pages = []
try:
reader = PyPDF2.PdfReader(io.BytesIO(file_bytes))
for page in reader.pages:
try:
txt = page.extract_text()
except Exception:
txt = None
if txt:
text_pages.append(txt)
except Exception as e:
raise e
return "\n\n".join(text_pages).strip()
def extract_text(file_bytes: bytes) -> str:
"""Robust extraction: try pdfplumber first, fallback to PyPDF2."""
text = ""
try:
text = extract_text_pdfplumber(file_bytes)
if not text:
raise ValueError("pdfplumber returned empty text.")
except Exception:
text = extract_text_pypdf2(file_bytes)
return text
# -------------------------
# Chunking / embeddings / retrieval
# -------------------------
@st.cache_data(show_spinner=False)
def chunk_text(text: str, words_per_chunk: int = 700, overlap: int = 150) -> List[str]:
words = text.split()
chunks = []
start = 0
L = len(words)
while start < L:
end = min(start + words_per_chunk, L)
chunk = " ".join(words[start:end])
chunks.append(chunk)
start = end - overlap
if start < 0:
start = 0
return chunks
@st.cache_data(show_spinner=False)
def get_embeddings(texts: List[str], model: str) -> List[List[float]]:
ensure_openai_key()
# Batch call to embeddings API
resp = openai.Embedding.create(model=model, input=texts)
embeddings = [row["embedding"] for row in resp["data"]]
return embeddings
def top_k_chunks(question: str, chunks: List[str], chunk_embs: List[List[float]], k: int = 4, emb_model: str = "text-embedding-3-small"):
ensure_openai_key()
# compute question embedding
q_emb = get_embeddings([question], model=emb_model)[0]
sims = cosine_similarity([q_emb], chunk_embs)[0]
idx = np.argsort(sims)[-k:][::-1]
selected = [chunks[i] for i in idx]
return selected, idx
# -------------------------
# OpenAI Chat wrappers
# -------------------------
def call_chat_completion(messages: List[Dict], model: str = "gpt-3.5-turbo", max_tokens: int = 700, temperature: float = 0.2):
ensure_openai_key()
try:
resp = openai.ChatCompletion.create(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
return resp["choices"][0]["message"]["content"].strip()
except openai.error.OpenAIError as e:
raise RuntimeError(f"OpenAI API error: {e}")
# -------------------------
# Prompt engineering functions
# -------------------------
def generate_summary(full_text: str, model: str = "gpt-4") -> str:
"""
Create a concise but comprehensive summary with headings and key bullets.
To reduce tokens we can ask the model to summarize sections first (but here we send full text).
"""
prompt = [
{
"role": "system",
"content": "You are an assistant that summarizes documents for study and revision."
},
{
"role": "user",
"content": (
"Summarize the following document for exam revision. "
"Provide a concise executive summary (3-6 sentences), then key takeaways as bullet points, and a short list of important terms and definitions. "
"Use clear headings. Keep the style formal and compact.\n\n"
f"Document:\n\n{full_text}"
)
}
]
# Limit tokens to protect cost; large docs may need chunked summarization — user can call again if needed
return call_chat_completion(prompt, model=model, max_tokens=900, temperature=0.2)
def generate_mcqs(full_text: str, model: str = "gpt-4", count: int = 30) -> str:
"""
Generate MCQs formatted consistently. We ask the model to return plaintext in a structured format.
"""
instruction = (
f"Create {count} multiple-choice questions (MCQs) based on the document below. "
"Each question must have 4 options labeled A, B, C, D and one correct answer. "
"Make questions diverse (recall, concept, application). Mark the correct answer on a separate 'Answer:' line. "
"Format EXACTLY like this for each question:\n\n"
"Question <n>: <question text>\n\n"
" A. <option A>\n"
" B. <option B>\n"
" C. <option C>\n"
"D. <option D>\n\n"
"Answer: <LETTER>\n\n"
"Do NOT include explanations. Keep each question short and clear."
)
prompt = [
{"role": "system", "content": "You are an experienced instructor who writes high-quality MCQs."},
{"role": "user", "content": instruction + "\n\nDocument:\n\n" + full_text}
]
return call_chat_completion(prompt, model=model, max_tokens=2200, temperature=0.3)
def answer_question(question: str, chunks: List[str], chunk_embs: List[List[float]], emb_model: str, gen_model: str, top_k: int = 4) -> str:
"""
Retrieval-augmented answer: pick top_k chunks and ask model to answer using only that context.
"""
selected_chunks, idx = top_k_chunks(question, chunks, chunk_embs, k=top_k, emb_model=emb_model)
context = "\n\n---\n\n".join(selected_chunks)
prompt = [
{"role": "system", "content": "You are an assistant that answers questions using the provided context. If the answer is not in the context, say you could not find it."},
{"role": "user", "content": f"Context:\n\n{context}\n\nQuestion: {question}\n\nAnswer concisely and cite which chunk indexes (0-based) you used."}
]
return call_chat_completion(prompt, model=gen_model, max_tokens=400, temperature=0.2)
# -------------------------
# Download helpers
# -------------------------
def make_text_download(content: str, filename: str = "study_package.md"):
b64 = base64.b64encode(content.encode()).decode()
href = f'<a class="download-link" href="data:text/markdown;base64,{b64}" download="{filename}">Download {filename}</a>'
return href
# -------------------------
# Session state initialization
# -------------------------
if "qa_history" not in st.session_state:
st.session_state["qa_history"] = [] # list of dicts: question, answer, time
if "summary" not in st.session_state:
st.session_state["summary"] = None
if "mcq_text" not in st.session_state:
st.session_state["mcq_text"] = None
if "chunks" not in st.session_state:
st.session_state["chunks"] = None
if "chunk_embeddings" not in st.session_state:
st.session_state["chunk_embeddings"] = None
# -------------------------
# App UI layout
# -------------------------
st.title("📘 AI Study Assistant")
st.caption("Upload a PDF and generate a summary, 25+ MCQs, and interactively ask questions about the content.")
# Main layout: left column for upload + actions, right for results
left_col, right_col = st.columns([1.4, 2])
with left_col:
st.header("Upload & Settings")
uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"], help="Choose a PDF with selectable text for best results.")
if uploaded_file:
# Read bytes
file_bytes = uploaded_file.read()
st.write(f"**Filename:** {uploaded_file.name}{len(file_bytes)//1024} KB")
# Try extracting text
with st.spinner("Extracting text from PDF..."):
try:
full_text = extract_text(file_bytes)
if not full_text or len(full_text.strip()) < 50:
st.warning("Extracted text is short or empty. The PDF may be scanned images. Try another PDF or enable OCR.")
else:
st.success(f"Extracted {len(full_text.split())} words from PDF.")
# Save in session
st.session_state["full_text"] = full_text
except Exception as e:
st.error(f"Failed to extract text: {e}")
st.stop()
else:
st.info("Please upload a PDF to enable summary and MCQ generation.")
# Action buttons
st.markdown("---")
st.header("Generate Content")
colA, colB = st.columns([1,1])
with colA:
if st.button("Generate Summary"):
if not uploaded_file:
st.error("Upload a PDF first.")
else:
try:
with st.spinner("Generating summary (OpenAI)..."):
ensure_openai_key()
# If document is very large, you might want to chunk and summarize iteratively.
summary_text = generate_summary(st.session_state["full_text"], model=model_choice)
st.session_state["summary"] = summary_text
st.success("Summary generated.")
except Exception as e:
st.error(f"Summary generation failed: {e}")
with colB:
if st.button(f"Generate {mcq_target} MCQs"):
if not uploaded_file:
st.error("Upload a PDF first.")
else:
try:
with st.spinner("Generating MCQs (this may take a moment)..."):
ensure_openai_key()
mcq_text = generate_mcqs(st.session_state["full_text"], model=model_choice, count=int(mcq_target))
st.session_state["mcq_text"] = mcq_text
st.success("MCQs generated.")
except Exception as e:
st.error(f"MCQ generation failed: {e}")
# Generate both
if st.button("Generate Summary + MCQs"):
if not uploaded_file:
st.error("Upload a PDF first.")
else:
try:
with st.spinner("Generating summary + MCQs..."):
ensure_openai_key()
st.session_state["summary"] = generate_summary(st.session_state["full_text"], model=model_choice)
st.session_state["mcq_text"] = generate_mcqs(st.session_state["full_text"], model=model_choice, count=int(mcq_target))
st.success("Summary and MCQs generated.")
except Exception as e:
st.error(f"Combined generation failed: {e}")
# Prepare retrieval infrastructure
if uploaded_file and ("full_text" in st.session_state):
if st.button("Prepare Q&A (create embeddings)"):
try:
with st.spinner("Chunking document and computing embeddings (costly operation)..."):
chunks = chunk_text(st.session_state["full_text"], words_per_chunk=int(chunk_size), overlap=int(chunk_overlap))
st.session_state["chunks"] = chunks
# Compute embeddings (cached)
chunk_embs = get_embeddings(chunks, model=emb_model_choice)
st.session_state["chunk_embeddings"] = chunk_embs
st.success(f"Prepared {len(chunks)} chunks and embeddings for retrieval.")
except Exception as e:
st.error(f"Failed to prepare embeddings: {e}")
st.markdown("---")
st.header("Download / Export")
st.markdown("After generating content, download a combined study package.")
if st.session_state.get("summary") or st.session_state.get("mcq_text") or st.session_state["qa_history"]:
# Compose markdown
composed = []
if st.session_state.get("summary"):
composed.append("# Summary\n\n" + st.session_state["summary"] + "\n\n")
if st.session_state.get("mcq_text"):
composed.append("# MCQs\n\n" + st.session_state["mcq_text"] + "\n\n")
if st.session_state.get("qa_history"):
qalist = ["# Q&A History\n"]
for qa in st.session_state["qa_history"]:
qalist.append(f"**Q:** {qa['question']}\n\n**A:** {qa['answer']}\n\n_Time:_ {qa['time']}\n\n")
composed.append("\n".join(qalist))
package_md = "\n".join(composed)
st.markdown(make_text_download(package_md, filename=f"{uploaded_file.name}_study_package.md"), unsafe_allow_html=True)
st.download_button("Download study package (.md)", package_md, file_name=f"{uploaded_file.name}_study_package.md", mime="text/markdown")
else:
st.info("No generated content yet. Run summary/MCQ generation first.")
with right_col:
# Tabs: Summary, MCQ Quiz, Q&A
tab1, tab2, tab3 = st.tabs(["\U0001f4d1 Summary", "\U0001f4dd MCQ Quiz", "\u2753 Q&A Dashboard"])
with tab1:
st.header("Document Summary")
if st.session_state.get("summary"):
st.markdown("<div class='qa-box handwriting'>", unsafe_allow_html=True)
st.markdown(st.session_state["summary"], unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
else:
st.info("No summary yet. Click 'Generate Summary' in the left panel.")
with tab2:
st.header("Generated MCQs")
if st.session_state.get("mcq_text"):
# Display with formatting: question line and indented options vertically
st.markdown("<div class='mcq-block'>", unsafe_allow_html=True)
# We display as preformatted but with handwriting font and indentation
st.text_area("MCQs (read-only)", value=st.session_state["mcq_text"], height=420, key="mcq_display")
st.markdown("</div>", unsafe_allow_html=True)
# Also provide CSV download parsed
def parse_mcqs_to_df(mcq_text: str) -> pd.DataFrame:
lines = mcq_text.splitlines()
rows = []
q_text = None
opts = {"A":"","B":"","C":"","D":""}
answer = ""
for ln in lines:
if not ln.strip():
continue
# Question detection: starts with "Question" or "Q"
if ln.strip().lower().startswith("question"):
if q_text:
rows.append({"question": q_text.strip(), "A": opts["A"].strip(), "B": opts["B"].strip(), "C": opts["C"].strip(), "D": opts["D"].strip(), "answer": answer.strip()})
# reset
parts = ln.split(":",1)
if len(parts) > 1:
q_text = parts[1].strip()
else:
q_text = ln.strip()
opts = {"A":"","B":"","C":"","D":""}
answer = ""
elif ln.strip().startswith("A.") or ln.strip().startswith("A)"):
opts["A"] = ln.strip()[2:].strip()
elif ln.strip().startswith("B.") or ln.strip().startswith("B)"):
opts["B"] = ln.strip()[2:].strip()
elif ln.strip().startswith("C.") or ln.strip().startswith("C)"):
opts["C"] = ln.strip()[2:].strip()
elif ln.strip().startswith("D.") or ln.strip().startswith("D)"):
opts["D"] = ln.strip()[2:].strip()
elif ln.strip().lower().startswith("answer"):
parts = ln.split(":",1)
if len(parts) > 1:
answer = parts[1].strip()
if q_text:
rows.append({"question": q_text.strip(), "A": opts["A"].strip(), "B": opts["B"].strip(), "C": opts["C"].strip(), "D": opts["D"].strip(), "answer": answer.strip()})
return pd.DataFrame(rows)
df_mcq = parse_mcqs_to_df(st.session_state["mcq_text"])
if not df_mcq.empty:
st.download_button("Download MCQs as CSV", df_mcq.to_csv(index=False), file_name=f"{uploaded_file.name}_mcqs.csv", mime="text/csv")
else:
st.info("No MCQs generated yet. Click 'Generate MCQs' in the left panel.")
with tab3:
st.header("Q&A Dashboard")
st.markdown("Ask questions about the PDF. Use 'Prepare Q&A' first (computes embeddings).")
question_input = st.text_input("Enter your question here:")
if st.button("Ask question"):
if not st.session_state.get("chunks") or not st.session_state.get("chunk_embeddings"):
st.warning("Please click 'Prepare Q&A (create embeddings)' in the left panel first.")
elif not question_input.strip():
st.error("Please type a question.")
else:
try:
with st.spinner("Retrieving context and generating answer..."):
ans = answer_question(question_input, st.session_state["chunks"], st.session_state["chunk_embeddings"], emb_model_choice, model_choice, top_k=int(retrieval_k))
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
st.session_state["qa_history"].append({"question": question_input, "answer": ans, "time": timestamp})
st.success("Answer generated.")
except Exception as e:
st.error(f"Q&A failed: {e}")
# Show history
if st.session_state["qa_history"]:
st.markdown("### Recent Q&A")
for qa in reversed(st.session_state["qa_history"][-8:]):
st.markdown(f"<div class='qa-box'><strong>Q:</strong> {qa['question']}<br/><strong>A:</strong> {qa['answer']}<div class='small-muted'>Time: {qa['time']}</div></div>", unsafe_allow_html=True)
# Download Q&A
qa_md = "\n\n".join([f"Q: {qa['question']}\nA: {qa['answer']}\nTime: {qa['time']}" for qa in st.session_state["qa_history"]])
st.download_button("Download Q&A history (.txt)", qa_md, file_name=f"{uploaded_file.name}_qa_history.txt", mime="text/plain")
else:
st.info("No Q&A history yet.")
# -------------------------
# Footer
# -------------------------
st.markdown("---")
st.markdown("Developed as **AI Study Assistant** — Upload a PDF, generate summary & MCQs, and ask questions!")
# End of app.py