|
|
import os |
|
|
import torch |
|
|
import streamlit as st |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from chats import init_db, get_all_chats, create_new_chat, save_message, get_messages, system_prompt |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/huggingface_cache" |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="AI Assistant", page_icon="π€") |
|
|
st.title("π€ Juma's Assistant") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") |
|
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
return tokenizer, model, device |
|
|
|
|
|
tokenizer, model, device = load_model() |
|
|
|
|
|
|
|
|
init_db() |
|
|
|
|
|
|
|
|
st.sidebar.title("π¬ Previous Chats") |
|
|
all_chats = get_all_chats() |
|
|
|
|
|
chat_titles = [f"{title} (ID: {chat_id})" for chat_id, title in all_chats] |
|
|
selected_chat_index = st.sidebar.selectbox( |
|
|
"Select Chat", range(len(all_chats)), format_func=lambda i: chat_titles[i] if all_chats else "No chats available" |
|
|
) |
|
|
|
|
|
selected_chat_id = all_chats[selected_chat_index][0] if all_chats else None |
|
|
|
|
|
if st.sidebar.button("π Start New Chat"): |
|
|
selected_chat_id = create_new_chat() |
|
|
st.experimental_rerun() |
|
|
|
|
|
if selected_chat_id is None: |
|
|
st.warning("Please start a new chat or select one from the sidebar.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
messages = get_messages(selected_chat_id) |
|
|
for role, content in messages: |
|
|
with st.chat_message(role): |
|
|
st.markdown(content) |
|
|
|
|
|
|
|
|
user_input = st.chat_input("Type your message...") |
|
|
if user_input: |
|
|
st.chat_message("user").markdown(user_input) |
|
|
save_message(selected_chat_id, "user", user_input) |
|
|
|
|
|
with st.spinner("Thinking..."): |
|
|
try: |
|
|
|
|
|
system_message = system_prompt() |
|
|
prompt = f"<|SYSTEM|> {system_message} <|USER|> {user_input} <|ASSISTANT>" |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
add_special_tokens=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
full_response = "" |
|
|
placeholder = st.empty() |
|
|
|
|
|
|
|
|
generated = inputs["input_ids"] |
|
|
outputs = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
attention_mask=inputs["attention_mask"], |
|
|
max_new_tokens=150, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
no_repeat_ngram_size=3, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=False |
|
|
) |
|
|
sequence = outputs.sequences[0] |
|
|
|
|
|
|
|
|
for i in range(generated.shape[-1], sequence.shape[-1]): |
|
|
token_id = sequence[i].unsqueeze(0) |
|
|
text = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
if text: |
|
|
full_response += text |
|
|
|
|
|
|
|
|
|
|
|
final_response = tokenizer.decode( |
|
|
sequence[generated.shape[-1]:], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True |
|
|
).strip() |
|
|
st.chat_message("assistant").markdown(final_response) |
|
|
save_message(selected_chat_id, "assistant", final_response) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error: {str(e)}") |