M3_prototype / app.py
houlie3's picture
Rename prototype.py to app.py
2136067 verified
import os
import re
import streamlit as st
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import PyPDF2
from pptx import Presentation
import docx2txt
import nbformat
import google.generativeai as genai
# -----------------------------
# BASIC CONFIG
# -----------------------------
st.set_page_config(page_title="ModuleMate AI", page_icon="πŸŽ“", layout="wide")
# ⚠️ For local prototype only – hardcoded key
# Replace this with your own Gemini key (and NEVER commit the api key to GitHub).
API_KEY = "your gemini api key here"
genai.configure(api_key=API_KEY)
# -----------------------------
# CACHED MODELS (LAZY LOADING)
# -----------------------------
@st.cache_resource
def get_embed_model():
# Loaded once per session, but ONLY when first called
return SentenceTransformer("all-MiniLM-L6-v2")
@st.cache_resource
def get_gemini_model():
# Loaded once per session, but ONLY when first called
return genai.GenerativeModel("gemini-1.5-flash")
# -----------------------------
# SESSION STATE
# -----------------------------
if "vector_index" not in st.session_state:
st.session_state.vector_index = None
if "chunk_texts" not in st.session_state:
st.session_state.chunk_texts = []
if "chunk_sources" not in st.session_state:
st.session_state.chunk_sources = [] # file names for each chunk
if "flashcards" not in st.session_state:
st.session_state.flashcards = []
if "flashcard_index" not in st.session_state:
st.session_state.flashcard_index = 0
if "flash_sources" not in st.session_state:
st.session_state.flash_sources = ""
if "show_answer" not in st.session_state:
st.session_state.show_answer = False
# -----------------------------
# FILE LOADERS (for uploaded file objects)
# -----------------------------
def load_pdf(file):
reader = PyPDF2.PdfReader(file)
return "\n".join([p.extract_text() or "" for p in reader.pages])
def load_docx(file):
return docx2txt.process(file)
def load_pptx(file):
prs = Presentation(file)
texts = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
texts.append(shape.text)
return "\n".join(texts)
def load_txt(file):
return file.read().decode("utf-8")
def load_ipynb(file):
nb = nbformat.read(file, as_version=4)
code_cells = [c.source for c in nb.cells if c.cell_type == "code"]
return "\n".join(code_cells)
def load_py(file):
return file.read().decode("utf-8")
def load_file(file):
ext = os.path.splitext(file.name)[1].lower()
if ext == ".pdf":
return load_pdf(file)
elif ext == ".docx":
return load_docx(file)
elif ext == ".pptx":
return load_pptx(file)
elif ext == ".txt":
return load_txt(file)
elif ext == ".ipynb":
return load_ipynb(file)
elif ext == ".py":
return load_py(file)
return ""
# -----------------------------
# CHUNKING
# -----------------------------
def chunk_text(text, chunk_size=1000, overlap=100):
"""Simple sliding window over the text."""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - overlap
return chunks
# -----------------------------
# BUILD INDEX (stores in session_state)
# -----------------------------
def build_index(files):
all_texts = []
all_sources = []
for file in files:
text = load_file(file)
if not text:
continue
chunks = chunk_text(text)
for ch in chunks:
all_texts.append(ch)
all_sources.append(file.name)
if not all_texts:
st.error("No readable text found in the uploaded files.")
return
# Load embedding model lazily and cache it
embed_model = get_embed_model()
embeddings = embed_model.encode(all_texts)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings).astype("float32"))
st.session_state.vector_index = index
st.session_state.chunk_texts = all_texts
st.session_state.chunk_sources = all_sources
st.success(
f"Index built with {len(all_texts)} chunks from {len(set(all_sources))} files."
)
# -----------------------------
# RETRIEVAL HELPER
# -----------------------------
def retrieve_context(question, k=5):
"""Return (context_text, sources_summary, selected_chunks)"""
if st.session_state.vector_index is None or not st.session_state.chunk_texts:
return None, "❌ Please upload files and build the index.", []
embed_model = get_embed_model()
query_emb = embed_model.encode([question])
D, I = st.session_state.vector_index.search(
np.array(query_emb).astype("float32"), k=k
)
selected_chunks = []
for idx in I[0]:
text = st.session_state.chunk_texts[idx]
src = st.session_state.chunk_sources[idx]
selected_chunks.append((text, src))
context = "\n\n---\n\n".join([c[0] for c in selected_chunks])
# Human-readable source summary
source_lines = []
for i, (chunk, src) in enumerate(selected_chunks, start=1):
preview = chunk.replace("\n", " ")
if len(preview) > 160:
preview = preview[:160] + "..."
source_lines.append(f"**Chunk {i}** – `{src}`\n> {preview}")
sources_summary = (
"\n\n".join(source_lines) if source_lines else "No sources retrieved."
)
return context, sources_summary, selected_chunks
# -----------------------------
# PROMPT HELPERS
# -----------------------------
def task_type_instructions(task_type):
mapping = {
"Brainstorming": "Focus on generating diverse, creative ideas and alternative angles. Avoid giving final polished answers.",
"Studying for exam": "Focus on clear explanations, key concepts, bullet-point summaries and small self-test questions.",
"Writing assistance": "Help with structure, clarity, argumentation and academic phrasing. Avoid writing full assignments in restricted mode.",
"Coding help": "Explain code step by step, suggest improvements and debugging hints. Emphasise understanding over copy-paste solutions.",
"Homework help": "Guide the student through the problem with hints and intermediate steps, not full final solutions.",
"Research assistance": "Help with concept clarification, comparing theories and outlining approaches. Avoid fabricating citations.",
}
return mapping.get(task_type, "")
def difficulty_instructions(level):
if level == "Beginner":
return "Explain as if to a first-year student. Use simple language, concrete examples and avoid heavy notation."
elif level == "Intermediate":
return "Assume some prior knowledge. Use correct terminology but keep explanations clear and structured."
else: # Advanced
return "Assume a strong background. You may use technical terminology, formal definitions and more compact explanations."
def material_focus_instructions(material_focus):
if not material_focus:
return ""
joined = ", ".join(material_focus)
return f"Prioritise information that appears to come from: {joined}. If relevant material is missing, say so."
# -----------------------------
# RAG QUERY (normal answer)
# -----------------------------
def rag_query(question, mode, task_type, difficulty, material_focus):
context, sources_summary, _ = retrieve_context(question)
if context is None:
return "❌ Please upload files and build the index.", ""
task_instr = task_type_instructions(task_type)
diff_instr = difficulty_instructions(difficulty)
focus_instr = material_focus_instructions(material_focus)
if mode == "AI Integrity Mode (restricted)":
system_block = f"""
You are ModuleMate AI in RESTRICTED MODE at a university.
Your role is to support learning **without** solving assessments directly.
You must:
- Use ONLY the context below.
- Give conceptual hints, guidance and explanations.
- **Do NOT** provide full final solutions to assignments, exam questions or project work.
- If the question looks like an assignment/exam question, remind the student that you can only give hints.
"""
else:
system_block = """
You are ModuleMate AI, a course-aware study assistant.
You may give full explanations and worked examples, but keep them grounded in the provided material.
"""
prompt = f"""
{system_block}
Task type: {task_type}
Difficulty level: {difficulty}
Intent-specific instructions: {task_instr}
{diff_instr}
{focus_instr}
CONTEXT FROM COURSE MATERIAL:
{context}
STUDENT QUESTION:
{question}
Now answer accordingly.
"""
try:
model = get_gemini_model()
response = model.generate_content(prompt)
return response.text, sources_summary
except Exception as e:
fallback = (
"⚠️ Gemini API error: "
+ str(e)
+ "\n\nHere are the most relevant sections from your course material:\n\n"
+ context
)
return fallback, sources_summary
# -----------------------------
# FLASHCARD PARSING + GENERATION
# -----------------------------
def parse_flashcards(raw_text):
"""
Parse model output into a list of {'q': ..., 'a': ...}.
Handles formats like:
Q: ...
A: ...
or
1. Q: ...
A: ...
or with bullets / markdown.
"""
cards = []
# Normalise line breaks and strip
text = raw_text.strip()
if not text:
return cards
# Split into blocks starting with Q / q
blocks = re.split(r"\n(?=[Qq][\.:)\- ]|\*\*Q)", text)
for block in blocks:
lines = [ln.strip() for ln in block.splitlines() if ln.strip()]
if not lines:
continue
# Find question line
q_line = None
a_line = None
for ln in lines:
if re.match(r"^(Q|q)[\.:)\- ]", ln) or ln.startswith("**Q"):
q_line = ln
break
for ln in lines:
if re.match(r"^(A|a)[\.:)\- ]", ln) or ln.startswith("**A"):
a_line = ln
break
if q_line and a_line:
q = re.sub(r"^(Q|q)[\.:)\- ]\s*|\*\*Q\*\*[:\- ]\s*", "", q_line).strip()
a = re.sub(r"^(A|a)[\.:)\- ]\s*|\*\*A\*\*[:\- ]\s*", "", a_line).strip()
if q and a:
cards.append({"q": q, "a": a})
return cards
def create_flashcards(question, mode, task_type, difficulty, material_focus):
context, sources_summary, _ = retrieve_context(question)
if context is None:
return [], sources_summary
diff_instr = difficulty_instructions(difficulty)
focus_instr = material_focus_instructions(material_focus)
prompt = f"""
You are ModuleMate AI, helping a university student create study flashcards.
The topic is:
{question}
Generate 8–12 concise flashcards (Q–A pairs) based ONLY on the context below.
Each flashcard should be useful for **exam revision** in this module.
Very important formatting rules:
- Write each card as two lines:
Q: <short question>
A: <short answer>
- Do NOT add numbering, bullets or extra commentary.
- Just repeat this pattern for every card.
{diff_instr}
{focus_instr}
CONTEXT:
{context}
"""
try:
model = get_gemini_model()
response = model.generate_content(prompt)
raw = response.text
cards = parse_flashcards(raw)
# If parsing fails, fall back to a single card with raw output
if not cards:
cards = [
{
"q": "Flashcards could not be parsed from the AI output.",
"a": raw.strip() or "No content returned.",
}
]
return cards, sources_summary
except Exception as e:
fallback_answer = (
"Gemini API error while creating flashcards: "
+ str(e)
+ "\n\nHere are the most relevant sections from your course material:\n\n"
+ context
)
cards = [
{
"q": "Flashcards could not be generated",
"a": fallback_answer,
}
]
return cards, sources_summary
# -----------------------------
# UI
# -----------------------------
st.markdown("## πŸŽ“ ModuleMate AI β€” Your Course-Aware Study Assistant")
# Real modules
modules = [
"Module 1 - Introduction to Data Handling, Exploration & Applied Machine Learning (E25)",
"Module 2 - Natural Language Processing and Network Analysis (E25)",
"Module 3 - Data-Driven Business Modelling and Strategy (E25)",
]
selected_module = st.selectbox("Select your current module:", modules)
mode = st.radio("Choose Mode:", ["Normal RAG Mode", "AI Integrity Mode (restricted)"])
task_type = st.selectbox(
"What do you need help with?",
[
"Studying for exam",
"Brainstorming",
"Writing assistance",
"Coding help",
"Homework help",
"Research assistance",
],
)
difficulty = st.selectbox(
"Preferred explanation level:",
["Beginner", "Intermediate", "Advanced"],
)
material_focus = st.multiselect(
"Which materials should I focus on?",
["Lecture slides", "Readings", "Assignments", "Tutorials", "Code examples"],
default=[
"Lecture slides",
"Readings",
"Assignments",
"Tutorials",
"Code examples",
],
)
st.divider()
uploaded_files = st.file_uploader(
"Upload course material (PDF, DOCX, PPTX, TXT, IPYNB, PY)",
accept_multiple_files=True,
)
if st.button("Build Index"):
if uploaded_files:
with st.spinner("Building vector index from your course material..."):
build_index(uploaded_files)
else:
st.error("Please upload at least one file.")
st.divider()
question = st.text_input("Ask a question about your module (or topic for flashcards):")
col1, col2 = st.columns(2)
with col1:
ask_clicked = st.button("Ask (normal answer)")
with col2:
flash_clicked = st.button("Generate flashcards")
# --- Normal RAG answer ---
if ask_clicked:
if question.strip():
with st.spinner("Searching your course material and generating an answer..."):
answer, sources = rag_query(
question, mode, task_type, difficulty, material_focus
)
st.subheader("Answer")
st.write(answer)
st.subheader("Sources used")
st.markdown(sources)
else:
st.error("Please enter a question or topic.")
# --- Flashcard generation ---
if flash_clicked:
if question.strip():
with st.spinner("Creating flashcards from your course material..."):
cards, sources = create_flashcards(
question, mode, task_type, difficulty, material_focus
)
st.session_state.flashcards = cards
st.session_state.flashcard_index = 0
st.session_state.flash_sources = sources
st.session_state.show_answer = False
else:
st.error("Please enter a topic for the flashcards.")
# --- Flashcard viewer ---
if st.session_state.flashcards:
cards = st.session_state.flashcards
idx = st.session_state.flashcard_index
card = cards[idx]
st.subheader("Flashcards")
st.write(f"Card {idx + 1} of {len(cards)}")
# Navigation + reveal controls
nav_col1, nav_col2, nav_col3 = st.columns([1, 2, 1])
with nav_col1:
prev_clicked = st.button("β—€ Previous")
with nav_col2:
show_clicked = st.button("Show answer")
with nav_col3:
next_clicked = st.button("Next β–Ά")
# Update state based on clicks
if prev_clicked and idx > 0:
st.session_state.flashcard_index -= 1
st.session_state.show_answer = False
idx = st.session_state.flashcard_index
card = cards[idx]
if next_clicked and idx < len(cards) - 1:
st.session_state.flashcard_index += 1
st.session_state.show_answer = False
idx = st.session_state.flashcard_index
card = cards[idx]
if show_clicked:
st.session_state.show_answer = True
# Display current card
st.markdown("#### Question:")
st.info(card["q"])
if st.session_state.show_answer:
st.markdown("#### Answer:")
st.success(card["a"])
st.subheader("Sources used for these flashcards")
st.markdown(st.session_state.flash_sources)
st.markdown("---")
st.caption("MSc in Business Data Science – Student ID: 12345. Current semester: 1.")
st.caption("Developed by Group 4 – Ali Moghadas, Amalie Hougaard Lang and Emina Gracanin.")