Spaces:
Sleeping
Sleeping
File size: 4,031 Bytes
ca9ae0b 6cd3a3a 521698c 6cd3a3a af0011c ca9ae0b 521698c 3fcc4ca af0011c 6cd3a3a 521698c af0011c 521698c af0011c 6cd3a3a 521698c 6cd3a3a 521698c 6cd3a3a 521698c 6cd3a3a 521698c 6cd3a3a 521698c 6cd3a3a 521698c 6cd3a3a 521698c 6cd3a3a 6156793 6cd3a3a 521698c 6cd3a3a 521698c 6cd3a3a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | import os
import streamlit as st
import transformers
import torch
import time
# Ensure the Hugging Face API Token is set in the environment
hf_token = os.getenv("HUGGING_FACE_API_TOKEN")
if not hf_token:
st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
st.stop()
# Initialize the model and tokenizer using Hugging Face pipeline
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
use_auth_token=hf_token
)
def generate_response(prompt):
messages = [
{"role": "user", "content": prompt}
]
# Generate a response using the pipeline
response = pipeline(messages, max_new_tokens=256)
return response[0]["generated_text"]
def response_generator(content):
words = content.split()
for word in words:
yield word + " "
time.sleep(0.1)
def show_messages():
for msg in st.session_state.messages:
role = msg["role"]
with st.chat_message(role):
st.write(msg["content"])
def save_chat():
if not os.path.exists('./Intermediate-Chats'):
os.makedirs('./Intermediate-Chats')
if st.session_state['messages']:
formatted_messages = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages'])
filename = f'./Intermediate-Chats/chat_{int(time.time())}.txt'
with open(filename, 'w') as f:
for message in st.session_state['messages']:
encoded_content = message['content'].replace('\n', '\\n')
f.write(f"{message['role']}: {encoded_content}\n")
st.session_state['messages'].clear()
else:
st.warning("No chat messages to save.")
def load_saved_chats():
chat_dir = './Intermediate-Chats'
if os.path.exists(chat_dir):
files = os.listdir(chat_dir)
files.sort(key=lambda x: os.path.getmtime(os.path.join(chat_dir, x)), reverse=True)
for file_name in files:
display_name = file_name[:-4] if file_name.endswith('.txt') else file_name
if st.sidebar.button(display_name):
st.session_state['show_chats'] = False
st.session_state['is_loaded'] = True
load_chat(os.path.join(chat_dir, file_name))
def load_chat(file_path):
st.session_state['messages'].clear()
with open(file_path, 'r') as file:
for line in file.readlines():
role, content = line.strip().split(': ', 1)
decoded_content = content.replace('\\n', '\n')
st.session_state['messages'].append({'role': role, 'content': decoded_content})
def main():
st.title("LLaMA Chat Interface")
if 'messages' not in st.session_state:
st.session_state['messages'] = []
if 'show_chats' not in st.session_state:
st.session_state['show_chats'] = False
show_messages()
user_input = st.chat_input("Enter your prompt:")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
response = generate_response(user_input)
st.session_state.messages.append({"role": "assistant", "content": response})
with st.chat_message("assistant"):
for word in response_generator(response):
st.write(word, end="")
chat_log = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages'])
st.sidebar.download_button(
label="Download Chat Log",
data=chat_log,
file_name="chat_log.txt",
mime="text/plain"
)
if st.sidebar.button("Save Chat"):
save_chat()
if st.sidebar.button("New Chat"):
st.session_state['messages'].clear()
if st.sidebar.checkbox("Show/hide chat history", value=st.session_state['show_chats']):
st.sidebar.title("Previous Chats")
load_saved_chats()
if __name__ == "__main__":
main()
|