Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import nltk | |
| # Download the necessary NLTK data | |
| nltk.download('punkt') | |
| # Constants | |
| MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| MAX_LENGTH = 512 | |
| RESPONSE_MAX_LENGTH = 50 | |
| RESPONSE_MIN_LENGTH = 20 | |
| LENGTH_PENALTY = 1.0 | |
| NUM_BEAMS = 2 | |
| NO_REPEAT_NGRAM_SIZE = 2 | |
| TEMPERATURE = 0.9 | |
| TOP_K = 30 | |
| TOP_P = 0.85 | |
| # Load Pre-Trained Model and Tokenizer | |
| def load_model(): | |
| """Load the pre-trained model and tokenizer""" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| return tokenizer, model | |
| # Function to generate a response using the model | |
| def generate_response(text, tokenizer, model): | |
| """Generate a response using the model""" | |
| input_ids = tokenizer.encode(text, return_tensors="pt", max_length=MAX_LENGTH, truncation=True) | |
| response_ids = model.generate( | |
| input_ids=input_ids, | |
| max_length=RESPONSE_MAX_LENGTH, | |
| min_length=RESPONSE_MIN_LENGTH, | |
| length_penalty=LENGTH_PENALTY, | |
| num_beams=NUM_BEAMS, | |
| no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE, | |
| temperature=TEMPERATURE, | |
| top_k=TOP_K, | |
| top_p=TOP_P, | |
| do_sample=True | |
| ) | |
| output = tokenizer.decode(response_ids[0], skip_special_tokens=True) | |
| return output | |
| # Function to format messages for display | |
| def format_messages_for_display(messages): | |
| """Format messages for display""" | |
| formatted_text = [] | |
| for message in messages: | |
| if message["role"] == "assistant": | |
| formatted_text.append(f"**Assistant**: {message['content']}") | |
| else: | |
| formatted_text.append(f"**User**: {message['content']}") | |
| return "\n\n".join(formatted_text) | |
| # Main function to run the Streamlit app | |
| def main(): | |
| """Run the Streamlit app""" | |
| st.set_page_config(page_title="LLaMA Chat Interface", page_icon="", layout="wide") | |
| st.title("LLaMA Chat Interface") | |
| st.write("This is a chat interface using the LLaMA model for generating responses. Enter a prompt below to start chatting with the model.") | |
| # Load the model and tokenizer | |
| tokenizer, model = load_model() | |
| if'messages' not in st.session_state: | |
| st.session_state['messages'] = [] | |
| # Display chat messages | |
| chat_placeholder = st.empty() | |
| with chat_placeholder.container(): | |
| st.markdown(format_messages_for_display(st.session_state['messages'])) | |
| # Add text input and send button | |
| user_input = st.text_input("Enter your prompt:", key="user_input") | |
| if st.button("Send") and user_input.strip(): | |
| # Store user's message | |
| st.session_state['messages'].append({"role": "user", "content": user_input}) | |
| # Generate and store the assistant's response | |
| with st.spinner("Generating response..."): | |
| response = generate_response(user_input, tokenizer, model) | |
| st.session_state['messages'].append({"role": "assistant", "content": response}) | |
| # Update chat display | |
| with chat_placeholder.container(): | |
| st.markdown(format_messages_for_display(st.session_state['messages'])) | |
| # Option to clear the chat history | |
| if st.button("Clear Chat"): | |
| st.session_state['messages'] = [] | |
| with chat_placeholder.container(): | |
| st.markdown("") | |
| if __name__ == '__main__': | |
| main() |