Spaces:
Sleeping
Sleeping
File size: 3,401 Bytes
a8ff722 168c1d5 199d849 bbf6bac 168c1d5 25c6525 168c1d5 3e22531 349f68b 25c6525 168c1d5 3e22531 168c1d5 3e22531 bbf6bac 168c1d5 3e22531 168c1d5 25c6525 168c1d5 bbf6bac 168c1d5 25c6525 168c1d5 25c6525 168c1d5 25c6525 168c1d5 25c6525 168c1d5 25c6525 168c1d5 25c6525 168c1d5 25c6525 168c1d5 25c6525 168c1d5 25c6525 168c1d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import nltk
# Download the necessary NLTK data
nltk.download('punkt')
# Constants
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MAX_LENGTH = 512
RESPONSE_MAX_LENGTH = 50
RESPONSE_MIN_LENGTH = 20
LENGTH_PENALTY = 1.0
NUM_BEAMS = 2
NO_REPEAT_NGRAM_SIZE = 2
TEMPERATURE = 0.9
TOP_K = 30
TOP_P = 0.85
# Load Pre-Trained Model and Tokenizer
@st.cache_resource
def load_model():
"""Load the pre-trained model and tokenizer"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
return tokenizer, model
# Function to generate a response using the model
def generate_response(text, tokenizer, model):
"""Generate a response using the model"""
input_ids = tokenizer.encode(text, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)
response_ids = model.generate(
input_ids=input_ids,
max_length=RESPONSE_MAX_LENGTH,
min_length=RESPONSE_MIN_LENGTH,
length_penalty=LENGTH_PENALTY,
num_beams=NUM_BEAMS,
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
temperature=TEMPERATURE,
top_k=TOP_K,
top_p=TOP_P,
do_sample=True
)
output = tokenizer.decode(response_ids[0], skip_special_tokens=True)
return output
# Function to format messages for display
def format_messages_for_display(messages):
"""Format messages for display"""
formatted_text = []
for message in messages:
if message["role"] == "assistant":
formatted_text.append(f"**Assistant**: {message['content']}")
else:
formatted_text.append(f"**User**: {message['content']}")
return "\n\n".join(formatted_text)
# Main function to run the Streamlit app
def main():
"""Run the Streamlit app"""
st.set_page_config(page_title="LLaMA Chat Interface", page_icon="", layout="wide")
st.title("LLaMA Chat Interface")
st.write("This is a chat interface using the LLaMA model for generating responses. Enter a prompt below to start chatting with the model.")
# Load the model and tokenizer
tokenizer, model = load_model()
if'messages' not in st.session_state:
st.session_state['messages'] = []
# Display chat messages
chat_placeholder = st.empty()
with chat_placeholder.container():
st.markdown(format_messages_for_display(st.session_state['messages']))
# Add text input and send button
user_input = st.text_input("Enter your prompt:", key="user_input")
if st.button("Send") and user_input.strip():
# Store user's message
st.session_state['messages'].append({"role": "user", "content": user_input})
# Generate and store the assistant's response
with st.spinner("Generating response..."):
response = generate_response(user_input, tokenizer, model)
st.session_state['messages'].append({"role": "assistant", "content": response})
# Update chat display
with chat_placeholder.container():
st.markdown(format_messages_for_display(st.session_state['messages']))
# Option to clear the chat history
if st.button("Clear Chat"):
st.session_state['messages'] = []
with chat_placeholder.container():
st.markdown("")
if __name__ == '__main__':
main() |