File size: 12,013 Bytes
cdad639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a46dea
cdad639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
# app.py
import os
import io
import streamlit as st
import pdfplumber
from pptx import Presentation
import docx as docx_lib
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
from groq import Groq
import markdown2
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas

# ---------------- CONFIG ----------------
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
GROQ_LLM_MODEL = "llama-3.3-70b-versatile"

# ---------------- HELPERS ----------------
@st.cache_resource
def load_embedder():
    return SentenceTransformer(EMBED_MODEL)

embedder = load_embedder()

def parse_pdf_bytes(file_bytes):
    try:
        text = ""
        with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
            for page in pdf.pages:
                p = page.extract_text()
                if p:
                    text += p + "\n"
        return text
    except Exception as e:
        st.warning(f"PDF parse warning: {e}")
        return ""

def parse_docx_bytes(file_bytes):
    try:
        doc = docx_lib.Document(io.BytesIO(file_bytes))
        return "\n".join([p.text for p in doc.paragraphs])
    except Exception as e:
        st.warning(f"DOCX parse warning: {e}")
        return ""

def parse_pptx_bytes(file_bytes):
    try:
        prs = Presentation(io.BytesIO(file_bytes))
        text = ""
        for slide in prs.slides:
            for shape in slide.shapes:
                if hasattr(shape, "text"):
                    text += shape.text + "\n"
        return text
    except Exception as e:
        st.warning(f"PPTX parse warning: {e}")
        return ""

def parse_spreadsheet_bytes(file_bytes):
    try:
        try:
            df = pd.read_excel(io.BytesIO(file_bytes))
        except Exception:
            df = pd.read_csv(io.BytesIO(file_bytes))
        return df.to_csv(index=False)
    except Exception as e:
        st.warning(f"Spreadsheet parse warning: {e}")
        return ""

def parse_txt_bytes(file_bytes):
    try:
        return file_bytes.decode("utf-8", errors="ignore")
    except Exception as e:
        st.warning(f"TXT parse warning: {e}")
        return ""

def chunk_text(text, max_chars=1000, overlap=200):
    if not text:
        return []
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + max_chars, len(text))
        chunk = text[start:end].strip()
        if chunk:
            chunks.append(chunk)
        if end == len(text):
            break
        start = end - overlap
    return chunks

def build_faiss_index(chunks, embedder):
    if not chunks:
        return None, None
    embeddings = embedder.encode(chunks, convert_to_numpy=True)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings.astype("float32"))
    return index, embeddings

def retrieve_chunks(query, embedder, faiss_index, chunks, k=5):
    if faiss_index is None or not chunks:
        return []
    q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
    D, I = faiss_index.search(q_emb, k)
    results = []
    for idx in I[0]:
        if 0 <= idx < len(chunks):
            results.append(chunks[idx])
    return results

# ---------------- Groq LLM ----------------
EDU_PROMPTS = {
    "Primary School": "Explain this to me like I'm 5 years old, in a fun and simple way with examples and analogies.",
    "Middle School": "Explain this in a simple and clear way appropriate for a middle school student with examples.",
    "High School": "Explain this clearly, assuming knowledge up to high school level.",
    "Undergraduate": "Explain this in a university-level way, with clarity and useful details and examples.",
    "Graduate": "Explain this at graduate-level rigor, including key details, nuance, and technical terms as appropriate.",
}

def get_groq_client():
    api_key = None
    try:
        api_key = st.secrets[""]
    except Exception:
        pass
    if not api_key:
        api_key = st.session_state.get("groq_api_key") or os.environ.get("GROQ_API_KEY")
    if not api_key:
        raise ValueError("Groq API key not found. Set st.secrets['GROQ_API_KEY'], or enter in sidebar, or set env GROQ_API_KEY.")
    return Groq(api_key=api_key)

def call_llm_with_context(question, retrieved_chunks, edu_level):
    client = get_groq_client()
    edu_instr = EDU_PROMPTS.get(edu_level, "")
    context = "\n\n".join(retrieved_chunks) if retrieved_chunks else ""
    user_content = f"{edu_instr}\n\nContext:\n{context}\n\nQuestion: {question}"
    response = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a helpful and knowledgeable tutor."},
            {"role": "user", "content": user_content}
        ],
        model=GROQ_LLM_MODEL,
    )
    return response.choices[0].message.content

def make_summary(question, retrieved_chunks, edu_level):
    client = get_groq_client()
    edu_instr = EDU_PROMPTS.get(edu_level, "")
    context = "\n\n".join(retrieved_chunks) if retrieved_chunks else ""
    prompt = f"{edu_instr}\n\nHere is some context:\n{context}\n\nPlease give a short, easy-to-understand summary of: {question}\nKeep it concise and simple; use bullet points if helpful."
    response = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a concise summarizer."},
            {"role": "user", "content": prompt}
        ],
        model=GROQ_LLM_MODEL,
    )
    return response.choices[0].message.content

def make_mcqs_from_summary(summary_text, count=5, difficulty="medium"):
    client = get_groq_client()
    prompt = (
        f"Create {count} multiple choice questions (MCQs) from the following summary. "
        "Each question should have 4 options labeled A-D and indicate the correct option. "
        "Also provide a 1-2 sentence explanation for the correct answer. "
        f"Difficulty: {difficulty}.\n\nSummary:\n{summary_text}"
    )
    response = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are an assistant that generates high-quality multiple-choice questions."},
            {"role": "user", "content": prompt}
        ],
        model=GROQ_LLM_MODEL,
    )
    return response.choices[0].message.content

# ---------------- STREAMLIT UI ----------------
st.set_page_config(page_title="AI Study Assistant", layout="wide")
st.title("📚 AI Study Assistant — Exam Mode")

