assistant / src /streamlit_app.py
JumaRubea's picture
Update src/streamlit_app.py
63b2a2b verified
import os
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from chats import init_db, get_all_chats, create_new_chat, save_message, get_messages, system_prompt
# Set HF cache directory
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Initialize Streamlit app
st.set_page_config(page_title="AI Assistant", page_icon="πŸ€–")
st.title("πŸ€– Juma's Assistant")
# Load model and tokenizer once at startup
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return tokenizer, model, device
tokenizer, model, device = load_model()
# Initialize database
init_db()
# Sidebar for previous chats
st.sidebar.title("πŸ’¬ Previous Chats")
all_chats = get_all_chats()
chat_titles = [f"{title} (ID: {chat_id})" for chat_id, title in all_chats]
selected_chat_index = st.sidebar.selectbox(
"Select Chat", range(len(all_chats)), format_func=lambda i: chat_titles[i] if all_chats else "No chats available"
)
selected_chat_id = all_chats[selected_chat_index][0] if all_chats else None
if st.sidebar.button("πŸ†• Start New Chat"):
selected_chat_id = create_new_chat()
st.experimental_rerun()
if selected_chat_id is None:
st.warning("Please start a new chat or select one from the sidebar.")
st.stop()
# Display chat history
messages = get_messages(selected_chat_id)
for role, content in messages:
with st.chat_message(role):
st.markdown(content)
# Handle user input
user_input = st.chat_input("Type your message...")
if user_input:
st.chat_message("user").markdown(user_input)
save_message(selected_chat_id, "user", user_input)
with st.spinner("Thinking..."):
try:
# Manually format the chat prompt
system_message = system_prompt()
prompt = f"<|SYSTEM|> {system_message} <|USER|> {user_input} <|ASSISTANT>"
# Tokenize the formatted prompt
inputs = tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=True
).to(device)
# Generate tokens
full_response = ""
placeholder = st.empty()
# Stream tokens
generated = inputs["input_ids"]
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=3,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False
)
sequence = outputs.sequences[0]
# Decode tokens one by one, preserving spaces
for i in range(generated.shape[-1], sequence.shape[-1]):
token_id = sequence[i].unsqueeze(0)
text = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if text:
full_response += text
# placeholder.markdown(full_response)
# Final response, decoding only new tokens
final_response = tokenizer.decode(
sequence[generated.shape[-1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
).strip()
st.chat_message("assistant").markdown(final_response)
save_message(selected_chat_id, "assistant", final_response)
except Exception as e:
st.error(f"Error: {str(e)}")