Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import streamlit as st | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import PyPDF2 | |
| from pptx import Presentation | |
| import docx2txt | |
| import nbformat | |
| import google.generativeai as genai | |
| # ----------------------------- | |
| # BASIC CONFIG | |
| # ----------------------------- | |
| st.set_page_config(page_title="ModuleMate AI", page_icon="π", layout="wide") | |
| # β οΈ For local prototype only β hardcoded key | |
| # Replace this with your own Gemini key (and NEVER commit the api key to GitHub). | |
| API_KEY = "your gemini api key here" | |
| genai.configure(api_key=API_KEY) | |
| # ----------------------------- | |
| # CACHED MODELS (LAZY LOADING) | |
| # ----------------------------- | |
| def get_embed_model(): | |
| # Loaded once per session, but ONLY when first called | |
| return SentenceTransformer("all-MiniLM-L6-v2") | |
| def get_gemini_model(): | |
| # Loaded once per session, but ONLY when first called | |
| return genai.GenerativeModel("gemini-1.5-flash") | |
| # ----------------------------- | |
| # SESSION STATE | |
| # ----------------------------- | |
| if "vector_index" not in st.session_state: | |
| st.session_state.vector_index = None | |
| if "chunk_texts" not in st.session_state: | |
| st.session_state.chunk_texts = [] | |
| if "chunk_sources" not in st.session_state: | |
| st.session_state.chunk_sources = [] # file names for each chunk | |
| if "flashcards" not in st.session_state: | |
| st.session_state.flashcards = [] | |
| if "flashcard_index" not in st.session_state: | |
| st.session_state.flashcard_index = 0 | |
| if "flash_sources" not in st.session_state: | |
| st.session_state.flash_sources = "" | |
| if "show_answer" not in st.session_state: | |
| st.session_state.show_answer = False | |
| # ----------------------------- | |
| # FILE LOADERS (for uploaded file objects) | |
| # ----------------------------- | |
| def load_pdf(file): | |
| reader = PyPDF2.PdfReader(file) | |
| return "\n".join([p.extract_text() or "" for p in reader.pages]) | |
| def load_docx(file): | |
| return docx2txt.process(file) | |
| def load_pptx(file): | |
| prs = Presentation(file) | |
| texts = [] | |
| for slide in prs.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text"): | |
| texts.append(shape.text) | |
| return "\n".join(texts) | |
| def load_txt(file): | |
| return file.read().decode("utf-8") | |
| def load_ipynb(file): | |
| nb = nbformat.read(file, as_version=4) | |
| code_cells = [c.source for c in nb.cells if c.cell_type == "code"] | |
| return "\n".join(code_cells) | |
| def load_py(file): | |
| return file.read().decode("utf-8") | |
| def load_file(file): | |
| ext = os.path.splitext(file.name)[1].lower() | |
| if ext == ".pdf": | |
| return load_pdf(file) | |
| elif ext == ".docx": | |
| return load_docx(file) | |
| elif ext == ".pptx": | |
| return load_pptx(file) | |
| elif ext == ".txt": | |
| return load_txt(file) | |
| elif ext == ".ipynb": | |
| return load_ipynb(file) | |
| elif ext == ".py": | |
| return load_py(file) | |
| return "" | |
| # ----------------------------- | |
| # CHUNKING | |
| # ----------------------------- | |
| def chunk_text(text, chunk_size=1000, overlap=100): | |
| """Simple sliding window over the text.""" | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| chunks.append(text[start:end]) | |
| start += chunk_size - overlap | |
| return chunks | |
| # ----------------------------- | |
| # BUILD INDEX (stores in session_state) | |
| # ----------------------------- | |
| def build_index(files): | |
| all_texts = [] | |
| all_sources = [] | |
| for file in files: | |
| text = load_file(file) | |
| if not text: | |
| continue | |
| chunks = chunk_text(text) | |
| for ch in chunks: | |
| all_texts.append(ch) | |
| all_sources.append(file.name) | |
| if not all_texts: | |
| st.error("No readable text found in the uploaded files.") | |
| return | |
| # Load embedding model lazily and cache it | |
| embed_model = get_embed_model() | |
| embeddings = embed_model.encode(all_texts) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dim) | |
| index.add(np.array(embeddings).astype("float32")) | |
| st.session_state.vector_index = index | |
| st.session_state.chunk_texts = all_texts | |
| st.session_state.chunk_sources = all_sources | |
| st.success( | |
| f"Index built with {len(all_texts)} chunks from {len(set(all_sources))} files." | |
| ) | |
| # ----------------------------- | |
| # RETRIEVAL HELPER | |
| # ----------------------------- | |
| def retrieve_context(question, k=5): | |
| """Return (context_text, sources_summary, selected_chunks)""" | |
| if st.session_state.vector_index is None or not st.session_state.chunk_texts: | |
| return None, "β Please upload files and build the index.", [] | |
| embed_model = get_embed_model() | |
| query_emb = embed_model.encode([question]) | |
| D, I = st.session_state.vector_index.search( | |
| np.array(query_emb).astype("float32"), k=k | |
| ) | |
| selected_chunks = [] | |
| for idx in I[0]: | |
| text = st.session_state.chunk_texts[idx] | |
| src = st.session_state.chunk_sources[idx] | |
| selected_chunks.append((text, src)) | |
| context = "\n\n---\n\n".join([c[0] for c in selected_chunks]) | |
| # Human-readable source summary | |
| source_lines = [] | |
| for i, (chunk, src) in enumerate(selected_chunks, start=1): | |
| preview = chunk.replace("\n", " ") | |
| if len(preview) > 160: | |
| preview = preview[:160] + "..." | |
| source_lines.append(f"**Chunk {i}** β `{src}`\n> {preview}") | |
| sources_summary = ( | |
| "\n\n".join(source_lines) if source_lines else "No sources retrieved." | |
| ) | |
| return context, sources_summary, selected_chunks | |
| # ----------------------------- | |
| # PROMPT HELPERS | |
| # ----------------------------- | |
| def task_type_instructions(task_type): | |
| mapping = { | |
| "Brainstorming": "Focus on generating diverse, creative ideas and alternative angles. Avoid giving final polished answers.", | |
| "Studying for exam": "Focus on clear explanations, key concepts, bullet-point summaries and small self-test questions.", | |
| "Writing assistance": "Help with structure, clarity, argumentation and academic phrasing. Avoid writing full assignments in restricted mode.", | |
| "Coding help": "Explain code step by step, suggest improvements and debugging hints. Emphasise understanding over copy-paste solutions.", | |
| "Homework help": "Guide the student through the problem with hints and intermediate steps, not full final solutions.", | |
| "Research assistance": "Help with concept clarification, comparing theories and outlining approaches. Avoid fabricating citations.", | |
| } | |
| return mapping.get(task_type, "") | |
| def difficulty_instructions(level): | |
| if level == "Beginner": | |
| return "Explain as if to a first-year student. Use simple language, concrete examples and avoid heavy notation." | |
| elif level == "Intermediate": | |
| return "Assume some prior knowledge. Use correct terminology but keep explanations clear and structured." | |
| else: # Advanced | |
| return "Assume a strong background. You may use technical terminology, formal definitions and more compact explanations." | |
| def material_focus_instructions(material_focus): | |
| if not material_focus: | |
| return "" | |
| joined = ", ".join(material_focus) | |
| return f"Prioritise information that appears to come from: {joined}. If relevant material is missing, say so." | |
| # ----------------------------- | |
| # RAG QUERY (normal answer) | |
| # ----------------------------- | |
| def rag_query(question, mode, task_type, difficulty, material_focus): | |
| context, sources_summary, _ = retrieve_context(question) | |
| if context is None: | |
| return "β Please upload files and build the index.", "" | |
| task_instr = task_type_instructions(task_type) | |
| diff_instr = difficulty_instructions(difficulty) | |
| focus_instr = material_focus_instructions(material_focus) | |
| if mode == "AI Integrity Mode (restricted)": | |
| system_block = f""" | |
| You are ModuleMate AI in RESTRICTED MODE at a university. | |
| Your role is to support learning **without** solving assessments directly. | |
| You must: | |
| - Use ONLY the context below. | |
| - Give conceptual hints, guidance and explanations. | |
| - **Do NOT** provide full final solutions to assignments, exam questions or project work. | |
| - If the question looks like an assignment/exam question, remind the student that you can only give hints. | |
| """ | |
| else: | |
| system_block = """ | |
| You are ModuleMate AI, a course-aware study assistant. | |
| You may give full explanations and worked examples, but keep them grounded in the provided material. | |
| """ | |
| prompt = f""" | |
| {system_block} | |
| Task type: {task_type} | |
| Difficulty level: {difficulty} | |
| Intent-specific instructions: {task_instr} | |
| {diff_instr} | |
| {focus_instr} | |
| CONTEXT FROM COURSE MATERIAL: | |
| {context} | |
| STUDENT QUESTION: | |
| {question} | |
| Now answer accordingly. | |
| """ | |
| try: | |
| model = get_gemini_model() | |
| response = model.generate_content(prompt) | |
| return response.text, sources_summary | |
| except Exception as e: | |
| fallback = ( | |
| "β οΈ Gemini API error: " | |
| + str(e) | |
| + "\n\nHere are the most relevant sections from your course material:\n\n" | |
| + context | |
| ) | |
| return fallback, sources_summary | |
| # ----------------------------- | |
| # FLASHCARD PARSING + GENERATION | |
| # ----------------------------- | |
| def parse_flashcards(raw_text): | |
| """ | |
| Parse model output into a list of {'q': ..., 'a': ...}. | |
| Handles formats like: | |
| Q: ... | |
| A: ... | |
| or | |
| 1. Q: ... | |
| A: ... | |
| or with bullets / markdown. | |
| """ | |
| cards = [] | |
| # Normalise line breaks and strip | |
| text = raw_text.strip() | |
| if not text: | |
| return cards | |
| # Split into blocks starting with Q / q | |
| blocks = re.split(r"\n(?=[Qq][\.:)\- ]|\*\*Q)", text) | |
| for block in blocks: | |
| lines = [ln.strip() for ln in block.splitlines() if ln.strip()] | |
| if not lines: | |
| continue | |
| # Find question line | |
| q_line = None | |
| a_line = None | |
| for ln in lines: | |
| if re.match(r"^(Q|q)[\.:)\- ]", ln) or ln.startswith("**Q"): | |
| q_line = ln | |
| break | |
| for ln in lines: | |
| if re.match(r"^(A|a)[\.:)\- ]", ln) or ln.startswith("**A"): | |
| a_line = ln | |
| break | |
| if q_line and a_line: | |
| q = re.sub(r"^(Q|q)[\.:)\- ]\s*|\*\*Q\*\*[:\- ]\s*", "", q_line).strip() | |
| a = re.sub(r"^(A|a)[\.:)\- ]\s*|\*\*A\*\*[:\- ]\s*", "", a_line).strip() | |
| if q and a: | |
| cards.append({"q": q, "a": a}) | |
| return cards | |
| def create_flashcards(question, mode, task_type, difficulty, material_focus): | |
| context, sources_summary, _ = retrieve_context(question) | |
| if context is None: | |
| return [], sources_summary | |
| diff_instr = difficulty_instructions(difficulty) | |
| focus_instr = material_focus_instructions(material_focus) | |
| prompt = f""" | |
| You are ModuleMate AI, helping a university student create study flashcards. | |
| The topic is: | |
| {question} | |
| Generate 8β12 concise flashcards (QβA pairs) based ONLY on the context below. | |
| Each flashcard should be useful for **exam revision** in this module. | |
| Very important formatting rules: | |
| - Write each card as two lines: | |
| Q: <short question> | |
| A: <short answer> | |
| - Do NOT add numbering, bullets or extra commentary. | |
| - Just repeat this pattern for every card. | |
| {diff_instr} | |
| {focus_instr} | |
| CONTEXT: | |
| {context} | |
| """ | |
| try: | |
| model = get_gemini_model() | |
| response = model.generate_content(prompt) | |
| raw = response.text | |
| cards = parse_flashcards(raw) | |
| # If parsing fails, fall back to a single card with raw output | |
| if not cards: | |
| cards = [ | |
| { | |
| "q": "Flashcards could not be parsed from the AI output.", | |
| "a": raw.strip() or "No content returned.", | |
| } | |
| ] | |
| return cards, sources_summary | |
| except Exception as e: | |
| fallback_answer = ( | |
| "Gemini API error while creating flashcards: " | |
| + str(e) | |
| + "\n\nHere are the most relevant sections from your course material:\n\n" | |
| + context | |
| ) | |
| cards = [ | |
| { | |
| "q": "Flashcards could not be generated", | |
| "a": fallback_answer, | |
| } | |
| ] | |
| return cards, sources_summary | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| st.markdown("## π ModuleMate AI β Your Course-Aware Study Assistant") | |
| # Real modules | |
| modules = [ | |
| "Module 1 - Introduction to Data Handling, Exploration & Applied Machine Learning (E25)", | |
| "Module 2 - Natural Language Processing and Network Analysis (E25)", | |
| "Module 3 - Data-Driven Business Modelling and Strategy (E25)", | |
| ] | |
| selected_module = st.selectbox("Select your current module:", modules) | |
| mode = st.radio("Choose Mode:", ["Normal RAG Mode", "AI Integrity Mode (restricted)"]) | |
| task_type = st.selectbox( | |
| "What do you need help with?", | |
| [ | |
| "Studying for exam", | |
| "Brainstorming", | |
| "Writing assistance", | |
| "Coding help", | |
| "Homework help", | |
| "Research assistance", | |
| ], | |
| ) | |
| difficulty = st.selectbox( | |
| "Preferred explanation level:", | |
| ["Beginner", "Intermediate", "Advanced"], | |
| ) | |
| material_focus = st.multiselect( | |
| "Which materials should I focus on?", | |
| ["Lecture slides", "Readings", "Assignments", "Tutorials", "Code examples"], | |
| default=[ | |
| "Lecture slides", | |
| "Readings", | |
| "Assignments", | |
| "Tutorials", | |
| "Code examples", | |
| ], | |
| ) | |
| st.divider() | |
| uploaded_files = st.file_uploader( | |
| "Upload course material (PDF, DOCX, PPTX, TXT, IPYNB, PY)", | |
| accept_multiple_files=True, | |
| ) | |
| if st.button("Build Index"): | |
| if uploaded_files: | |
| with st.spinner("Building vector index from your course material..."): | |
| build_index(uploaded_files) | |
| else: | |
| st.error("Please upload at least one file.") | |
| st.divider() | |
| question = st.text_input("Ask a question about your module (or topic for flashcards):") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| ask_clicked = st.button("Ask (normal answer)") | |
| with col2: | |
| flash_clicked = st.button("Generate flashcards") | |
| # --- Normal RAG answer --- | |
| if ask_clicked: | |
| if question.strip(): | |
| with st.spinner("Searching your course material and generating an answer..."): | |
| answer, sources = rag_query( | |
| question, mode, task_type, difficulty, material_focus | |
| ) | |
| st.subheader("Answer") | |
| st.write(answer) | |
| st.subheader("Sources used") | |
| st.markdown(sources) | |
| else: | |
| st.error("Please enter a question or topic.") | |
| # --- Flashcard generation --- | |
| if flash_clicked: | |
| if question.strip(): | |
| with st.spinner("Creating flashcards from your course material..."): | |
| cards, sources = create_flashcards( | |
| question, mode, task_type, difficulty, material_focus | |
| ) | |
| st.session_state.flashcards = cards | |
| st.session_state.flashcard_index = 0 | |
| st.session_state.flash_sources = sources | |
| st.session_state.show_answer = False | |
| else: | |
| st.error("Please enter a topic for the flashcards.") | |
| # --- Flashcard viewer --- | |
| if st.session_state.flashcards: | |
| cards = st.session_state.flashcards | |
| idx = st.session_state.flashcard_index | |
| card = cards[idx] | |
| st.subheader("Flashcards") | |
| st.write(f"Card {idx + 1} of {len(cards)}") | |
| # Navigation + reveal controls | |
| nav_col1, nav_col2, nav_col3 = st.columns([1, 2, 1]) | |
| with nav_col1: | |
| prev_clicked = st.button("β Previous") | |
| with nav_col2: | |
| show_clicked = st.button("Show answer") | |
| with nav_col3: | |
| next_clicked = st.button("Next βΆ") | |
| # Update state based on clicks | |
| if prev_clicked and idx > 0: | |
| st.session_state.flashcard_index -= 1 | |
| st.session_state.show_answer = False | |
| idx = st.session_state.flashcard_index | |
| card = cards[idx] | |
| if next_clicked and idx < len(cards) - 1: | |
| st.session_state.flashcard_index += 1 | |
| st.session_state.show_answer = False | |
| idx = st.session_state.flashcard_index | |
| card = cards[idx] | |
| if show_clicked: | |
| st.session_state.show_answer = True | |
| # Display current card | |
| st.markdown("#### Question:") | |
| st.info(card["q"]) | |
| if st.session_state.show_answer: | |
| st.markdown("#### Answer:") | |
| st.success(card["a"]) | |
| st.subheader("Sources used for these flashcards") | |
| st.markdown(st.session_state.flash_sources) | |
| st.markdown("---") | |
| st.caption("MSc in Business Data Science β Student ID: 12345. Current semester: 1.") | |
| st.caption("Developed by Group 4 β Ali Moghadas, Amalie Hougaard Lang and Emina Gracanin.") | |