with st.sidebar:
    st.header("Settings")
    groq_key = st.text_input("Groq API key (optional)", type="password")
    if groq_key:
        st.session_state["groq_api_key"] = groq_key
    edu_level = st.selectbox("Education level", list(EDU_PROMPTS.keys()))
    st.info("Upload documents and ask questions. You can generate summaries + MCQs.")

uploaded_files = st.file_uploader("Upload study documents (PDF, DOCX, PPTX, XLSX/CSV, TXT)", accept_multiple_files=True)
if not uploaded_files:
    st.info("Please upload at least one document.")
    st.stop()

# ---------------- PARSE FILES ----------------
all_text = ""
for uf in uploaded_files:
    raw = uf.read()
    text = ""
    name = uf.name.lower()
    if name.endswith(".pdf"):
        text = parse_pdf_bytes(raw)
    elif name.endswith(".docx"):
        text = parse_docx_bytes(raw)
    elif name.endswith(".pptx"):
        text = parse_pptx_bytes(raw)
    elif name.endswith((".xls", ".xlsx", ".csv")):
        text = parse_spreadsheet_bytes(raw)
    elif name.endswith(".txt"):
        text = parse_txt_bytes(raw)
    else:
        try:
            text = raw.decode("utf-8")
        except Exception:
            text = ""
    if text:
        all_text += f"\n\n### From file: {uf.name}\n\n{text}"

if not all_text.strip():
    st.error("No textual content extracted.")
    st.stop()

# ---------------- CHUNK + INDEX ----------------
with st.spinner("Processing documents..."):
    chunks = chunk_text(all_text)
    faiss_index, embeddings = build_faiss_index(chunks, embedder)
    st.success(f"Prepared {len(chunks)} chunks and built vector index.")

# ---------------- ASK QUESTION ----------------
question = st.text_input("Ask a question about your materials:")
if not question:
    st.info("Type a question to begin.")
    st.stop()

topk = st.number_input("Top-k passages", min_value=1, max_value=10, value=5)
mcq_count = st.number_input("MCQs to generate", min_value=1, max_value=20, value=5)
mcq_diff = st.selectbox("MCQ difficulty", ["easy", "medium", "hard"], index=1)

retrieved = retrieve_chunks(question, embedder, faiss_index, chunks, k=int(topk))

if retrieved:
    st.subheader("Relevant passages:")
    for i, r in enumerate(retrieved):
        st.markdown(f"**Passage {i+1}:**")
        st.write(r[:800] + ("..." if len(r) > 800 else ""))
else:
    st.warning("No relevant passages found.")

# ---------------- GENERATE ANSWER ----------------
try:
    answer = call_llm_with_context(question, retrieved, edu_level)
    st.subheader("Answer:")
    st.write(answer)
except Exception as e:
    st.error(f"LLM error: {e}")
    st.stop()

# ---------------- GENERATE SUMMARY + MCQs ----------------
if st.checkbox("Generate summary and MCQs"):
    try:
        summary = make_summary(question, retrieved, edu_level)
        st.subheader("📘 Summary")
        st.write(summary)

        # Downloads
        md_text = summary
        html_text = markdown2.markdown(summary)

        # PDF
        pdf_buffer = io.BytesIO()
        p = canvas.Canvas(pdf_buffer, pagesize=letter)
        width, height = letter
        text_obj = p.beginText(40, height - 40)
        for line in summary.split("\n"):
            while len(line) > 90:
                text_obj.textLine(line[:90])
                line = line[90:]
            text_obj.textLine(line)
        p.drawText(text_obj)
        p.showPage()
        p.save()
        pdf_buffer.seek(0)

        # DOCX
        docx_buffer = io.BytesIO()
        doc = docx_lib.Document()
        doc.add_heading("Summary", level=1)
        for line in summary.split("\n"):
            doc.add_paragraph(line)
        doc.save(docx_buffer)
        docx_buffer.seek(0)

        st.download_button("⬇️ Download Summary (Markdown)", md_text, file_name="summary.md")
        st.download_button("⬇️ Download Summary (HTML)", html_text, file_name="summary.html", mime="text/html")
        st.download_button("⬇️ Download Summary (PDF)", pdf_buffer, file_name="summary.pdf", mime="application/pdf")
        st.download_button("⬇️ Download Summary (DOCX)", docx_buffer, file_name="summary.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document")

        # MCQs
        mcq_text = make_mcqs_from_summary(summary, count=int(mcq_count), difficulty=mcq_diff)
        st.subheader("📝 Generated MCQs")
        st.write(mcq_text)

        mcq_docx_buf = io.BytesIO()
        doc_mcq = docx_lib.Document()
        doc_mcq.add_heading("MCQs", level=1)
        for line in mcq_text.split("\n"):
            doc_mcq.add_paragraph(line)
        doc_mcq.save(mcq_docx_buf)
        mcq_docx_buf.seek(0)

        mcq_pdf_buf = io.BytesIO()
        p2 = canvas.Canvas(mcq_pdf_buf, pagesize=letter)
        text_obj2 = p2.beginText(40, height - 40)
        for line in mcq_text.split("\n"):
            while len(line) > 90:
                text_obj2.textLine(line[:90])
                line = line[90:]
            text_obj2.textLine(line)
        p2.drawText(text_obj2)
        p2.showPage()
        p2.save()
        mcq_pdf_buf.seek(0)

        st.download_button("⬇️ Download MCQs (DOCX)", mcq_docx_buf, file_name="mcqs.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
        st.download_button("⬇️ Download MCQs (PDF)", mcq_pdf_buf, file_name="mcqs.pdf", mime="application/pdf")

    except Exception as e:
        st.error(f"Error generating summary or MCQs: {e}")
else:
    st.info("Check the box above to generate summary + MCQs from retrieved content.")