Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import transformers | |
| import torch | |
| import time | |
| # Ensure the Hugging Face API Token is set in the environment | |
| hf_token = os.getenv("HUGGING_FACE_API_TOKEN") | |
| if not hf_token: | |
| st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.") | |
| st.stop() | |
| # Initialize the model and tokenizer using Hugging Face pipeline | |
| model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| pipeline = transformers.pipeline( | |
| "text-generation", | |
| model=model_id, | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| device_map="auto", | |
| use_auth_token=hf_token | |
| ) | |
| def generate_response(prompt): | |
| messages = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| # Generate a response using the pipeline | |
| response = pipeline(messages, max_new_tokens=256) | |
| return response[0]["generated_text"] | |
| def response_generator(content): | |
| words = content.split() | |
| for word in words: | |
| yield word + " " | |
| time.sleep(0.1) | |
| def show_messages(): | |
| for msg in st.session_state.messages: | |
| role = msg["role"] | |
| with st.chat_message(role): | |
| st.write(msg["content"]) | |
| def save_chat(): | |
| if not os.path.exists('./Intermediate-Chats'): | |
| os.makedirs('./Intermediate-Chats') | |
| if st.session_state['messages']: | |
| formatted_messages = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages']) | |
| filename = f'./Intermediate-Chats/chat_{int(time.time())}.txt' | |
| with open(filename, 'w') as f: | |
| for message in st.session_state['messages']: | |
| encoded_content = message['content'].replace('\n', '\\n') | |
| f.write(f"{message['role']}: {encoded_content}\n") | |
| st.session_state['messages'].clear() | |
| else: | |
| st.warning("No chat messages to save.") | |
| def load_saved_chats(): | |
| chat_dir = './Intermediate-Chats' | |
| if os.path.exists(chat_dir): | |
| files = os.listdir(chat_dir) | |
| files.sort(key=lambda x: os.path.getmtime(os.path.join(chat_dir, x)), reverse=True) | |
| for file_name in files: | |
| display_name = file_name[:-4] if file_name.endswith('.txt') else file_name | |
| if st.sidebar.button(display_name): | |
| st.session_state['show_chats'] = False | |
| st.session_state['is_loaded'] = True | |
| load_chat(os.path.join(chat_dir, file_name)) | |
| def load_chat(file_path): | |
| st.session_state['messages'].clear() | |
| with open(file_path, 'r') as file: | |
| for line in file.readlines(): | |
| role, content = line.strip().split(': ', 1) | |
| decoded_content = content.replace('\\n', '\n') | |
| st.session_state['messages'].append({'role': role, 'content': decoded_content}) | |
| def main(): | |
| st.title("LLaMA Chat Interface") | |
| if 'messages' not in st.session_state: | |
| st.session_state['messages'] = [] | |
| if 'show_chats' not in st.session_state: | |
| st.session_state['show_chats'] = False | |
| show_messages() | |
| user_input = st.chat_input("Enter your prompt:") | |
| if user_input: | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| response = generate_response(user_input) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| with st.chat_message("assistant"): | |
| for word in response_generator(response): | |
| st.write(word, end="") | |
| chat_log = '\n'.join(f"{msg['role']}: {msg['content']}" for msg in st.session_state['messages']) | |
| st.sidebar.download_button( | |
| label="Download Chat Log", | |
| data=chat_log, | |
| file_name="chat_log.txt", | |
| mime="text/plain" | |
| ) | |
| if st.sidebar.button("Save Chat"): | |
| save_chat() | |
| if st.sidebar.button("New Chat"): | |
| st.session_state['messages'].clear() | |
| if st.sidebar.checkbox("Show/hide chat history", value=st.session_state['show_chats']): | |
| st.sidebar.title("Previous Chats") | |
| load_saved_chats() | |
| if __name__ == "__main__": | |
| main() | |