| import torch | |
| import pickle | |
| import streamlit as st | |
| from transformers import Conversation, pipeline | |
| from upload import get_file, upload_file | |
| from utils import clear_uploader, undo, restart | |
| share_keys = ["messages", "model_name"] | |
| MODELS = ["facebook/blenderbot-400M-distill", "facebook/blenderbot-90M"] | |
| st.set_page_config( | |
| page_title="LLM", | |
| page_icon="π", | |
| ) | |
| if "model_name" not in st.session_state: | |
| st.session_state.model_name = "facebook/blenderbot-400M-distill" | |
| def get_pipeline(model_name): | |
| device = 0 if torch.cuda.is_available() else -1 | |
| chatbot = pipeline(model=model_name, task="conversational", device=device) | |
| return chatbot | |
| chatbot = get_pipeline(st.session_state.model_name) | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if len(st.session_state.messages) == 0 and "id" in st.query_params: | |
| with st.spinner("Loading chat..."): | |
| id = st.query_params["id"] | |
| data = get_file(id, 'llm-007') | |
| obj = pickle.loads(data) | |
| for k, v in obj.items(): | |
| st.session_state[k] = v | |
| def share(): | |
| obj = {} | |
| for k in share_keys: | |
| if k in st.session_state: | |
| obj[k] = st.session_state[k] | |
| data = pickle.dumps(obj) | |
| id = upload_file(data, 'llm-007') | |
| url = f"https://umbc-nlp-llm.hf.space/?id={id}" | |
| st.markdown(f"[share](/?id={id})") | |
| st.success(f"Share URL: {url}") | |
| with st.sidebar: | |
| st.title(":blue[LLM Only]") | |
| st.subheader("Model") | |
| model_name = st.selectbox("Model", MODELS, index=MODELS.index(st.session_state.model_name)) | |
| if st.button("Share", use_container_width=True): | |
| share() | |
| cols = st.columns(2) | |
| with cols[0]: | |
| if st.button("Restart", type="primary", use_container_width=True): | |
| restart() | |
| with cols[1]: | |
| if st.button("Undo", use_container_width=True): | |
| undo() | |
| append = st.checkbox("Append to previous message", value=False) | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| def push_message(role, content): | |
| message = {"role": role, "content": content} | |
| st.session_state.messages.append(message) | |
| return message | |
| if prompt := st.chat_input("Type a message", key="chat_input"): | |
| push_message("user", prompt) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| if not append: | |
| with st.chat_message("assistant"): | |
| conversation = Conversation() | |
| for m in st.session_state.messages: | |
| conversation.add_message(m) | |
| print(conversation) | |
| with st.spinner("Generating response..."): | |
| response = chatbot(conversation) | |
| response = response[-1]["content"] | |
| st.write(response) | |
| push_message("assistant", response) | |
| clear_uploader() |