Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from utils import load_model, generate_stream | |
| import time | |
| def init_chat(): | |
| try: | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "generating" not in st.session_state: | |
| st.session_state.generating = False | |
| # Load model if needed | |
| if "model" not in st.session_state or "tokenizer" not in st.session_state: | |
| with st.spinner("π Initializing AI model..."): | |
| st.session_state.model, st.session_state.tokenizer = load_model() | |
| except Exception as e: | |
| st.error(f"Initialization error: {str(e)}") | |
| st.button("π Retry Loading") | |
| st.stop() | |
| st.title("π Chat Interface") | |
| # Initialize | |
| init_chat() | |
| # Verify model loaded | |
| if not st.session_state.get("model_loaded", False): | |
| st.warning("β οΈ Model not fully loaded. Please wait...") | |
| st.stop() | |
| # Sidebar controls | |
| with st.sidebar: | |
| st.markdown("### Chat Controls") | |
| cols = st.columns(2) | |
| with cols[0]: | |
| if st.button("ποΈ Clear Chat", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.session_state.generating = False | |
| st.rerun() | |
| with cols[1]: | |
| if st.button("π Reset Model", use_container_width=True): | |
| st.session_state.clear() | |
| st.cache_resource.clear() | |
| st.rerun() | |
| # Chat history | |
| chat_container = st.container() | |
| with chat_container: | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # Input handling | |
| if prompt := st.chat_input( | |
| "Ask me anything...", | |
| disabled=st.session_state.generating | |
| ): | |
| # Update generating state | |
| st.session_state.generating = True | |
| # Show user message | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Generate and show response | |
| with st.chat_message("assistant"): | |
| try: | |
| context = "\n".join([ | |
| f"{m['role']}: {m['content']}" | |
| for m in st.session_state.messages[-3:] | |
| ]) | |
| response = generate_stream( | |
| context, | |
| st.session_state.model, | |
| st.session_state.tokenizer | |
| ) | |
| if response: | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| except Exception as e: | |
| st.error("Failed to generate response. Please try again.") | |
| st.error(f"Error details: {str(e)}") | |
| finally: | |
| st.session_state.generating = False | |