File size: 2,919 Bytes
ffe6783
fbb08a4
ffe6783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb08a4
ffe6783
fbb08a4
ffe6783
56f8448
fbb08a4
 
 
 
ffe6783
 
 
7384eb2
 
 
 
 
 
 
ffe6783
 
 
f7717dd
 
ffe6783
 
 
 
 
 
f7717dd
 
ffe6783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

st.title("📚 Study Buddy Chatbot")
st.write("Ask a question or type a topic, and I'll help you learn interactively!")

# Initialize session state for conversation history
if "conversation" not in st.session_state:
    st.session_state.conversation = []

# Load model with better caching and memory management
@st.cache_resource
def load_model():
    MODEL_NAME = "HuggingFaceH4/zephyr-7b-alpha"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, 
        torch_dtype=torch.float16, 
        device_map="auto",
        low_cpu_mem_usage=True
    )
    return tokenizer, model

# Only load model when needed
if "tokenizer" not in st.session_state or "model" not in st.session_state:
    with st.spinner("Loading AI model (this may take a minute)..."):
        st.session_state.tokenizer, st.session_state.model = load_model()

def get_response(user_input):
    # Get tokenizer and model from session state
    tokenizer = st.session_state.tokenizer
    model = st.session_state.model
    
    # Format conversation history for context
    history = "\n".join(st.session_state.conversation[-6:])  # Last 6 exchanges
    
    prompt = (
        f"You are a knowledgeable study coach. Engage the student in conversation. "
        f"Ask open-ended questions to deepen understanding. Provide feedback and encourage explanations.\n\n"
        f"Previous conversation:\n{history}\n\n"
        f"Student: {user_input}\n"
        f"Coach: "
    )
    
    # Better generation parameters
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    with torch.no_grad():
        output = model.generate(
            input_ids, 
            max_new_tokens=250,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.2
        )
    
    response = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)
    return response

# User interface
user_input = st.text_input("Type your question or topic:")

if user_input:
    with st.spinner("Thinking..."):
        response = get_response(user_input)
    
    # Add to conversation history
    st.session_state.conversation.append(f"Student: {user_input}")
    st.session_state.conversation.append(f"Coach: {response}")

# Display conversation in a better format
st.subheader("Conversation History")
for i, message in enumerate(st.session_state.conversation[-10:]):
    if i % 2 == 0:  # Student messages
        st.markdown(f"**You**: {message.replace('Student: ', '')}")
    else:  # Coach messages
        st.markdown(f"**Coach**: {message.replace('Coach: ', '')}")

# Add a clear conversation button
if st.button("Clear Conversation"):
    st.session_state.conversation = []
    st.experimental_rerun()