import os import torch import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM from chats import init_db, get_all_chats, create_new_chat, save_message, get_messages, system_prompt # Set HF cache directory os.environ["HF_HOME"] = "/tmp/huggingface_cache" # Initialize Streamlit app st.set_page_config(page_title="AI Assistant", page_icon="🤖") st.title("🤖 Juma's Assistant") # Load model and tokenizer once at startup @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) return tokenizer, model, device tokenizer, model, device = load_model() # Initialize database init_db() # Sidebar for previous chats st.sidebar.title("💬 Previous Chats") all_chats = get_all_chats() chat_titles = [f"{title} (ID: {chat_id})" for chat_id, title in all_chats] selected_chat_index = st.sidebar.selectbox( "Select Chat", range(len(all_chats)), format_func=lambda i: chat_titles[i] if all_chats else "No chats available" ) selected_chat_id = all_chats[selected_chat_index][0] if all_chats else None if st.sidebar.button("🆕 Start New Chat"): selected_chat_id = create_new_chat() st.experimental_rerun() if selected_chat_id is None: st.warning("Please start a new chat or select one from the sidebar.") st.stop() # Display chat history messages = get_messages(selected_chat_id) for role, content in messages: with st.chat_message(role): st.markdown(content) # Handle user input user_input = st.chat_input("Type your message...") if user_input: st.chat_message("user").markdown(user_input) save_message(selected_chat_id, "user", user_input) with st.spinner("Thinking..."): try: # Manually format the chat prompt system_message = system_prompt() prompt = f"<|SYSTEM|> {system_message} <|USER|> {user_input} <|ASSISTANT>" # Tokenize the formatted prompt inputs = tokenizer( prompt, return_tensors="pt", add_special_tokens=True ).to(device) # Generate tokens full_response = "" placeholder = st.empty() # Stream tokens generated = inputs["input_ids"] outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9, no_repeat_ngram_size=3, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=False ) sequence = outputs.sequences[0] # Decode tokens one by one, preserving spaces for i in range(generated.shape[-1], sequence.shape[-1]): token_id = sequence[i].unsqueeze(0) text = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) if text: full_response += text # placeholder.markdown(full_response) # Final response, decoding only new tokens final_response = tokenizer.decode( sequence[generated.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True ).strip() st.chat_message("assistant").markdown(final_response) save_message(selected_chat_id, "assistant", final_response) except Exception as e: st.error(f"Error: {str(e)}")