File size: 4,031 Bytes
ca9ae0b
6cd3a3a
521698c
 
6cd3a3a
 
af0011c
ca9ae0b
 
 
 
 
521698c
3fcc4ca
af0011c
 
 
 
 
 
 
6cd3a3a
521698c
af0011c
 
 
521698c
af0011c
 
6cd3a3a
521698c
 
 
 
 
6cd3a3a
521698c
6cd3a3a
 
 
 
 
 
 
 
 
521698c
6cd3a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521698c
6cd3a3a
 
 
 
 
 
521698c
6cd3a3a
521698c
6cd3a3a
 
 
 
521698c
6cd3a3a
6156793
 
6cd3a3a
521698c
6cd3a3a
 
521698c
6cd3a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import streamlit as st
import transformers
import torch
import time

# Ensure the Hugging Face API Token is set in the environment
hf_token = os.getenv("HUGGING_FACE_API_TOKEN")
if not hf_token:
    st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
    st.stop()

# Initialize the model and tokenizer using Hugging Face pipeline
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
    use_auth_token=hf_token
)

def generate_response(prompt):
    messages = [
        {"role": "user", "content": prompt}
    ]
    # Generate a response using the pipeline
    response = pipeline(messages, max_new_tokens=256)
    return response[0]["generated_text"]

def response_generator(content):
    words = content.split()
    for word in words:
        yield word + " "
        time.sleep(0.1)

def show_messages():
    for msg in st.session_state.messages:
        role = msg["role"]
        with st.chat_message(role):
            st.write(msg["content"])

def save_chat():
    if not os.path.exists('./Intermediate-Chats'):
        os.makedirs('./Intermediate-Chats')
    if st.session_state['messages']:
        formatted_messages = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages'])
        filename = f'./Intermediate-Chats/chat_{int(time.time())}.txt'
        with open(filename, 'w') as f:
            for message in st.session_state['messages']:
                encoded_content = message['content'].replace('\n', '\\n')
                f.write(f"{message['role']}: {encoded_content}\n")
        st.session_state['messages'].clear()
    else:
        st.warning("No chat messages to save.")

def load_saved_chats():
    chat_dir = './Intermediate-Chats'
    if os.path.exists(chat_dir):
        files = os.listdir(chat_dir)
        files.sort(key=lambda x: os.path.getmtime(os.path.join(chat_dir, x)), reverse=True)
        for file_name in files:
            display_name = file_name[:-4] if file_name.endswith('.txt') else file_name
            if st.sidebar.button(display_name):
                st.session_state['show_chats'] = False
                st.session_state['is_loaded'] = True
                load_chat(os.path.join(chat_dir, file_name))

def load_chat(file_path):
    st.session_state['messages'].clear()
    with open(file_path, 'r') as file:
        for line in file.readlines():
            role, content = line.strip().split(': ', 1)
            decoded_content = content.replace('\\n', '\n')
            st.session_state['messages'].append({'role': role, 'content': decoded_content})

def main():
    st.title("LLaMA Chat Interface")

    if 'messages' not in st.session_state:
        st.session_state['messages'] = []
    if 'show_chats' not in st.session_state:
        st.session_state['show_chats'] = False

    show_messages()

    user_input = st.chat_input("Enter your prompt:")
    if user_input:
        st.session_state.messages.append({"role": "user", "content": user_input})
        response = generate_response(user_input)
        st.session_state.messages.append({"role": "assistant", "content": response})

        with st.chat_message("assistant"):
            for word in response_generator(response):
                st.write(word, end="")

    chat_log = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages'])
    st.sidebar.download_button(
        label="Download Chat Log",
        data=chat_log,
        file_name="chat_log.txt",
        mime="text/plain"
    )

    if st.sidebar.button("Save Chat"):
        save_chat()

    if st.sidebar.button("New Chat"):
        st.session_state['messages'].clear()

    if st.sidebar.checkbox("Show/hide chat history", value=st.session_state['show_chats']):
        st.sidebar.title("Previous Chats")
        load_saved_chats()

if __name__ == "__main__":
    main()