File size: 6,323 Bytes
448a6c5
 
 
 
 
 
 
 
 
 
 
 
 
3aaee7a
448a6c5
 
 
9d3f976
448a6c5
 
3aaee7a
b67fa53
448a6c5
 
 
 
3aaee7a
448a6c5
 
 
 
3aaee7a
b67fa53
448a6c5
3aaee7a
 
448a6c5
 
3aaee7a
 
 
448a6c5
 
9d3f976
448a6c5
 
3aaee7a
 
448a6c5
 
3aaee7a
448a6c5
 
3aaee7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448a6c5
 
 
 
 
 
3aaee7a
448a6c5
3aaee7a
448a6c5
 
 
 
 
 
 
 
 
 
 
3aaee7a
448a6c5
 
 
 
 
3aaee7a
448a6c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aaee7a
448a6c5
 
 
 
 
 
 
 
 
 
 
 
 
 
3aaee7a
448a6c5
 
 
3aaee7a
448a6c5
 
 
 
 
3aaee7a
448a6c5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import streamlit as st
from streamlit_chat import message
import time
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# Set page configuration
st.set_page_config(
    page_title="ChatGPT-Style Chatbot",
    page_icon="🤖",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS for styling with beautiful colors
st.markdown("""
    <style>
        .stApp {
            background-image: linear-gradient(135deg, #ffcccc 0%, #ff9999 100%);
        }
        .sidebar .sidebar-content {
            background-image: linear-gradient(135deg, #6B73FF 0%, #000DFF 100%);
            color: black;
        }
        .stTextInput>div>div>input {
            border-radius: 20px;
            padding: 10px 15px;
            border: 1px solid #d1d5db;
        }
        .stButton>button {
            border-radius: 20px;
            padding: 10px 25px;
            background-image: linear-gradient(to right, #6B73FF 0%, #000DFF 100%);
            color: black;
            border: none;
            font-weight: 500;
            transition: all 0.3s ease;
        }
        .stButton>button:hover {
            background-image: linear-gradient(to right, #000DFF 0%, #6B73FF 100%);
            transform: translateY(-2px);
            box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
        }
        .chat-container {
            background-color: rgba(255, 230, 230, 0.95);
            border-radius: 15px;
            padding: 20px;
            box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05);
            border: 1px solid #e5e7eb;
        }
        .title {
            color: #2d3748;
            text-align: center;
            margin-bottom: 30px;
            font-weight: 600;
        }
        .stSelectbox>div>div>select {
            border-radius: 12px;
            padding: 8px 12px;
        }
        .stSlider>div>div>div>div {
            background-color: #6B73FF;
        }
        .st-expander {
            border-radius: 12px;
            border: 1px solid #e5e7eb;
        }
        .stMarkdown h1 {
            color: #2d3748;
        }
    </style>
""", unsafe_allow_html=True)

# Sidebar
with st.sidebar:
    st.title("⚙️ Chatbot Settings")
    st.markdown("""
    ### ✨ About
    This is a ChatGPT-style chatbot powered by a fine-tuned LLM.
    """)
    
    # Model selection
    model_name = st.selectbox(
        "Choose a model",
        ["gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot-400M-distill"],
        index=1
    )
    
    # Advanced settings
    with st.expander("🔧 Advanced Settings"):
        max_length = st.slider("Max response length", 50, 500, 100)
        temperature = st.slider("Temperature", 0.1, 1.0, 0.7)
        top_p = st.slider("Top-p", 0.1, 1.0, 0.9)
    
    st.markdown("---")
    st.markdown("🚀 Built with ❤️ using [Streamlit](https://streamlit.io/) and [Hugging Face](https://huggingface.co/)")

# Initialize chat history
if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []

if 'model' not in st.session_state:
    st.session_state['model'] = None

if 'tokenizer' not in st.session_state:
    st.session_state['tokenizer'] = None

# Load model
@st.cache_resource
def load_model(model_name):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None, None

# Generate response
def generate_response(prompt):
    if st.session_state['model'] is None or st.session_state['tokenizer'] is None:
        return "Model not loaded. Please try again."
    
    try:
        # Create conversation history context
        history = "\n".join([f"User: {p}\nBot: {g}" for p, g in zip(st.session_state['past'], st.session_state['generated'])])
        full_prompt = f"{history}\nUser: {prompt}\nBot:"
        
        # Generate response
        inputs = st.session_state['tokenizer'].encode(full_prompt, return_tensors="pt")
        outputs = st.session_state['model'].generate(
            inputs,
            max_length=max_length + len(inputs[0]),
            temperature=temperature,
            top_p=top_p,
            pad_token_id=st.session_state['tokenizer'].eos_token_id
        )
        response = st.session_state['tokenizer'].decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the new response
        return response.split("Bot:")[-1].strip()
    except Exception as e:
        return f"Error generating response: {e}"

# Main app
st.title("💬 ChatGPT-Style Chatbot")
st.markdown("""
    <div class='title'>
        Experience a conversation with our fine-tuned LLM chatbot
    </div>
""", unsafe_allow_html=True)

# Container for chat
chat_container = st.container()

# Load model button
if st.button("🚀 Load Model"):
    with st.spinner(f"Loading {model_name}..."):
        st.session_state['model'], st.session_state['tokenizer'] = load_model(model_name)
    st.success(f"Model {model_name} loaded successfully!")

# Display chat
with chat_container:
    if st.session_state['generated']:
        for i in range(len(st.session_state['generated'])):
            message(st.session_state['past'][i], is_user=True, key=str(i) + '_user', avatar_style="identicon")
            message(st.session_state['generated'][i], key=str(i), avatar_style="bottts")

# User input
with st.form(key='chat_form', clear_on_submit=True):
    user_input = st.text_input("You:", key='input', placeholder="Type your message here...")
    submit_button = st.form_submit_button(label='Send ➤')

if submit_button and user_input:
    if st.session_state['model'] is None or st.session_state['tokenizer'] is None:
        st.warning("⚠️ Please load the model first!")
    else:
        # Add user message to chat history
        st.session_state['past'].append(user_input)
        
        # Generate response
        with st.spinner("🤔 Thinking..."):
            response = generate_response(user_input)
        
        # Add bot response to chat history
        st.session_state['generated'].append(response)
        
        # Rerun to update the chat display
        st.experimental_rerun()