afscomercial's picture
fine tuning
2237f11
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
@st.cache_resource(show_spinner=False)
def load_model():
"""Load the DialoGPT-small model and tokenizer once and cache them."""
model_name = os.getenv("MODEL_REPO", "afscomercial/streamlit-chatbot-aviation")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model()
# Initialise session state to keep track of the conversation
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
if "chat_history_ids" not in st.session_state:
st.session_state["chat_history_ids"] = None
st.title("🗨️ Simple LLM Chatbot")
st.write("Chat with an open-source language model powered by Hugging Face Transformers")
# Display the conversation so far
for i in range(len(st.session_state["generated"])):
st.chat_message("user").markdown(st.session_state["past"][i])
st.chat_message("assistant").markdown(st.session_state["generated"][i])
# Prompt for user input
user_input = st.chat_input("Type your message and press Enter…")
if user_input:
# Add user message to history
st.session_state["past"].append(user_input)
# Encode the user input and append the EOS token
new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
# Append the new user input to the chat history (if it exists)
if st.session_state["chat_history_ids"] is not None:
bot_input_ids = torch.cat([st.session_state["chat_history_ids"], new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
# Generate a response
st.spinner("Generating reply…")
output_ids = model.generate(
bot_input_ids,
max_length=bot_input_ids.shape[-1] + 100,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_p=0.92,
temperature=0.75,
)
# Extract the generated tokens excluding the input tokens
response_ids = output_ids[:, bot_input_ids.shape[-1]:]
response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
# Add the model's response to the chat history
st.session_state["generated"].append(response)
st.session_state["chat_history_ids"] = output_ids
# Display the latest user & assistant messages immediately
st.chat_message("user").markdown(user_input)
st.chat_message("assistant").markdown(response)