File size: 3,879 Bytes
043fc8b 376b382 d2681a2 043fc8b 376b382 043fc8b d2681a2 043fc8b d2681a2 63b2a2b d2681a2 043fc8b d2681a2 043fc8b d2681a2 043fc8b d2681a2 043fc8b b9c9890 043fc8b d2681a2 043fc8b d2681a2 043fc8b 6ea5e8d d2681a2 6ea5e8d d2681a2 6ea5e8d a9b6845 d2681a2 a9b6845 d2681a2 043fc8b d2681a2 6ea5e8d d2681a2 6ea5e8d d2681a2 62e49f3 d2681a2 6ea5e8d 043fc8b 6ea5e8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from chats import init_db, get_all_chats, create_new_chat, save_message, get_messages, system_prompt
# Set HF cache directory
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Initialize Streamlit app
st.set_page_config(page_title="AI Assistant", page_icon="π€")
st.title("π€ Juma's Assistant")
# Load model and tokenizer once at startup
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return tokenizer, model, device
tokenizer, model, device = load_model()
# Initialize database
init_db()
# Sidebar for previous chats
st.sidebar.title("π¬ Previous Chats")
all_chats = get_all_chats()
chat_titles = [f"{title} (ID: {chat_id})" for chat_id, title in all_chats]
selected_chat_index = st.sidebar.selectbox(
"Select Chat", range(len(all_chats)), format_func=lambda i: chat_titles[i] if all_chats else "No chats available"
)
selected_chat_id = all_chats[selected_chat_index][0] if all_chats else None
if st.sidebar.button("π Start New Chat"):
selected_chat_id = create_new_chat()
st.experimental_rerun()
if selected_chat_id is None:
st.warning("Please start a new chat or select one from the sidebar.")
st.stop()
# Display chat history
messages = get_messages(selected_chat_id)
for role, content in messages:
with st.chat_message(role):
st.markdown(content)
# Handle user input
user_input = st.chat_input("Type your message...")
if user_input:
st.chat_message("user").markdown(user_input)
save_message(selected_chat_id, "user", user_input)
with st.spinner("Thinking..."):
try:
# Manually format the chat prompt
system_message = system_prompt()
prompt = f"<|SYSTEM|> {system_message} <|USER|> {user_input} <|ASSISTANT>"
# Tokenize the formatted prompt
inputs = tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=True
).to(device)
# Generate tokens
full_response = ""
placeholder = st.empty()
# Stream tokens
generated = inputs["input_ids"]
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=3,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False
)
sequence = outputs.sequences[0]
# Decode tokens one by one, preserving spaces
for i in range(generated.shape[-1], sequence.shape[-1]):
token_id = sequence[i].unsqueeze(0)
text = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if text:
full_response += text
# placeholder.markdown(full_response)
# Final response, decoding only new tokens
final_response = tokenizer.decode(
sequence[generated.shape[-1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
).strip()
st.chat_message("assistant").markdown(final_response)
save_message(selected_chat_id, "assistant", final_response)
except Exception as e:
st.error(f"Error: {str(e)}") |