import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch st.title("📚 Study Buddy Chatbot") st.write("Ask a question or type a topic, and I'll help you learn interactively!") # Initialize session state for conversation history if "conversation" not in st.session_state: st.session_state.conversation = [] # Load model with better caching and memory management @st.cache_resource def load_model(): MODEL_NAME = "HuggingFaceH4/zephyr-7b-alpha" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True ) return tokenizer, model # Only load model when needed if "tokenizer" not in st.session_state or "model" not in st.session_state: with st.spinner("Loading AI model (this may take a minute)..."): st.session_state.tokenizer, st.session_state.model = load_model() def get_response(user_input): # Get tokenizer and model from session state tokenizer = st.session_state.tokenizer model = st.session_state.model # Format conversation history for context history = "\n".join(st.session_state.conversation[-6:]) # Last 6 exchanges prompt = ( f"You are a knowledgeable study coach. Engage the student in conversation. " f"Ask open-ended questions to deepen understanding. Provide feedback and encourage explanations.\n\n" f"Previous conversation:\n{history}\n\n" f"Student: {user_input}\n" f"Coach: " ) # Better generation parameters input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=250, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.2 ) response = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True) return response # User interface user_input = st.text_input("Type your question or topic:") if user_input: with st.spinner("Thinking..."): response = get_response(user_input) # Add to conversation history st.session_state.conversation.append(f"Student: {user_input}") st.session_state.conversation.append(f"Coach: {response}") # Display conversation in a better format st.subheader("Conversation History") for i, message in enumerate(st.session_state.conversation[-10:]): if i % 2 == 0: # Student messages st.markdown(f"**You**: {message.replace('Student: ', '')}") else: # Coach messages st.markdown(f"**Coach**: {message.replace('Coach: ', '')}") # Add a clear conversation button if st.button("Clear Conversation"): st.session_state.conversation = [] st.experimental_rerun()