File size: 3,401 Bytes
a8ff722
168c1d5
199d849
bbf6bac
168c1d5
25c6525
 
168c1d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e22531
349f68b
25c6525
168c1d5
3e22531
168c1d5
3e22531
bbf6bac
 
168c1d5
3e22531
168c1d5
 
25c6525
168c1d5
bbf6bac
168c1d5
 
25c6525
168c1d5
 
25c6525
168c1d5
 
25c6525
168c1d5
 
25c6525
168c1d5
 
 
 
25c6525
168c1d5
 
 
 
 
25c6525
168c1d5
 
 
 
25c6525
168c1d5
 
 
25c6525
168c1d5
 
 
 
 
25c6525
 
168c1d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import nltk

# Download the necessary NLTK data
nltk.download('punkt')

# Constants
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MAX_LENGTH = 512
RESPONSE_MAX_LENGTH = 50
RESPONSE_MIN_LENGTH = 20
LENGTH_PENALTY = 1.0
NUM_BEAMS = 2
NO_REPEAT_NGRAM_SIZE = 2
TEMPERATURE = 0.9
TOP_K = 30
TOP_P = 0.85

# Load Pre-Trained Model and Tokenizer
@st.cache_resource
def load_model():
    """Load the pre-trained model and tokenizer"""
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
    return tokenizer, model

# Function to generate a response using the model
def generate_response(text, tokenizer, model):
    """Generate a response using the model"""
    input_ids = tokenizer.encode(text, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)
    response_ids = model.generate(
        input_ids=input_ids,
        max_length=RESPONSE_MAX_LENGTH,
        min_length=RESPONSE_MIN_LENGTH,
        length_penalty=LENGTH_PENALTY,
        num_beams=NUM_BEAMS,
        no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
        temperature=TEMPERATURE,
        top_k=TOP_K,
        top_p=TOP_P,
        do_sample=True
    )
    output = tokenizer.decode(response_ids[0], skip_special_tokens=True)
    return output

# Function to format messages for display
def format_messages_for_display(messages):
    """Format messages for display"""
    formatted_text = []
    for message in messages:
        if message["role"] == "assistant":
            formatted_text.append(f"**Assistant**: {message['content']}")
        else:
            formatted_text.append(f"**User**: {message['content']}")
    return "\n\n".join(formatted_text)

# Main function to run the Streamlit app
def main():
    """Run the Streamlit app"""
    st.set_page_config(page_title="LLaMA Chat Interface", page_icon="", layout="wide")

    st.title("LLaMA Chat Interface")
    st.write("This is a chat interface using the LLaMA model for generating responses. Enter a prompt below to start chatting with the model.")

    # Load the model and tokenizer
    tokenizer, model = load_model()

    if'messages' not in st.session_state:
        st.session_state['messages'] = []

    # Display chat messages
    chat_placeholder = st.empty()
    with chat_placeholder.container():
        st.markdown(format_messages_for_display(st.session_state['messages']))

    # Add text input and send button
    user_input = st.text_input("Enter your prompt:", key="user_input")
    if st.button("Send") and user_input.strip():
        # Store user's message
        st.session_state['messages'].append({"role": "user", "content": user_input})

        # Generate and store the assistant's response
        with st.spinner("Generating response..."):
            response = generate_response(user_input, tokenizer, model)
            st.session_state['messages'].append({"role": "assistant", "content": response})

        # Update chat display
        with chat_placeholder.container():
            st.markdown(format_messages_for_display(st.session_state['messages']))

    # Option to clear the chat history
    if st.button("Clear Chat"):
        st.session_state['messages'] = []
        with chat_placeholder.container():
            st.markdown("")

if __name__ == '__main__':
    main()