asad9641 commited on
Commit
cdad639
·
verified ·
1 Parent(s): 68a4eff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import streamlit as st
5
+ import pdfplumber
6
+ from pptx import Presentation
7
+ import docx as docx_lib
8
+ import pandas as pd
9
+ from sentence_transformers import SentenceTransformer
10
+ import faiss
11
+ from groq import Groq
12
+ import markdown2
13
+ from reportlab.lib.pagesizes import letter
14
+ from reportlab.pdfgen import canvas
15
+
16
+ # ---------------- CONFIG ----------------
17
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
18
+ GROQ_LLM_MODEL = "llama-3.3-70b-versatile"
19
+
20
+ # ---------------- HELPERS ----------------
21
+ @st.cache_resource
22
+ def load_embedder():
23
+ return SentenceTransformer(EMBED_MODEL)
24
+
25
+ embedder = load_embedder()
26
+
27
+ def parse_pdf_bytes(file_bytes):
28
+ try:
29
+ text = ""
30
+ with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
31
+ for page in pdf.pages:
32
+ p = page.extract_text()
33
+ if p:
34
+ text += p + "\n"
35
+ return text
36
+ except Exception as e:
37
+ st.warning(f"PDF parse warning: {e}")
38
+ return ""
39
+
40
+ def parse_docx_bytes(file_bytes):
41
+ try:
42
+ doc = docx_lib.Document(io.BytesIO(file_bytes))
43
+ return "\n".join([p.text for p in doc.paragraphs])
44
+ except Exception as e:
45
+ st.warning(f"DOCX parse warning: {e}")
46
+ return ""
47
+
48
+ def parse_pptx_bytes(file_bytes):
49
+ try:
50
+ prs = Presentation(io.BytesIO(file_bytes))
51
+ text = ""
52
+ for slide in prs.slides:
53
+ for shape in slide.shapes:
54
+ if hasattr(shape, "text"):
55
+ text += shape.text + "\n"
56
+ return text
57
+ except Exception as e:
58
+ st.warning(f"PPTX parse warning: {e}")
59
+ return ""
60
+
61
+ def parse_spreadsheet_bytes(file_bytes):
62
+ try:
63
+ try:
64
+ df = pd.read_excel(io.BytesIO(file_bytes))
65
+ except Exception:
66
+ df = pd.read_csv(io.BytesIO(file_bytes))
67
+ return df.to_csv(index=False)
68
+ except Exception as e:
69
+ st.warning(f"Spreadsheet parse warning: {e}")
70
+ return ""
71
+
72
+ def parse_txt_bytes(file_bytes):
73
+ try:
74
+ return file_bytes.decode("utf-8", errors="ignore")
75
+ except Exception as e:
76
+ st.warning(f"TXT parse warning: {e}")
77
+ return ""
78
+
79
+ def chunk_text(text, max_chars=1000, overlap=200):
80
+ if not text:
81
+ return []
82
+ chunks = []
83
+ start = 0
84
+ while start < len(text):
85
+ end = min(start + max_chars, len(text))
86
+ chunk = text[start:end].strip()
87
+ if chunk:
88
+ chunks.append(chunk)
89
+ if end == len(text):
90
+ break
91
+ start = end - overlap
92
+ return chunks
93
+
94
+ def build_faiss_index(chunks, embedder):
95
+ if not chunks:
96
+ return None, None
97
+ embeddings = embedder.encode(chunks, convert_to_numpy=True)
98
+ dim = embeddings.shape[1]
99
+ index = faiss.IndexFlatL2(dim)
100
+ index.add(embeddings.astype("float32"))
101
+ return index, embeddings
102
+
103
+ def retrieve_chunks(query, embedder, faiss_index, chunks, k=5):
104
+ if faiss_index is None or not chunks:
105
+ return []
106
+ q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
107
+ D, I = faiss_index.search(q_emb, k)
108
+ results = []
109
+ for idx in I[0]:
110
+ if 0 <= idx < len(chunks):
111
+ results.append(chunks[idx])
112
+ return results
113
+
114
+ # ---------------- Groq LLM ----------------
115
+ EDU_PROMPTS = {
116
+ "Primary School": "Explain this to me like I'm 5 years old, in a fun and simple way with examples and analogies.",
117
+ "Middle School": "Explain this in a simple and clear way appropriate for a middle school student with examples.",
118
+ "High School": "Explain this clearly, assuming knowledge up to high school level.",
119
+ "Undergraduate": "Explain this in a university-level way, with clarity and useful details and examples.",
120
+ "Graduate": "Explain this at graduate-level rigor, including key details, nuance, and technical terms as appropriate.",
121
+ }
122
+
123
+ def get_groq_client():
124
+ api_key = None
125
+ try:
126
+ api_key = st.secrets["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"]
127
+ except Exception:
128
+ pass
129
+ if not api_key:
130
+ api_key = st.session_state.get("groq_api_key") or os.environ.get("GROQ_API_KEY")
131
+ if not api_key:
132
+ raise ValueError("Groq API key not found. Set st.secrets['GROQ_API_KEY'], or enter in sidebar, or set env GROQ_API_KEY.")
133
+ return Groq(api_key=api_key)
134
+
135
+ def call_llm_with_context(question, retrieved_chunks, edu_level):
136
+ client = get_groq_client()
137
+ edu_instr = EDU_PROMPTS.get(edu_level, "")
138
+ context = "\n\n".join(retrieved_chunks) if retrieved_chunks else ""
139
+ user_content = f"{edu_instr}\n\nContext:\n{context}\n\nQuestion: {question}"
140
+ response = client.chat.completions.create(
141
+ messages=[
142
+ {"role": "system", "content": "You are a helpful and knowledgeable tutor."},
143
+ {"role": "user", "content": user_content}
144
+ ],
145
+ model=GROQ_LLM_MODEL,
146
+ )
147
+ return response.choices[0].message.content
148
+
149
+ def make_summary(question, retrieved_chunks, edu_level):
150
+ client = get_groq_client()
151
+ edu_instr = EDU_PROMPTS.get(edu_level, "")
152
+ context = "\n\n".join(retrieved_chunks) if retrieved_chunks else ""
153
+ prompt = f"{edu_instr}\n\nHere is some context:\n{context}\n\nPlease give a short, easy-to-understand summary of: {question}\nKeep it concise and simple; use bullet points if helpful."
154
+ response = client.chat.completions.create(
155
+ messages=[
156
+ {"role": "system", "content": "You are a concise summarizer."},
157
+ {"role": "user", "content": prompt}
158
+ ],
159
+ model=GROQ_LLM_MODEL,
160
+ )
161
+ return response.choices[0].message.content
162
+
163
+ def make_mcqs_from_summary(summary_text, count=5, difficulty="medium"):
164
+ client = get_groq_client()
165
+ prompt = (
166
+ f"Create {count} multiple choice questions (MCQs) from the following summary. "
167
+ "Each question should have 4 options labeled A-D and indicate the correct option. "
168
+ "Also provide a 1-2 sentence explanation for the correct answer. "
169
+ f"Difficulty: {difficulty}.\n\nSummary:\n{summary_text}"
170
+ )
171
+ response = client.chat.completions.create(
172
+ messages=[
173
+ {"role": "system", "content": "You are an assistant that generates high-quality multiple-choice questions."},
174
+ {"role": "user", "content": prompt}
175
+ ],
176
+ model=GROQ_LLM_MODEL,
177
+ )
178
+ return response.choices[0].message.content
179
+
180
+ # ---------------- STREAMLIT UI ----------------
181
+ st.set_page_config(page_title="AI Study Assistant", layout="wide")
182
+ st.title("📚 AI Study Assistant — Exam Mode")
183
+
184
+ with st.sidebar:
185
+ st.header("Settings")
186
+ groq_key = st.text_input("Groq API key (optional)", type="password")
187
+ if groq_key:
188
+ st.session_state["groq_api_key"] = groq_key
189
+ edu_level = st.selectbox("Education level", list(EDU_PROMPTS.keys()))
190
+ st.info("Upload documents and ask questions. You can generate summaries + MCQs.")
191
+
192
+ uploaded_files = st.file_uploader("Upload study documents (PDF, DOCX, PPTX, XLSX/CSV, TXT)", accept_multiple_files=True)
193
+ if not uploaded_files:
194
+ st.info("Please upload at least one document.")
195
+ st.stop()
196
+
197
+ # ---------------- PARSE FILES ----------------
198
+ all_text = ""
199
+ for uf in uploaded_files:
200
+ raw = uf.read()
201
+ text = ""
202
+ name = uf.name.lower()
203
+ if name.endswith(".pdf"):
204
+ text = parse_pdf_bytes(raw)
205
+ elif name.endswith(".docx"):
206
+ text = parse_docx_bytes(raw)
207
+ elif name.endswith(".pptx"):
208
+ text = parse_pptx_bytes(raw)
209
+ elif name.endswith((".xls", ".xlsx", ".csv")):
210
+ text = parse_spreadsheet_bytes(raw)
211
+ elif name.endswith(".txt"):
212
+ text = parse_txt_bytes(raw)
213
+ else:
214
+ try:
215
+ text = raw.decode("utf-8")
216
+ except Exception:
217
+ text = ""
218
+ if text:
219
+ all_text += f"\n\n### From file: {uf.name}\n\n{text}"
220
+
221
+ if not all_text.strip():
222
+ st.error("No textual content extracted.")
223
+ st.stop()
224
+
225
+ # ---------------- CHUNK + INDEX ----------------
226
+ with st.spinner("Processing documents..."):
227
+ chunks = chunk_text(all_text)
228
+ faiss_index, embeddings = build_faiss_index(chunks, embedder)
229
+ st.success(f"Prepared {len(chunks)} chunks and built vector index.")
230
+
231
+ # ---------------- ASK QUESTION ----------------
232
+ question = st.text_input("Ask a question about your materials:")
233
+ if not question:
234
+ st.info("Type a question to begin.")
235
+ st.stop()
236
+
237
+ topk = st.number_input("Top-k passages", min_value=1, max_value=10, value=5)
238
+ mcq_count = st.number_input("MCQs to generate", min_value=1, max_value=20, value=5)
239
+ mcq_diff = st.selectbox("MCQ difficulty", ["easy", "medium", "hard"], index=1)
240
+
241
+ retrieved = retrieve_chunks(question, embedder, faiss_index, chunks, k=int(topk))
242
+
243
+ if retrieved:
244
+ st.subheader("Relevant passages:")
245
+ for i, r in enumerate(retrieved):
246
+ st.markdown(f"**Passage {i+1}:**")
247
+ st.write(r[:800] + ("..." if len(r) > 800 else ""))
248
+ else:
249
+ st.warning("No relevant passages found.")
250
+
251
+ # ---------------- GENERATE ANSWER ----------------
252
+ try:
253
+ answer = call_llm_with_context(question, retrieved, edu_level)
254
+ st.subheader("Answer:")
255
+ st.write(answer)
256
+ except Exception as e:
257
+ st.error(f"LLM error: {e}")
258
+ st.stop()
259
+
260
+ # ---------------- GENERATE SUMMARY + MCQs ----------------
261
+ if st.checkbox("Generate summary and MCQs"):
262
+ try:
263
+ summary = make_summary(question, retrieved, edu_level)
264
+ st.subheader("📘 Summary")
265
+ st.write(summary)
266
+
267
+ # Downloads
268
+ md_text = summary
269
+ html_text = markdown2.markdown(summary)
270
+
271
+ # PDF
272
+ pdf_buffer = io.BytesIO()
273
+ p = canvas.Canvas(pdf_buffer, pagesize=letter)
274
+ width, height = letter
275
+ text_obj = p.beginText(40, height - 40)
276
+ for line in summary.split("\n"):
277
+ while len(line) > 90:
278
+ text_obj.textLine(line[:90])
279
+ line = line[90:]
280
+ text_obj.textLine(line)
281
+ p.drawText(text_obj)
282
+ p.showPage()
283
+ p.save()
284
+ pdf_buffer.seek(0)
285
+
286
+ # DOCX
287
+ docx_buffer = io.BytesIO()
288
+ doc = docx_lib.Document()
289
+ doc.add_heading("Summary", level=1)
290
+ for line in summary.split("\n"):
291
+ doc.add_paragraph(line)
292
+ doc.save(docx_buffer)
293
+ docx_buffer.seek(0)
294
+
295
+ st.download_button("⬇️ Download Summary (Markdown)", md_text, file_name="summary.md")
296
+ st.download_button("⬇️ Download Summary (HTML)", html_text, file_name="summary.html", mime="text/html")
297
+ st.download_button("⬇️ Download Summary (PDF)", pdf_buffer, file_name="summary.pdf", mime="application/pdf")
298
+ st.download_button("⬇️ Download Summary (DOCX)", docx_buffer, file_name="summary.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
299
+
300
+ # MCQs
301
+ mcq_text = make_mcqs_from_summary(summary, count=int(mcq_count), difficulty=mcq_diff)
302
+ st.subheader("📝 Generated MCQs")
303
+ st.write(mcq_text)
304
+
305
+ mcq_docx_buf = io.BytesIO()
306
+ doc_mcq = docx_lib.Document()
307
+ doc_mcq.add_heading("MCQs", level=1)
308
+ for line in mcq_text.split("\n"):
309
+ doc_mcq.add_paragraph(line)
310
+ doc_mcq.save(mcq_docx_buf)
311
+ mcq_docx_buf.seek(0)
312
+
313
+ mcq_pdf_buf = io.BytesIO()
314
+ p2 = canvas.Canvas(mcq_pdf_buf, pagesize=letter)
315
+ text_obj2 = p2.beginText(40, height - 40)
316
+ for line in mcq_text.split("\n"):
317
+ while len(line) > 90:
318
+ text_obj2.textLine(line[:90])
319
+ line = line[90:]
320
+ text_obj2.textLine(line)
321
+ p2.drawText(text_obj2)
322
+ p2.showPage()
323
+ p2.save()
324
+ mcq_pdf_buf.seek(0)
325
+
326
+ st.download_button("⬇️ Download MCQs (DOCX)", mcq_docx_buf, file_name="mcqs.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
327
+ st.download_button("⬇️ Download MCQs (PDF)", mcq_pdf_buf, file_name="mcqs.pdf", mime="application/pdf")
328
+
329
+ except Exception as e:
330
+ st.error(f"Error generating summary or MCQs: {e}")
331
+ else:
332
+ st.info("Check the box above to generate summary + MCQs from retrieved content.")