File size: 2,483 Bytes
06c58f2
b5133a4
d0acc14
06c58f2
b5133a4
06c58f2
b5133a4
06c58f2
d0acc14
b5133a4
268d57b
d0acc14
 
 
 
 
b5133a4
 
 
 
d0acc14
b5133a4
 
d0acc14
b5133a4
 
 
 
06c58f2
b5133a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, pipeline
import random

st.set_page_config(page_title="AI Flashcard Quiz", layout="centered", page_icon="🧠")

# Load models
@st.cache_resource
def load_models():
    tokenizer = T5Tokenizer.from_pretrained("iarfmoose/t5-base-question-generator")
    model = AutoModelForSeq2SeqLM.from_pretrained("iarfmoose/t5-base-question-generator")
    distractor_gen = pipeline("text-generation", model="gpt2", max_length=20)
    return tokenizer, model, distractor_gen

tokenizer, model, distractor_gen = load_models()

st.markdown(
    "<h1 style='text-align: center; color: #4CAF50;'>🤖 AI Flashcard Quiz</h1>",
    unsafe_allow_html=True
)

# Input text
context = st.text_area("✍️ Paste your study content here:", height=200)

# Generate questions button
if st.button("🎯 Generate Flashcard"):
    if context.strip() == "":
        st.warning("Please enter some text to generate a flashcard.")
    else:
        with st.spinner("Generating question and options..."):
            # Prepare the input for question generation
            input_text = f"generate question: {context} </s>"
            input_ids = tokenizer.encode(input_text, return_tensors="pt")
            outputs = model.generate(input_ids, max_length=64, num_beams=4, early_stopping=True)
            question = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Use the same context as the correct answer (simplified)
            correct_answer = context.split()[0:random.randint(3, 7)]
            correct_answer = " ".join(correct_answer).strip()

            # Generate distractors
            distractors = []
            for _ in range(3):
                distractor = distractor_gen(f"{question}")[0]['generated_text']
                distractor = distractor.replace(question, "").strip().split(".")[0]
                distractors.append(distractor if distractor else "None")

            # Shuffle options
            options = distractors + [correct_answer]
            random.shuffle(options)

            st.markdown(f"### ❓ {question}")
            selected = st.radio("Choose the correct answer:", options)

            if st.button("Check Answer"):
                if selected == correct_answer:
                    st.success("✅ Correct! Great job.")
                    st.balloons()
                else:
                    st.error(f"❌ Oops! Correct answer was: **{correct_answer}**")