File size: 3,879 Bytes
043fc8b
 
376b382
d2681a2
043fc8b
376b382
043fc8b
 
 
d2681a2
 
 
043fc8b
d2681a2
 
 
63b2a2b
 
d2681a2
 
 
043fc8b
d2681a2
043fc8b
d2681a2
043fc8b
 
d2681a2
043fc8b
 
 
 
 
 
 
 
 
 
 
 
b9c9890
043fc8b
 
 
 
 
d2681a2
043fc8b
 
 
 
 
d2681a2
043fc8b
 
 
 
 
 
 
6ea5e8d
 
 
 
 
 
 
d2681a2
6ea5e8d
d2681a2
 
 
 
 
 
 
 
 
6ea5e8d
 
a9b6845
 
 
d2681a2
a9b6845
d2681a2
 
 
 
043fc8b
d2681a2
 
6ea5e8d
d2681a2
 
6ea5e8d
 
d2681a2
62e49f3
d2681a2
6ea5e8d
 
 
 
 
 
 
 
043fc8b
 
6ea5e8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from chats import init_db, get_all_chats, create_new_chat, save_message, get_messages, system_prompt

# Set HF cache directory
os.environ["HF_HOME"] = "/tmp/huggingface_cache"

# Initialize Streamlit app
st.set_page_config(page_title="AI Assistant", page_icon="πŸ€–")
st.title("πŸ€– Juma's Assistant")

# Load model and tokenizer once at startup
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    return tokenizer, model, device

tokenizer, model, device = load_model()

# Initialize database
init_db()

# Sidebar for previous chats
st.sidebar.title("πŸ’¬ Previous Chats")
all_chats = get_all_chats()

chat_titles = [f"{title} (ID: {chat_id})" for chat_id, title in all_chats]
selected_chat_index = st.sidebar.selectbox(
    "Select Chat", range(len(all_chats)), format_func=lambda i: chat_titles[i] if all_chats else "No chats available"
)

selected_chat_id = all_chats[selected_chat_index][0] if all_chats else None

if st.sidebar.button("πŸ†• Start New Chat"):
    selected_chat_id = create_new_chat()
    st.experimental_rerun()

if selected_chat_id is None:
    st.warning("Please start a new chat or select one from the sidebar.")
    st.stop()

# Display chat history
messages = get_messages(selected_chat_id)
for role, content in messages:
    with st.chat_message(role):
        st.markdown(content)

# Handle user input
user_input = st.chat_input("Type your message...")
if user_input:
    st.chat_message("user").markdown(user_input)
    save_message(selected_chat_id, "user", user_input)

    with st.spinner("Thinking..."):
        try:
            # Manually format the chat prompt
            system_message = system_prompt()
            prompt = f"<|SYSTEM|> {system_message} <|USER|> {user_input} <|ASSISTANT>"

            # Tokenize the formatted prompt
            inputs = tokenizer(
                prompt,
                return_tensors="pt",
                add_special_tokens=True
            ).to(device)

            # Generate tokens
            full_response = ""
            placeholder = st.empty()

            # Stream tokens
            generated = inputs["input_ids"]
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=150,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                no_repeat_ngram_size=3,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=False
            )
            sequence = outputs.sequences[0]

            # Decode tokens one by one, preserving spaces
            for i in range(generated.shape[-1], sequence.shape[-1]):
                token_id = sequence[i].unsqueeze(0)
                text = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                if text:
                    full_response += text
                    # placeholder.markdown(full_response)

            # Final response, decoding only new tokens
            final_response = tokenizer.decode(
                sequence[generated.shape[-1]:],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            ).strip()
            st.chat_message("assistant").markdown(final_response)
            save_message(selected_chat_id, "assistant", final_response)

        except Exception as e:
            st.error(f"Error: {str(e)}")