fuzzylab / app1.py
odaly's picture
Update app1.py
168c1d5 verified
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import nltk
# Download the necessary NLTK data
nltk.download('punkt')
# Constants
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MAX_LENGTH = 512
RESPONSE_MAX_LENGTH = 50
RESPONSE_MIN_LENGTH = 20
LENGTH_PENALTY = 1.0
NUM_BEAMS = 2
NO_REPEAT_NGRAM_SIZE = 2
TEMPERATURE = 0.9
TOP_K = 30
TOP_P = 0.85
# Load Pre-Trained Model and Tokenizer
@st.cache_resource
def load_model():
"""Load the pre-trained model and tokenizer"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
return tokenizer, model
# Function to generate a response using the model
def generate_response(text, tokenizer, model):
"""Generate a response using the model"""
input_ids = tokenizer.encode(text, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)
response_ids = model.generate(
input_ids=input_ids,
max_length=RESPONSE_MAX_LENGTH,
min_length=RESPONSE_MIN_LENGTH,
length_penalty=LENGTH_PENALTY,
num_beams=NUM_BEAMS,
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
temperature=TEMPERATURE,
top_k=TOP_K,
top_p=TOP_P,
do_sample=True
)
output = tokenizer.decode(response_ids[0], skip_special_tokens=True)
return output
# Function to format messages for display
def format_messages_for_display(messages):
"""Format messages for display"""
formatted_text = []
for message in messages:
if message["role"] == "assistant":
formatted_text.append(f"**Assistant**: {message['content']}")
else:
formatted_text.append(f"**User**: {message['content']}")
return "\n\n".join(formatted_text)
# Main function to run the Streamlit app
def main():
"""Run the Streamlit app"""
st.set_page_config(page_title="LLaMA Chat Interface", page_icon="", layout="wide")
st.title("LLaMA Chat Interface")
st.write("This is a chat interface using the LLaMA model for generating responses. Enter a prompt below to start chatting with the model.")
# Load the model and tokenizer
tokenizer, model = load_model()
if'messages' not in st.session_state:
st.session_state['messages'] = []
# Display chat messages
chat_placeholder = st.empty()
with chat_placeholder.container():
st.markdown(format_messages_for_display(st.session_state['messages']))
# Add text input and send button
user_input = st.text_input("Enter your prompt:", key="user_input")
if st.button("Send") and user_input.strip():
# Store user's message
st.session_state['messages'].append({"role": "user", "content": user_input})
# Generate and store the assistant's response
with st.spinner("Generating response..."):
response = generate_response(user_input, tokenizer, model)
st.session_state['messages'].append({"role": "assistant", "content": response})
# Update chat display
with chat_placeholder.container():
st.markdown(format_messages_for_display(st.session_state['messages']))
# Option to clear the chat history
if st.button("Clear Chat"):
st.session_state['messages'] = []
with chat_placeholder.container():
st.markdown("")
if __name__ == '__main__':
main()