fuzzylab / app2.py
odaly's picture
Rename app.py to app2.py
8588df7 verified
import os
import streamlit as st
import transformers
import torch
import time
# Ensure the Hugging Face API Token is set in the environment
hf_token = os.getenv("HUGGING_FACE_API_TOKEN")
if not hf_token:
st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
st.stop()
# Initialize the model and tokenizer using Hugging Face pipeline
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
use_auth_token=hf_token
)
def generate_response(prompt):
messages = [
{"role": "user", "content": prompt}
]
# Generate a response using the pipeline
response = pipeline(messages, max_new_tokens=256)
return response[0]["generated_text"]
def response_generator(content):
words = content.split()
for word in words:
yield word + " "
time.sleep(0.1)
def show_messages():
for msg in st.session_state.messages:
role = msg["role"]
with st.chat_message(role):
st.write(msg["content"])
def save_chat():
if not os.path.exists('./Intermediate-Chats'):
os.makedirs('./Intermediate-Chats')
if st.session_state['messages']:
formatted_messages = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages'])
filename = f'./Intermediate-Chats/chat_{int(time.time())}.txt'
with open(filename, 'w') as f:
for message in st.session_state['messages']:
encoded_content = message['content'].replace('\n', '\\n')
f.write(f"{message['role']}: {encoded_content}\n")
st.session_state['messages'].clear()
else:
st.warning("No chat messages to save.")
def load_saved_chats():
chat_dir = './Intermediate-Chats'
if os.path.exists(chat_dir):
files = os.listdir(chat_dir)
files.sort(key=lambda x: os.path.getmtime(os.path.join(chat_dir, x)), reverse=True)
for file_name in files:
display_name = file_name[:-4] if file_name.endswith('.txt') else file_name
if st.sidebar.button(display_name):
st.session_state['show_chats'] = False
st.session_state['is_loaded'] = True
load_chat(os.path.join(chat_dir, file_name))
def load_chat(file_path):
st.session_state['messages'].clear()
with open(file_path, 'r') as file:
for line in file.readlines():
role, content = line.strip().split(': ', 1)
decoded_content = content.replace('\\n', '\n')
st.session_state['messages'].append({'role': role, 'content': decoded_content})
def main():
st.title("LLaMA Chat Interface")
if 'messages' not in st.session_state:
st.session_state['messages'] = []
if 'show_chats' not in st.session_state:
st.session_state['show_chats'] = False
show_messages()
user_input = st.chat_input("Enter your prompt:")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
response = generate_response(user_input)
st.session_state.messages.append({"role": "assistant", "content": response})
with st.chat_message("assistant"):
for word in response_generator(response):
st.write(word, end="")
chat_log = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages'])
st.sidebar.download_button(
label="Download Chat Log",
data=chat_log,
file_name="chat_log.txt",
mime="text/plain"
)
if st.sidebar.button("Save Chat"):
save_chat()
if st.sidebar.button("New Chat"):
st.session_state['messages'].clear()
if st.sidebar.checkbox("Show/hide chat history", value=st.session_state['show_chats']):
st.sidebar.title("Previous Chats")
load_saved_chats()
if __name__ == "__main__":
main()