Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| def load_model(): | |
| """Load the DialoGPT-small model and tokenizer once and cache them.""" | |
| model_name = os.getenv("MODEL_REPO", "afscomercial/streamlit-chatbot-aviation") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| # Initialise session state to keep track of the conversation | |
| if "generated" not in st.session_state: | |
| st.session_state["generated"] = [] | |
| if "past" not in st.session_state: | |
| st.session_state["past"] = [] | |
| if "chat_history_ids" not in st.session_state: | |
| st.session_state["chat_history_ids"] = None | |
| st.title("🗨️ Simple LLM Chatbot") | |
| st.write("Chat with an open-source language model powered by Hugging Face Transformers") | |
| # Display the conversation so far | |
| for i in range(len(st.session_state["generated"])): | |
| st.chat_message("user").markdown(st.session_state["past"][i]) | |
| st.chat_message("assistant").markdown(st.session_state["generated"][i]) | |
| # Prompt for user input | |
| user_input = st.chat_input("Type your message and press Enter…") | |
| if user_input: | |
| # Add user message to history | |
| st.session_state["past"].append(user_input) | |
| # Encode the user input and append the EOS token | |
| new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt") | |
| # Append the new user input to the chat history (if it exists) | |
| if st.session_state["chat_history_ids"] is not None: | |
| bot_input_ids = torch.cat([st.session_state["chat_history_ids"], new_user_input_ids], dim=-1) | |
| else: | |
| bot_input_ids = new_user_input_ids | |
| # Generate a response | |
| st.spinner("Generating reply…") | |
| output_ids = model.generate( | |
| bot_input_ids, | |
| max_length=bot_input_ids.shape[-1] + 100, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_p=0.92, | |
| temperature=0.75, | |
| ) | |
| # Extract the generated tokens excluding the input tokens | |
| response_ids = output_ids[:, bot_input_ids.shape[-1]:] | |
| response = tokenizer.decode(response_ids[0], skip_special_tokens=True) | |
| # Add the model's response to the chat history | |
| st.session_state["generated"].append(response) | |
| st.session_state["chat_history_ids"] = output_ids | |
| # Display the latest user & assistant messages immediately | |
| st.chat_message("user").markdown(user_input) | |
| st.chat_message("assistant").markdown(response) |