import os import streamlit as st import transformers import torch import time # Ensure the Hugging Face API Token is set in the environment hf_token = os.getenv("HUGGING_FACE_API_TOKEN") if not hf_token: st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.") st.stop() # Initialize the model and tokenizer using Hugging Face pipeline model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", use_auth_token=hf_token ) def generate_response(prompt): messages = [ {"role": "user", "content": prompt} ] # Generate a response using the pipeline response = pipeline(messages, max_new_tokens=256) return response[0]["generated_text"] def response_generator(content): words = content.split() for word in words: yield word + " " time.sleep(0.1) def show_messages(): for msg in st.session_state.messages: role = msg["role"] with st.chat_message(role): st.write(msg["content"]) def save_chat(): if not os.path.exists('./Intermediate-Chats'): os.makedirs('./Intermediate-Chats') if st.session_state['messages']: formatted_messages = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages']) filename = f'./Intermediate-Chats/chat_{int(time.time())}.txt' with open(filename, 'w') as f: for message in st.session_state['messages']: encoded_content = message['content'].replace('\n', '\\n') f.write(f"{message['role']}: {encoded_content}\n") st.session_state['messages'].clear() else: st.warning("No chat messages to save.") def load_saved_chats(): chat_dir = './Intermediate-Chats' if os.path.exists(chat_dir): files = os.listdir(chat_dir) files.sort(key=lambda x: os.path.getmtime(os.path.join(chat_dir, x)), reverse=True) for file_name in files: display_name = file_name[:-4] if file_name.endswith('.txt') else file_name if st.sidebar.button(display_name): st.session_state['show_chats'] = False st.session_state['is_loaded'] = True load_chat(os.path.join(chat_dir, file_name)) def load_chat(file_path): st.session_state['messages'].clear() with open(file_path, 'r') as file: for line in file.readlines(): role, content = line.strip().split(': ', 1) decoded_content = content.replace('\\n', '\n') st.session_state['messages'].append({'role': role, 'content': decoded_content}) def main(): st.title("LLaMA Chat Interface") if 'messages' not in st.session_state: st.session_state['messages'] = [] if 'show_chats' not in st.session_state: st.session_state['show_chats'] = False show_messages() user_input = st.chat_input("Enter your prompt:") if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) response = generate_response(user_input) st.session_state.messages.append({"role": "assistant", "content": response}) with st.chat_message("assistant"): for word in response_generator(response): st.write(word, end="") chat_log = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages']) st.sidebar.download_button( label="Download Chat Log", data=chat_log, file_name="chat_log.txt", mime="text/plain" ) if st.sidebar.button("Save Chat"): save_chat() if st.sidebar.button("New Chat"): st.session_state['messages'].clear() if st.sidebar.checkbox("Show/hide chat history", value=st.session_state['show_chats']): st.sidebar.title("Previous Chats") load_saved_chats() if __name__ == "__main__": main()