import os import re import streamlit as st import faiss import numpy as np from sentence_transformers import SentenceTransformer import PyPDF2 from pptx import Presentation import docx2txt import nbformat import google.generativeai as genai # ----------------------------- # BASIC CONFIG # ----------------------------- st.set_page_config(page_title="ModuleMate AI", page_icon="🎓", layout="wide") # ⚠️ For local prototype only – hardcoded key # Replace this with your own Gemini key (and NEVER commit the api key to GitHub). API_KEY = "your gemini api key here" genai.configure(api_key=API_KEY) # ----------------------------- # CACHED MODELS (LAZY LOADING) # ----------------------------- @st.cache_resource def get_embed_model(): # Loaded once per session, but ONLY when first called return SentenceTransformer("all-MiniLM-L6-v2") @st.cache_resource def get_gemini_model(): # Loaded once per session, but ONLY when first called return genai.GenerativeModel("gemini-1.5-flash") # ----------------------------- # SESSION STATE # ----------------------------- if "vector_index" not in st.session_state: st.session_state.vector_index = None if "chunk_texts" not in st.session_state: st.session_state.chunk_texts = [] if "chunk_sources" not in st.session_state: st.session_state.chunk_sources = [] # file names for each chunk if "flashcards" not in st.session_state: st.session_state.flashcards = [] if "flashcard_index" not in st.session_state: st.session_state.flashcard_index = 0 if "flash_sources" not in st.session_state: st.session_state.flash_sources = "" if "show_answer" not in st.session_state: st.session_state.show_answer = False # ----------------------------- # FILE LOADERS (for uploaded file objects) # ----------------------------- def load_pdf(file): reader = PyPDF2.PdfReader(file) return "\n".join([p.extract_text() or "" for p in reader.pages]) def load_docx(file): return docx2txt.process(file) def load_pptx(file): prs = Presentation(file) texts = [] for slide in prs.slides: for shape in slide.shapes: if hasattr(shape, "text"): texts.append(shape.text) return "\n".join(texts) def load_txt(file): return file.read().decode("utf-8") def load_ipynb(file): nb = nbformat.read(file, as_version=4) code_cells = [c.source for c in nb.cells if c.cell_type == "code"] return "\n".join(code_cells) def load_py(file): return file.read().decode("utf-8") def load_file(file): ext = os.path.splitext(file.name)[1].lower() if ext == ".pdf": return load_pdf(file) elif ext == ".docx": return load_docx(file) elif ext == ".pptx": return load_pptx(file) elif ext == ".txt": return load_txt(file) elif ext == ".ipynb": return load_ipynb(file) elif ext == ".py": return load_py(file) return "" # ----------------------------- # CHUNKING # ----------------------------- def chunk_text(text, chunk_size=1000, overlap=100): """Simple sliding window over the text.""" chunks = [] start = 0 while start < len(text): end = start + chunk_size chunks.append(text[start:end]) start += chunk_size - overlap return chunks # ----------------------------- # BUILD INDEX (stores in session_state) # ----------------------------- def build_index(files): all_texts = [] all_sources = [] for file in files: text = load_file(file) if not text: continue chunks = chunk_text(text) for ch in chunks: all_texts.append(ch) all_sources.append(file.name) if not all_texts: st.error("No readable text found in the uploaded files.") return # Load embedding model lazily and cache it embed_model = get_embed_model() embeddings = embed_model.encode(all_texts) dim = embeddings.shape[1] index = faiss.IndexFlatL2(dim) index.add(np.array(embeddings).astype("float32")) st.session_state.vector_index = index st.session_state.chunk_texts = all_texts st.session_state.chunk_sources = all_sources st.success( f"Index built with {len(all_texts)} chunks from {len(set(all_sources))} files." ) # ----------------------------- # RETRIEVAL HELPER # ----------------------------- def retrieve_context(question, k=5): """Return (context_text, sources_summary, selected_chunks)""" if st.session_state.vector_index is None or not st.session_state.chunk_texts: return None, "❌ Please upload files and build the index.", [] embed_model = get_embed_model() query_emb = embed_model.encode([question]) D, I = st.session_state.vector_index.search( np.array(query_emb).astype("float32"), k=k ) selected_chunks = [] for idx in I[0]: text = st.session_state.chunk_texts[idx] src = st.session_state.chunk_sources[idx] selected_chunks.append((text, src)) context = "\n\n---\n\n".join([c[0] for c in selected_chunks]) # Human-readable source summary source_lines = [] for i, (chunk, src) in enumerate(selected_chunks, start=1): preview = chunk.replace("\n", " ") if len(preview) > 160: preview = preview[:160] + "..." source_lines.append(f"**Chunk {i}** – `{src}`\n> {preview}") sources_summary = ( "\n\n".join(source_lines) if source_lines else "No sources retrieved." ) return context, sources_summary, selected_chunks # ----------------------------- # PROMPT HELPERS # ----------------------------- def task_type_instructions(task_type): mapping = { "Brainstorming": "Focus on generating diverse, creative ideas and alternative angles. Avoid giving final polished answers.", "Studying for exam": "Focus on clear explanations, key concepts, bullet-point summaries and small self-test questions.", "Writing assistance": "Help with structure, clarity, argumentation and academic phrasing. Avoid writing full assignments in restricted mode.", "Coding help": "Explain code step by step, suggest improvements and debugging hints. Emphasise understanding over copy-paste solutions.", "Homework help": "Guide the student through the problem with hints and intermediate steps, not full final solutions.", "Research assistance": "Help with concept clarification, comparing theories and outlining approaches. Avoid fabricating citations.", } return mapping.get(task_type, "") def difficulty_instructions(level): if level == "Beginner": return "Explain as if to a first-year student. Use simple language, concrete examples and avoid heavy notation." elif level == "Intermediate": return "Assume some prior knowledge. Use correct terminology but keep explanations clear and structured." else: # Advanced return "Assume a strong background. You may use technical terminology, formal definitions and more compact explanations." def material_focus_instructions(material_focus): if not material_focus: return "" joined = ", ".join(material_focus) return f"Prioritise information that appears to come from: {joined}. If relevant material is missing, say so." # ----------------------------- # RAG QUERY (normal answer) # ----------------------------- def rag_query(question, mode, task_type, difficulty, material_focus): context, sources_summary, _ = retrieve_context(question) if context is None: return "❌ Please upload files and build the index.", "" task_instr = task_type_instructions(task_type) diff_instr = difficulty_instructions(difficulty) focus_instr = material_focus_instructions(material_focus) if mode == "AI Integrity Mode (restricted)": system_block = f""" You are ModuleMate AI in RESTRICTED MODE at a university. Your role is to support learning **without** solving assessments directly. You must: - Use ONLY the context below. - Give conceptual hints, guidance and explanations. - **Do NOT** provide full final solutions to assignments, exam questions or project work. - If the question looks like an assignment/exam question, remind the student that you can only give hints. """ else: system_block = """ You are ModuleMate AI, a course-aware study assistant. You may give full explanations and worked examples, but keep them grounded in the provided material. """ prompt = f""" {system_block} Task type: {task_type} Difficulty level: {difficulty} Intent-specific instructions: {task_instr} {diff_instr} {focus_instr} CONTEXT FROM COURSE MATERIAL: {context} STUDENT QUESTION: {question} Now answer accordingly. """ try: model = get_gemini_model() response = model.generate_content(prompt) return response.text, sources_summary except Exception as e: fallback = ( "⚠️ Gemini API error: " + str(e) + "\n\nHere are the most relevant sections from your course material:\n\n" + context ) return fallback, sources_summary # ----------------------------- # FLASHCARD PARSING + GENERATION # ----------------------------- def parse_flashcards(raw_text): """ Parse model output into a list of {'q': ..., 'a': ...}. Handles formats like: Q: ... A: ... or 1. Q: ... A: ... or with bullets / markdown. """ cards = [] # Normalise line breaks and strip text = raw_text.strip() if not text: return cards # Split into blocks starting with Q / q blocks = re.split(r"\n(?=[Qq][\.:)\- ]|\*\*Q)", text) for block in blocks: lines = [ln.strip() for ln in block.splitlines() if ln.strip()] if not lines: continue # Find question line q_line = None a_line = None for ln in lines: if re.match(r"^(Q|q)[\.:)\- ]", ln) or ln.startswith("**Q"): q_line = ln break for ln in lines: if re.match(r"^(A|a)[\.:)\- ]", ln) or ln.startswith("**A"): a_line = ln break if q_line and a_line: q = re.sub(r"^(Q|q)[\.:)\- ]\s*|\*\*Q\*\*[:\- ]\s*", "", q_line).strip() a = re.sub(r"^(A|a)[\.:)\- ]\s*|\*\*A\*\*[:\- ]\s*", "", a_line).strip() if q and a: cards.append({"q": q, "a": a}) return cards def create_flashcards(question, mode, task_type, difficulty, material_focus): context, sources_summary, _ = retrieve_context(question) if context is None: return [], sources_summary diff_instr = difficulty_instructions(difficulty) focus_instr = material_focus_instructions(material_focus) prompt = f""" You are ModuleMate AI, helping a university student create study flashcards. The topic is: {question} Generate 8–12 concise flashcards (Q–A pairs) based ONLY on the context below. Each flashcard should be useful for **exam revision** in this module. Very important formatting rules: - Write each card as two lines: Q: A: - Do NOT add numbering, bullets or extra commentary. - Just repeat this pattern for every card. {diff_instr} {focus_instr} CONTEXT: {context} """ try: model = get_gemini_model() response = model.generate_content(prompt) raw = response.text cards = parse_flashcards(raw) # If parsing fails, fall back to a single card with raw output if not cards: cards = [ { "q": "Flashcards could not be parsed from the AI output.", "a": raw.strip() or "No content returned.", } ] return cards, sources_summary except Exception as e: fallback_answer = ( "Gemini API error while creating flashcards: " + str(e) + "\n\nHere are the most relevant sections from your course material:\n\n" + context ) cards = [ { "q": "Flashcards could not be generated", "a": fallback_answer, } ] return cards, sources_summary # ----------------------------- # UI # ----------------------------- st.markdown("## 🎓 ModuleMate AI — Your Course-Aware Study Assistant") # Real modules modules = [ "Module 1 - Introduction to Data Handling, Exploration & Applied Machine Learning (E25)", "Module 2 - Natural Language Processing and Network Analysis (E25)", "Module 3 - Data-Driven Business Modelling and Strategy (E25)", ] selected_module = st.selectbox("Select your current module:", modules) mode = st.radio("Choose Mode:", ["Normal RAG Mode", "AI Integrity Mode (restricted)"]) task_type = st.selectbox( "What do you need help with?", [ "Studying for exam", "Brainstorming", "Writing assistance", "Coding help", "Homework help", "Research assistance", ], ) difficulty = st.selectbox( "Preferred explanation level:", ["Beginner", "Intermediate", "Advanced"], ) material_focus = st.multiselect( "Which materials should I focus on?", ["Lecture slides", "Readings", "Assignments", "Tutorials", "Code examples"], default=[ "Lecture slides", "Readings", "Assignments", "Tutorials", "Code examples", ], ) st.divider() uploaded_files = st.file_uploader( "Upload course material (PDF, DOCX, PPTX, TXT, IPYNB, PY)", accept_multiple_files=True, ) if st.button("Build Index"): if uploaded_files: with st.spinner("Building vector index from your course material..."): build_index(uploaded_files) else: st.error("Please upload at least one file.") st.divider() question = st.text_input("Ask a question about your module (or topic for flashcards):") col1, col2 = st.columns(2) with col1: ask_clicked = st.button("Ask (normal answer)") with col2: flash_clicked = st.button("Generate flashcards") # --- Normal RAG answer --- if ask_clicked: if question.strip(): with st.spinner("Searching your course material and generating an answer..."): answer, sources = rag_query( question, mode, task_type, difficulty, material_focus ) st.subheader("Answer") st.write(answer) st.subheader("Sources used") st.markdown(sources) else: st.error("Please enter a question or topic.") # --- Flashcard generation --- if flash_clicked: if question.strip(): with st.spinner("Creating flashcards from your course material..."): cards, sources = create_flashcards( question, mode, task_type, difficulty, material_focus ) st.session_state.flashcards = cards st.session_state.flashcard_index = 0 st.session_state.flash_sources = sources st.session_state.show_answer = False else: st.error("Please enter a topic for the flashcards.") # --- Flashcard viewer --- if st.session_state.flashcards: cards = st.session_state.flashcards idx = st.session_state.flashcard_index card = cards[idx] st.subheader("Flashcards") st.write(f"Card {idx + 1} of {len(cards)}") # Navigation + reveal controls nav_col1, nav_col2, nav_col3 = st.columns([1, 2, 1]) with nav_col1: prev_clicked = st.button("◀ Previous") with nav_col2: show_clicked = st.button("Show answer") with nav_col3: next_clicked = st.button("Next ▶") # Update state based on clicks if prev_clicked and idx > 0: st.session_state.flashcard_index -= 1 st.session_state.show_answer = False idx = st.session_state.flashcard_index card = cards[idx] if next_clicked and idx < len(cards) - 1: st.session_state.flashcard_index += 1 st.session_state.show_answer = False idx = st.session_state.flashcard_index card = cards[idx] if show_clicked: st.session_state.show_answer = True # Display current card st.markdown("#### Question:") st.info(card["q"]) if st.session_state.show_answer: st.markdown("#### Answer:") st.success(card["a"]) st.subheader("Sources used for these flashcards") st.markdown(st.session_state.flash_sources) st.markdown("---") st.caption("MSc in Business Data Science – Student ID: 12345. Current semester: 1.") st.caption("Developed by Group 4 – Ali Moghadas, Amalie Hougaard Lang and Emina Gracanin.")