Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
AI Study Assistant - Streamlit Application
|
| 3 |
Features:
|
|
@@ -14,11 +15,23 @@ import os
|
|
| 14 |
import io
|
| 15 |
import time
|
| 16 |
import base64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from typing import List, Tuple, Dict, Optional
|
| 18 |
-
import PyPDF2
|
| 19 |
-
|
| 20 |
|
| 21 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Streamlit page config
|
| 24 |
st.set_page_config(page_title="AI Study Assistant", layout="wide", initial_sidebar_state="expanded")
|
|
@@ -27,52 +40,52 @@ st.set_page_config(page_title="AI Study Assistant", layout="wide", initial_sideb
|
|
| 27 |
# CSS / Fonts (handwriting)
|
| 28 |
# -------------------------
|
| 29 |
HANDWRITING_FONTS = [
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
]
|
| 35 |
google_fonts = "+".join([f"{f.replace(' ', '+')}:wght@400;700" for f in HANDWRITING_FONTS])
|
| 36 |
st.markdown(
|
| 37 |
-
|
| 38 |
-
|
| 39 |
)
|
| 40 |
|
| 41 |
st.markdown(
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
)
|
| 77 |
|
| 78 |
# -------------------------
|
|
@@ -83,9 +96,9 @@ st.sidebar.title("AI Study Assistant — Settings")
|
|
| 83 |
# API Key input (secure)
|
| 84 |
openai_key = st.sidebar.text_input("OpenAI API Key (start with sk-)", type="password", help="Your OpenAI API key. For Spaces add it to Secrets.")
|
| 85 |
if openai_key:
|
| 86 |
-
|
| 87 |
elif "OPENAI_API_KEY" in os.environ:
|
| 88 |
-
|
| 89 |
|
| 90 |
# Model selection
|
| 91 |
model_choice = st.sidebar.selectbox("Generation model", options=["gpt-4", "gpt-4o", "gpt-3.5-turbo"], index=0)
|
|
@@ -106,194 +119,194 @@ st.sidebar.markdown("**Tips:** Use PDFs with selectable text for best results. S
|
|
| 106 |
# OpenAI initialization
|
| 107 |
# -------------------------
|
| 108 |
def ensure_openai_key():
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
|
| 114 |
# -------------------------
|
| 115 |
# PDF extraction utilities
|
| 116 |
# -------------------------
|
| 117 |
@st.cache_data(show_spinner=False)
|
| 118 |
def extract_text_pdfplumber(file_bytes: bytes) -> str:
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
|
| 132 |
@st.cache_data(show_spinner=False)
|
| 133 |
def extract_text_pypdf2(file_bytes: bytes) -> str:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
|
| 149 |
def extract_text(file_bytes: bytes) -> str:
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
|
| 160 |
# -------------------------
|
| 161 |
# Chunking / embeddings / retrieval
|
| 162 |
# -------------------------
|
| 163 |
@st.cache_data(show_spinner=False)
|
| 164 |
def chunk_text(text: str, words_per_chunk: int = 700, overlap: int = 150) -> List[str]:
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
|
| 178 |
@st.cache_data(show_spinner=False)
|
| 179 |
def get_embeddings(texts: List[str], model: str) -> List[List[float]]:
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
|
| 186 |
def top_k_chunks(question: str, chunks: List[str], chunk_embs: List[List[float]], k: int = 4, emb_model: str = "text-embedding-3-small"):
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
|
| 195 |
# -------------------------
|
| 196 |
# OpenAI Chat wrappers
|
| 197 |
# -------------------------
|
| 198 |
def call_chat_completion(messages: List[Dict], model: str = "gpt-3.5-turbo", max_tokens: int = 700, temperature: float = 0.2):
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
|
| 211 |
# -------------------------
|
| 212 |
# Prompt engineering functions
|
| 213 |
# -------------------------
|
| 214 |
def generate_summary(full_text: str, model: str = "gpt-4") -> str:
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
|
| 237 |
def generate_mcqs(full_text: str, model: str = "gpt-4", count: int = 30) -> str:
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
|
| 260 |
def answer_question(question: str, chunks: List[str], chunk_embs: List[List[float]], emb_model: str, gen_model: str, top_k: int = 4) -> str:
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
|
| 272 |
# -------------------------
|
| 273 |
# Download helpers
|
| 274 |
# -------------------------
|
| 275 |
def make_text_download(content: str, filename: str = "study_package.md"):
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
|
| 280 |
# -------------------------
|
| 281 |
# Session state initialization
|
| 282 |
# -------------------------
|
| 283 |
if "qa_history" not in st.session_state:
|
| 284 |
-
|
| 285 |
|
| 286 |
if "summary" not in st.session_state:
|
| 287 |
-
|
| 288 |
|
| 289 |
if "mcq_text" not in st.session_state:
|
| 290 |
-
|
| 291 |
|
| 292 |
if "chunks" not in st.session_state:
|
| 293 |
-
|
| 294 |
|
| 295 |
if "chunk_embeddings" not in st.session_state:
|
| 296 |
-
|
| 297 |
|
| 298 |
# -------------------------
|
| 299 |
# App UI layout
|
|
@@ -305,205 +318,205 @@ st.caption("Upload a PDF and generate a summary, 25+ MCQs, and interactively ask
|
|
| 305 |
left_col, right_col = st.columns([1.4, 2])
|
| 306 |
|
| 307 |
with left_col:
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
|
| 412 |
with right_col:
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
|
| 508 |
# -------------------------
|
| 509 |
# Footer
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
"""
|
| 3 |
AI Study Assistant - Streamlit Application
|
| 4 |
Features:
|
|
|
|
| 15 |
import io
|
| 16 |
import time
|
| 17 |
import base64
|
| 18 |
+
import openai
|
| 19 |
+
#import pypdf2
|
| 20 |
+
from PyPDF2 import PdfReader
|
| 21 |
+
import pdfplumber
|
| 22 |
+
import dotenv # Corrected from python-dotenv
|
| 23 |
from typing import List, Tuple, Dict, Optional
|
|
|
|
|
|
|
| 24 |
|
| 25 |
import streamlit as st
|
| 26 |
+
import pdfplumber
|
| 27 |
+
|
| 28 |
+
import pandas as pd
|
| 29 |
+
import numpy as np
|
| 30 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 31 |
+
from dotenv import load_dotenv
|
| 32 |
+
import openai
|
| 33 |
+
# Load .env if present (local dev)
|
| 34 |
+
load_dotenv()
|
| 35 |
|
| 36 |
# Streamlit page config
|
| 37 |
st.set_page_config(page_title="AI Study Assistant", layout="wide", initial_sidebar_state="expanded")
|
|
|
|
| 40 |
# CSS / Fonts (handwriting)
|
| 41 |
# -------------------------
|
| 42 |
HANDWRITING_FONTS = [
|
| 43 |
+
"Patrick Hand",
|
| 44 |
+
"Caveat",
|
| 45 |
+
"Indie Flower",
|
| 46 |
+
"Reenie Beanie"
|
| 47 |
]
|
| 48 |
google_fonts = "+".join([f"{f.replace(' ', '+')}:wght@400;700" for f in HANDWRITING_FONTS])
|
| 49 |
st.markdown(
|
| 50 |
+
f"<link href=\"https://fonts.googleapis.com/css2?family=Patrick+Hand&family=Caveat&family=Indie+Flower&family=Reenie+Beanie&display=swap\" rel=\"stylesheet\">",
|
| 51 |
+
unsafe_allow_html=True
|
| 52 |
)
|
| 53 |
|
| 54 |
st.markdown(
|
| 55 |
+
f"""
|
| 56 |
+
<style>
|
| 57 |
+
:root {{
|
| 58 |
+
--handwriting: "{HANDWRITING_FONTS[0]}", "{HANDWRITING_FONTS[1]}", cursive, sans-serif;
|
| 59 |
+
}}
|
| 60 |
+
body {{
|
| 61 |
+
background: linear-gradient(180deg,#fbfbff,#ffffff);
|
| 62 |
+
}}
|
| 63 |
+
.handwriting {{
|
| 64 |
+
font-family: var(--handwriting);
|
| 65 |
+
}}
|
| 66 |
+
.mcq-block {{
|
| 67 |
+
white-space: pre-wrap;
|
| 68 |
+
font-family: var(--handwriting);
|
| 69 |
+
padding: 12px;
|
| 70 |
+
border-radius: 8px;
|
| 71 |
+
background: #fffdf7;
|
| 72 |
+
border: 1px solid #f1e6d6;
|
| 73 |
+
}}
|
| 74 |
+
.qa-box {{
|
| 75 |
+
background: #ffffff;
|
| 76 |
+
border-radius: 8px;
|
| 77 |
+
padding: 10px;
|
| 78 |
+
box-shadow: 0 2px 8px rgba(12,12,12,0.05);
|
| 79 |
+
}}
|
| 80 |
+
.small-muted {{
|
| 81 |
+
font-size:12px;color:#6b7280;
|
| 82 |
+
}}
|
| 83 |
+
.download-link {{
|
| 84 |
+
margin-top: 8px;
|
| 85 |
+
}}
|
| 86 |
+
</style>
|
| 87 |
+
""",
|
| 88 |
+
unsafe_allow_html=True
|
| 89 |
)
|
| 90 |
|
| 91 |
# -------------------------
|
|
|
|
| 96 |
# API Key input (secure)
|
| 97 |
openai_key = st.sidebar.text_input("OpenAI API Key (start with sk-)", type="password", help="Your OpenAI API key. For Spaces add it to Secrets.")
|
| 98 |
if openai_key:
|
| 99 |
+
os.environ["OPENAI_API_KEY"] = openai_key
|
| 100 |
elif "OPENAI_API_KEY" in os.environ:
|
| 101 |
+
openai_key = os.environ.get("OPENAI_API_KEY")
|
| 102 |
|
| 103 |
# Model selection
|
| 104 |
model_choice = st.sidebar.selectbox("Generation model", options=["gpt-4", "gpt-4o", "gpt-3.5-turbo"], index=0)
|
|
|
|
| 119 |
# OpenAI initialization
|
| 120 |
# -------------------------
|
| 121 |
def ensure_openai_key():
|
| 122 |
+
key = os.environ.get("OPENAI_API_KEY", None)
|
| 123 |
+
if not key:
|
| 124 |
+
raise RuntimeError("OpenAI API key not found. Set it in the sidebar or add OPENAI_API_KEY to environment.")
|
| 125 |
+
openai.api_key = key
|
| 126 |
|
| 127 |
# -------------------------
|
| 128 |
# PDF extraction utilities
|
| 129 |
# -------------------------
|
| 130 |
@st.cache_data(show_spinner=False)
|
| 131 |
def extract_text_pdfplumber(file_bytes: bytes) -> str:
|
| 132 |
+
"""Extract text using pdfplumber (best for most PDFs). Cached to avoid repeated work."""
|
| 133 |
+
text_pages = []
|
| 134 |
+
try:
|
| 135 |
+
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
|
| 136 |
+
for p in pdf.pages:
|
| 137 |
+
txt = p.extract_text()
|
| 138 |
+
if txt:
|
| 139 |
+
text_pages.append(txt)
|
| 140 |
+
except Exception as e:
|
| 141 |
+
# Let caller fallback to PyPDF2
|
| 142 |
+
raise e
|
| 143 |
+
return "\n\n".join(text_pages).strip()
|
| 144 |
|
| 145 |
@st.cache_data(show_spinner=False)
|
| 146 |
def extract_text_pypdf2(file_bytes: bytes) -> str:
|
| 147 |
+
"""Fallback extraction using PyPDF2."""
|
| 148 |
+
text_pages = []
|
| 149 |
+
try:
|
| 150 |
+
reader = PyPDF2.PdfReader(io.BytesIO(file_bytes))
|
| 151 |
+
for page in reader.pages:
|
| 152 |
+
try:
|
| 153 |
+
txt = page.extract_text()
|
| 154 |
+
except Exception:
|
| 155 |
+
txt = None
|
| 156 |
+
if txt:
|
| 157 |
+
text_pages.append(txt)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
raise e
|
| 160 |
+
return "\n\n".join(text_pages).strip()
|
| 161 |
|
| 162 |
def extract_text(file_bytes: bytes) -> str:
|
| 163 |
+
"""Robust extraction: try pdfplumber first, fallback to PyPDF2."""
|
| 164 |
+
text = ""
|
| 165 |
+
try:
|
| 166 |
+
text = extract_text_pdfplumber(file_bytes)
|
| 167 |
+
if not text:
|
| 168 |
+
raise ValueError("pdfplumber returned empty text.")
|
| 169 |
+
except Exception:
|
| 170 |
+
text = extract_text_pypdf2(file_bytes)
|
| 171 |
+
return text
|
| 172 |
|
| 173 |
# -------------------------
|
| 174 |
# Chunking / embeddings / retrieval
|
| 175 |
# -------------------------
|
| 176 |
@st.cache_data(show_spinner=False)
|
| 177 |
def chunk_text(text: str, words_per_chunk: int = 700, overlap: int = 150) -> List[str]:
|
| 178 |
+
words = text.split()
|
| 179 |
+
chunks = []
|
| 180 |
+
start = 0
|
| 181 |
+
L = len(words)
|
| 182 |
+
while start < L:
|
| 183 |
+
end = min(start + words_per_chunk, L)
|
| 184 |
+
chunk = " ".join(words[start:end])
|
| 185 |
+
chunks.append(chunk)
|
| 186 |
+
start = end - overlap
|
| 187 |
+
if start < 0:
|
| 188 |
+
start = 0
|
| 189 |
+
return chunks
|
| 190 |
|
| 191 |
@st.cache_data(show_spinner=False)
|
| 192 |
def get_embeddings(texts: List[str], model: str) -> List[List[float]]:
|
| 193 |
+
ensure_openai_key()
|
| 194 |
+
# Batch call to embeddings API
|
| 195 |
+
resp = openai.Embedding.create(model=model, input=texts)
|
| 196 |
+
embeddings = [row["embedding"] for row in resp["data"]]
|
| 197 |
+
return embeddings
|
| 198 |
|
| 199 |
def top_k_chunks(question: str, chunks: List[str], chunk_embs: List[List[float]], k: int = 4, emb_model: str = "text-embedding-3-small"):
|
| 200 |
+
ensure_openai_key()
|
| 201 |
+
# compute question embedding
|
| 202 |
+
q_emb = get_embeddings([question], model=emb_model)[0]
|
| 203 |
+
sims = cosine_similarity([q_emb], chunk_embs)[0]
|
| 204 |
+
idx = np.argsort(sims)[-k:][::-1]
|
| 205 |
+
selected = [chunks[i] for i in idx]
|
| 206 |
+
return selected, idx
|
| 207 |
|
| 208 |
# -------------------------
|
| 209 |
# OpenAI Chat wrappers
|
| 210 |
# -------------------------
|
| 211 |
def call_chat_completion(messages: List[Dict], model: str = "gpt-3.5-turbo", max_tokens: int = 700, temperature: float = 0.2):
|
| 212 |
+
ensure_openai_key()
|
| 213 |
+
try:
|
| 214 |
+
resp = openai.ChatCompletion.create(
|
| 215 |
+
model=model,
|
| 216 |
+
messages=messages,
|
| 217 |
+
max_tokens=max_tokens,
|
| 218 |
+
temperature=temperature
|
| 219 |
+
)
|
| 220 |
+
return resp["choices"][0]["message"]["content"].strip()
|
| 221 |
+
except openai.error.OpenAIError as e:
|
| 222 |
+
raise RuntimeError(f"OpenAI API error: {e}")
|
| 223 |
|
| 224 |
# -------------------------
|
| 225 |
# Prompt engineering functions
|
| 226 |
# -------------------------
|
| 227 |
def generate_summary(full_text: str, model: str = "gpt-4") -> str:
|
| 228 |
+
"""
|
| 229 |
+
Create a concise but comprehensive summary with headings and key bullets.
|
| 230 |
+
To reduce tokens we can ask the model to summarize sections first (but here we send full text).
|
| 231 |
+
"""
|
| 232 |
+
prompt = [
|
| 233 |
+
{
|
| 234 |
+
"role": "system",
|
| 235 |
+
"content": "You are an assistant that summarizes documents for study and revision."
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"role": "user",
|
| 239 |
+
"content": (
|
| 240 |
+
"Summarize the following document for exam revision. "
|
| 241 |
+
"Provide a concise executive summary (3-6 sentences), then key takeaways as bullet points, and a short list of important terms and definitions. "
|
| 242 |
+
"Use clear headings. Keep the style formal and compact.\n\n"
|
| 243 |
+
f"Document:\n\n{full_text}"
|
| 244 |
+
)
|
| 245 |
+
}
|
| 246 |
+
]
|
| 247 |
+
# Limit tokens to protect cost; large docs may need chunked summarization — user can call again if needed
|
| 248 |
+
return call_chat_completion(prompt, model=model, max_tokens=900, temperature=0.2)
|
| 249 |
|
| 250 |
def generate_mcqs(full_text: str, model: str = "gpt-4", count: int = 30) -> str:
|
| 251 |
+
"""
|
| 252 |
+
Generate MCQs formatted consistently. We ask the model to return plaintext in a structured format.
|
| 253 |
+
"""
|
| 254 |
+
instruction = (
|
| 255 |
+
f"Create {count} multiple-choice questions (MCQs) based on the document below. "
|
| 256 |
+
"Each question must have 4 options labeled A, B, C, D and one correct answer. "
|
| 257 |
+
"Make questions diverse (recall, concept, application). Mark the correct answer on a separate 'Answer:' line. "
|
| 258 |
+
"Format EXACTLY like this for each question:\n\n"
|
| 259 |
+
"Question <n>: <question text>\n\n"
|
| 260 |
+
" A. <option A>\n"
|
| 261 |
+
" B. <option B>\n"
|
| 262 |
+
" C. <option C>\n"
|
| 263 |
+
"D. <option D>\n\n"
|
| 264 |
+
"Answer: <LETTER>\n\n"
|
| 265 |
+
"Do NOT include explanations. Keep each question short and clear."
|
| 266 |
+
)
|
| 267 |
+
prompt = [
|
| 268 |
+
{"role": "system", "content": "You are an experienced instructor who writes high-quality MCQs."},
|
| 269 |
+
{"role": "user", "content": instruction + "\n\nDocument:\n\n" + full_text}
|
| 270 |
+
]
|
| 271 |
+
return call_chat_completion(prompt, model=model, max_tokens=2200, temperature=0.3)
|
| 272 |
|
| 273 |
def answer_question(question: str, chunks: List[str], chunk_embs: List[List[float]], emb_model: str, gen_model: str, top_k: int = 4) -> str:
|
| 274 |
+
"""
|
| 275 |
+
Retrieval-augmented answer: pick top_k chunks and ask model to answer using only that context.
|
| 276 |
+
"""
|
| 277 |
+
selected_chunks, idx = top_k_chunks(question, chunks, chunk_embs, k=top_k, emb_model=emb_model)
|
| 278 |
+
context = "\n\n---\n\n".join(selected_chunks)
|
| 279 |
+
prompt = [
|
| 280 |
+
{"role": "system", "content": "You are an assistant that answers questions using the provided context. If the answer is not in the context, say you could not find it."},
|
| 281 |
+
{"role": "user", "content": f"Context:\n\n{context}\n\nQuestion: {question}\n\nAnswer concisely and cite which chunk indexes (0-based) you used."}
|
| 282 |
+
]
|
| 283 |
+
return call_chat_completion(prompt, model=gen_model, max_tokens=400, temperature=0.2)
|
| 284 |
|
| 285 |
# -------------------------
|
| 286 |
# Download helpers
|
| 287 |
# -------------------------
|
| 288 |
def make_text_download(content: str, filename: str = "study_package.md"):
|
| 289 |
+
b64 = base64.b64encode(content.encode()).decode()
|
| 290 |
+
href = f'<a class="download-link" href="data:text/markdown;base64,{b64}" download="{filename}">Download {filename}</a>'
|
| 291 |
+
return href
|
| 292 |
|
| 293 |
# -------------------------
|
| 294 |
# Session state initialization
|
| 295 |
# -------------------------
|
| 296 |
if "qa_history" not in st.session_state:
|
| 297 |
+
st.session_state["qa_history"] = [] # list of dicts: question, answer, time
|
| 298 |
|
| 299 |
if "summary" not in st.session_state:
|
| 300 |
+
st.session_state["summary"] = None
|
| 301 |
|
| 302 |
if "mcq_text" not in st.session_state:
|
| 303 |
+
st.session_state["mcq_text"] = None
|
| 304 |
|
| 305 |
if "chunks" not in st.session_state:
|
| 306 |
+
st.session_state["chunks"] = None
|
| 307 |
|
| 308 |
if "chunk_embeddings" not in st.session_state:
|
| 309 |
+
st.session_state["chunk_embeddings"] = None
|
| 310 |
|
| 311 |
# -------------------------
|
| 312 |
# App UI layout
|
|
|
|
| 318 |
left_col, right_col = st.columns([1.4, 2])
|
| 319 |
|
| 320 |
with left_col:
|
| 321 |
+
st.header("Upload & Settings")
|
| 322 |
+
uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"], help="Choose a PDF with selectable text for best results.")
|
| 323 |
+
if uploaded_file:
|
| 324 |
+
# Read bytes
|
| 325 |
+
file_bytes = uploaded_file.read()
|
| 326 |
+
st.write(f"**Filename:** {uploaded_file.name} — {len(file_bytes)//1024} KB")
|
| 327 |
+
# Try extracting text
|
| 328 |
+
with st.spinner("Extracting text from PDF..."):
|
| 329 |
+
try:
|
| 330 |
+
full_text = extract_text(file_bytes)
|
| 331 |
+
if not full_text or len(full_text.strip()) < 50:
|
| 332 |
+
st.warning("Extracted text is short or empty. The PDF may be scanned images. Try another PDF or enable OCR.")
|
| 333 |
+
else:
|
| 334 |
+
st.success(f"Extracted {len(full_text.split())} words from PDF.")
|
| 335 |
+
# Save in session
|
| 336 |
+
st.session_state["full_text"] = full_text
|
| 337 |
+
except Exception as e:
|
| 338 |
+
st.error(f"Failed to extract text: {e}")
|
| 339 |
+
st.stop()
|
| 340 |
+
else:
|
| 341 |
+
st.info("Please upload a PDF to enable summary and MCQ generation.")
|
| 342 |
+
|
| 343 |
+
# Action buttons
|
| 344 |
+
st.markdown("---")
|
| 345 |
+
st.header("Generate Content")
|
| 346 |
+
colA, colB = st.columns([1,1])
|
| 347 |
+
with colA:
|
| 348 |
+
if st.button("Generate Summary"):
|
| 349 |
+
if not uploaded_file:
|
| 350 |
+
st.error("Upload a PDF first.")
|
| 351 |
+
else:
|
| 352 |
+
try:
|
| 353 |
+
with st.spinner("Generating summary (OpenAI)..."):
|
| 354 |
+
ensure_openai_key()
|
| 355 |
+
# If document is very large, you might want to chunk and summarize iteratively.
|
| 356 |
+
summary_text = generate_summary(st.session_state["full_text"], model=model_choice)
|
| 357 |
+
st.session_state["summary"] = summary_text
|
| 358 |
+
st.success("Summary generated.")
|
| 359 |
+
except Exception as e:
|
| 360 |
+
st.error(f"Summary generation failed: {e}")
|
| 361 |
+
|
| 362 |
+
with colB:
|
| 363 |
+
if st.button(f"Generate {mcq_target} MCQs"):
|
| 364 |
+
if not uploaded_file:
|
| 365 |
+
st.error("Upload a PDF first.")
|
| 366 |
+
else:
|
| 367 |
+
try:
|
| 368 |
+
with st.spinner("Generating MCQs (this may take a moment)..."):
|
| 369 |
+
ensure_openai_key()
|
| 370 |
+
mcq_text = generate_mcqs(st.session_state["full_text"], model=model_choice, count=int(mcq_target))
|
| 371 |
+
st.session_state["mcq_text"] = mcq_text
|
| 372 |
+
st.success("MCQs generated.")
|
| 373 |
+
except Exception as e:
|
| 374 |
+
st.error(f"MCQ generation failed: {e}")
|
| 375 |
+
|
| 376 |
+
# Generate both
|
| 377 |
+
if st.button("Generate Summary + MCQs"):
|
| 378 |
+
if not uploaded_file:
|
| 379 |
+
st.error("Upload a PDF first.")
|
| 380 |
+
else:
|
| 381 |
+
try:
|
| 382 |
+
with st.spinner("Generating summary + MCQs..."):
|
| 383 |
+
ensure_openai_key()
|
| 384 |
+
st.session_state["summary"] = generate_summary(st.session_state["full_text"], model=model_choice)
|
| 385 |
+
st.session_state["mcq_text"] = generate_mcqs(st.session_state["full_text"], model=model_choice, count=int(mcq_target))
|
| 386 |
+
st.success("Summary and MCQs generated.")
|
| 387 |
+
except Exception as e:
|
| 388 |
+
st.error(f"Combined generation failed: {e}")
|
| 389 |
+
|
| 390 |
+
# Prepare retrieval infrastructure
|
| 391 |
+
if uploaded_file and ("full_text" in st.session_state):
|
| 392 |
+
if st.button("Prepare Q&A (create embeddings)"):
|
| 393 |
+
try:
|
| 394 |
+
with st.spinner("Chunking document and computing embeddings (costly operation)..."):
|
| 395 |
+
chunks = chunk_text(st.session_state["full_text"], words_per_chunk=int(chunk_size), overlap=int(chunk_overlap))
|
| 396 |
+
st.session_state["chunks"] = chunks
|
| 397 |
+
# Compute embeddings (cached)
|
| 398 |
+
chunk_embs = get_embeddings(chunks, model=emb_model_choice)
|
| 399 |
+
st.session_state["chunk_embeddings"] = chunk_embs
|
| 400 |
+
st.success(f"Prepared {len(chunks)} chunks and embeddings for retrieval.")
|
| 401 |
+
except Exception as e:
|
| 402 |
+
st.error(f"Failed to prepare embeddings: {e}")
|
| 403 |
+
|
| 404 |
+
st.markdown("---")
|
| 405 |
+
st.header("Download / Export")
|
| 406 |
+
st.markdown("After generating content, download a combined study package.")
|
| 407 |
+
if st.session_state.get("summary") or st.session_state.get("mcq_text") or st.session_state["qa_history"]:
|
| 408 |
+
# Compose markdown
|
| 409 |
+
composed = []
|
| 410 |
+
if st.session_state.get("summary"):
|
| 411 |
+
composed.append("# Summary\n\n" + st.session_state["summary"] + "\n\n")
|
| 412 |
+
if st.session_state.get("mcq_text"):
|
| 413 |
+
composed.append("# MCQs\n\n" + st.session_state["mcq_text"] + "\n\n")
|
| 414 |
+
if st.session_state.get("qa_history"):
|
| 415 |
+
qalist = ["# Q&A History\n"]
|
| 416 |
+
for qa in st.session_state["qa_history"]:
|
| 417 |
+
qalist.append(f"**Q:** {qa['question']}\n\n**A:** {qa['answer']}\n\n_Time:_ {qa['time']}\n\n")
|
| 418 |
+
composed.append("\n".join(qalist))
|
| 419 |
+
package_md = "\n".join(composed)
|
| 420 |
+
st.markdown(make_text_download(package_md, filename=f"{uploaded_file.name}_study_package.md"), unsafe_allow_html=True)
|
| 421 |
+
st.download_button("Download study package (.md)", package_md, file_name=f"{uploaded_file.name}_study_package.md", mime="text/markdown")
|
| 422 |
+
else:
|
| 423 |
+
st.info("No generated content yet. Run summary/MCQ generation first.")
|
| 424 |
|
| 425 |
with right_col:
|
| 426 |
+
# Tabs: Summary, MCQ Quiz, Q&A
|
| 427 |
+
tab1, tab2, tab3 = st.tabs(["\U0001f4d1 Summary", "\U0001f4dd MCQ Quiz", "\u2753 Q&A Dashboard"])
|
| 428 |
+
|
| 429 |
+
with tab1:
|
| 430 |
+
st.header("Document Summary")
|
| 431 |
+
if st.session_state.get("summary"):
|
| 432 |
+
st.markdown("<div class='qa-box handwriting'>", unsafe_allow_html=True)
|
| 433 |
+
st.markdown(st.session_state["summary"], unsafe_allow_html=True)
|
| 434 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
| 435 |
+
else:
|
| 436 |
+
st.info("No summary yet. Click 'Generate Summary' in the left panel.")
|
| 437 |
+
|
| 438 |
+
with tab2:
|
| 439 |
+
st.header("Generated MCQs")
|
| 440 |
+
if st.session_state.get("mcq_text"):
|
| 441 |
+
# Display with formatting: question line and indented options vertically
|
| 442 |
+
st.markdown("<div class='mcq-block'>", unsafe_allow_html=True)
|
| 443 |
+
# We display as preformatted but with handwriting font and indentation
|
| 444 |
+
st.text_area("MCQs (read-only)", value=st.session_state["mcq_text"], height=420, key="mcq_display")
|
| 445 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
| 446 |
+
|
| 447 |
+
# Also provide CSV download parsed
|
| 448 |
+
def parse_mcqs_to_df(mcq_text: str) -> pd.DataFrame:
|
| 449 |
+
lines = mcq_text.splitlines()
|
| 450 |
+
rows = []
|
| 451 |
+
q_text = None
|
| 452 |
+
opts = {"A":"","B":"","C":"","D":""}
|
| 453 |
+
answer = ""
|
| 454 |
+
for ln in lines:
|
| 455 |
+
if not ln.strip():
|
| 456 |
+
continue
|
| 457 |
+
# Question detection: starts with "Question" or "Q"
|
| 458 |
+
if ln.strip().lower().startswith("question"):
|
| 459 |
+
if q_text:
|
| 460 |
+
rows.append({"question": q_text.strip(), "A": opts["A"].strip(), "B": opts["B"].strip(), "C": opts["C"].strip(), "D": opts["D"].strip(), "answer": answer.strip()})
|
| 461 |
+
# reset
|
| 462 |
+
parts = ln.split(":",1)
|
| 463 |
+
if len(parts) > 1:
|
| 464 |
+
q_text = parts[1].strip()
|
| 465 |
+
else:
|
| 466 |
+
q_text = ln.strip()
|
| 467 |
+
opts = {"A":"","B":"","C":"","D":""}
|
| 468 |
+
answer = ""
|
| 469 |
+
elif ln.strip().startswith("A.") or ln.strip().startswith("A)"):
|
| 470 |
+
opts["A"] = ln.strip()[2:].strip()
|
| 471 |
+
elif ln.strip().startswith("B.") or ln.strip().startswith("B)"):
|
| 472 |
+
opts["B"] = ln.strip()[2:].strip()
|
| 473 |
+
elif ln.strip().startswith("C.") or ln.strip().startswith("C)"):
|
| 474 |
+
opts["C"] = ln.strip()[2:].strip()
|
| 475 |
+
elif ln.strip().startswith("D.") or ln.strip().startswith("D)"):
|
| 476 |
+
opts["D"] = ln.strip()[2:].strip()
|
| 477 |
+
elif ln.strip().lower().startswith("answer"):
|
| 478 |
+
parts = ln.split(":",1)
|
| 479 |
+
if len(parts) > 1:
|
| 480 |
+
answer = parts[1].strip()
|
| 481 |
+
if q_text:
|
| 482 |
+
rows.append({"question": q_text.strip(), "A": opts["A"].strip(), "B": opts["B"].strip(), "C": opts["C"].strip(), "D": opts["D"].strip(), "answer": answer.strip()})
|
| 483 |
+
return pd.DataFrame(rows)
|
| 484 |
+
|
| 485 |
+
df_mcq = parse_mcqs_to_df(st.session_state["mcq_text"])
|
| 486 |
+
if not df_mcq.empty:
|
| 487 |
+
st.download_button("Download MCQs as CSV", df_mcq.to_csv(index=False), file_name=f"{uploaded_file.name}_mcqs.csv", mime="text/csv")
|
| 488 |
+
else:
|
| 489 |
+
st.info("No MCQs generated yet. Click 'Generate MCQs' in the left panel.")
|
| 490 |
+
|
| 491 |
+
with tab3:
|
| 492 |
+
st.header("Q&A Dashboard")
|
| 493 |
+
st.markdown("Ask questions about the PDF. Use 'Prepare Q&A' first (computes embeddings).")
|
| 494 |
+
question_input = st.text_input("Enter your question here:")
|
| 495 |
+
if st.button("Ask question"):
|
| 496 |
+
if not st.session_state.get("chunks") or not st.session_state.get("chunk_embeddings"):
|
| 497 |
+
st.warning("Please click 'Prepare Q&A (create embeddings)' in the left panel first.")
|
| 498 |
+
elif not question_input.strip():
|
| 499 |
+
st.error("Please type a question.")
|
| 500 |
+
else:
|
| 501 |
+
try:
|
| 502 |
+
with st.spinner("Retrieving context and generating answer..."):
|
| 503 |
+
ans = answer_question(question_input, st.session_state["chunks"], st.session_state["chunk_embeddings"], emb_model_choice, model_choice, top_k=int(retrieval_k))
|
| 504 |
+
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
| 505 |
+
st.session_state["qa_history"].append({"question": question_input, "answer": ans, "time": timestamp})
|
| 506 |
+
st.success("Answer generated.")
|
| 507 |
+
except Exception as e:
|
| 508 |
+
st.error(f"Q&A failed: {e}")
|
| 509 |
+
|
| 510 |
+
# Show history
|
| 511 |
+
if st.session_state["qa_history"]:
|
| 512 |
+
st.markdown("### Recent Q&A")
|
| 513 |
+
for qa in reversed(st.session_state["qa_history"][-8:]):
|
| 514 |
+
st.markdown(f"<div class='qa-box'><strong>Q:</strong> {qa['question']}<br/><strong>A:</strong> {qa['answer']}<div class='small-muted'>Time: {qa['time']}</div></div>", unsafe_allow_html=True)
|
| 515 |
+
# Download Q&A
|
| 516 |
+
qa_md = "\n\n".join([f"Q: {qa['question']}\nA: {qa['answer']}\nTime: {qa['time']}" for qa in st.session_state["qa_history"]])
|
| 517 |
+
st.download_button("Download Q&A history (.txt)", qa_md, file_name=f"{uploaded_file.name}_qa_history.txt", mime="text/plain")
|
| 518 |
+
else:
|
| 519 |
+
st.info("No Q&A history yet.")
|
| 520 |
|
| 521 |
# -------------------------
|
| 522 |
# Footer
|