DharavathSri commited on
Commit
448a6c5
·
verified ·
1 Parent(s): 8272bc1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+ import time
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # Set page configuration
7
+ st.set_page_config(
8
+ page_title="ChatGPT-Style Chatbot",
9
+ page_icon="🤖",
10
+ layout="wide",
11
+ initial_sidebar_state="expanded"
12
+ )
13
+
14
+ # Custom CSS for styling
15
+ st.markdown("""
16
+ <style>
17
+ .stApp {
18
+ background-image: linear-gradient(to right, #f5f7fa, #c3cfe2);
19
+ }
20
+ .sidebar .sidebar-content {
21
+ background-image: linear-gradient(to bottom, #667eea, #764ba2);
22
+ color: white;
23
+ }
24
+ .stTextInput>div>div>input {
25
+ border-radius: 20px;
26
+ padding: 10px 15px;
27
+ }
28
+ .stButton>button {
29
+ border-radius: 20px;
30
+ padding: 10px 25px;
31
+ background-image: linear-gradient(to right, #667eea, #764ba2);
32
+ color: white;
33
+ border: none;
34
+ }
35
+ .stButton>button:hover {
36
+ background-image: linear-gradient(to right, #764ba2, #667eea);
37
+ }
38
+ .chat-container {
39
+ background-color: rgba(255, 255, 255, 0.9);
40
+ border-radius: 15px;
41
+ padding: 20px;
42
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
43
+ }
44
+ .title {
45
+ color: #4a4a4a;
46
+ text-align: center;
47
+ margin-bottom: 30px;
48
+ }
49
+ </style>
50
+ """, unsafe_allow_html=True)
51
+
52
+ # Sidebar
53
+ with st.sidebar:
54
+ st.title("🤖 Chatbot Settings")
55
+ st.markdown("""
56
+ ### About
57
+ This is a ChatGPT-style chatbot powered by a fine-tuned LLM.
58
+ """)
59
+
60
+ # Model selection
61
+ model_name = st.selectbox(
62
+ "Choose a model",
63
+ ["gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot-400M-distill"],
64
+ index=1
65
+ )
66
+
67
+ # Advanced settings
68
+ with st.expander("Advanced Settings"):
69
+ max_length = st.slider("Max response length", 50, 500, 100)
70
+ temperature = st.slider("Temperature", 0.1, 1.0, 0.7)
71
+ top_p = st.slider("Top-p", 0.1, 1.0, 0.9)
72
+
73
+ st.markdown("---")
74
+ st.markdown("Built with ❤️ using [Streamlit](https://streamlit.io/) and [Hugging Face](https://huggingface.co/)")
75
+
76
+ # Initialize chat history
77
+ if 'generated' not in st.session_state:
78
+ st.session_state['generated'] = []
79
+
80
+ if 'past' not in st.session_state:
81
+ st.session_state['past'] = []
82
+
83
+ if 'model' not in st.session_state:
84
+ st.session_state['model'] = None
85
+
86
+ if 'tokenizer' not in st.session_state:
87
+ st.session_state['tokenizer'] = None
88
+
89
+ # Load model
90
+ @st.cache_resource
91
+ def load_model(model_name):
92
+ try:
93
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
94
+ model = AutoModelForCausalLM.from_pretrained(model_name)
95
+ return model, tokenizer
96
+ except Exception as e:
97
+ st.error(f"Error loading model: {e}")
98
+ return None, None
99
+
100
+ # Generate response
101
+ def generate_response(prompt):
102
+ if st.session_state['model'] is None or st.session_state['tokenizer'] is None:
103
+ return "Model not loaded. Please try again."
104
+
105
+ try:
106
+ # Create conversation history context
107
+ history = "\n".join([f"User: {p}\nBot: {g}" for p, g in zip(st.session_state['past'], st.session_state['generated'])])
108
+ full_prompt = f"{history}\nUser: {prompt}\nBot:"
109
+
110
+ # Generate response
111
+ inputs = st.session_state['tokenizer'].encode(full_prompt, return_tensors="pt")
112
+ outputs = st.session_state['model'].generate(
113
+ inputs,
114
+ max_length=max_length + len(inputs[0]),
115
+ temperature=temperature,
116
+ top_p=top_p,
117
+ pad_token_id=st.session_state['tokenizer'].eos_token_id
118
+ )
119
+ response = st.session_state['tokenizer'].decode(outputs[0], skip_special_tokens=True)
120
+
121
+ # Extract only the new response
122
+ return response.split("Bot:")[-1].strip()
123
+ except Exception as e:
124
+ return f"Error generating response: {e}"
125
+
126
+ # Main app
127
+ st.title("💬 ChatGPT-Style Chatbot")
128
+ st.markdown("""
129
+ <div class='title'>
130
+ Experience a conversation with our fine-tuned LLM chatbot
131
+ </div>
132
+ """, unsafe_allow_html=True)
133
+
134
+ # Container for chat
135
+ chat_container = st.container()
136
+
137
+ # Load model button
138
+ if st.button("Load Model"):
139
+ with st.spinner(f"Loading {model_name}..."):
140
+ st.session_state['model'], st.session_state['tokenizer'] = load_model(model_name)
141
+ st.success(f"Model {model_name} loaded successfully!")
142
+
143
+ # Display chat
144
+ with chat_container:
145
+ if st.session_state['generated']:
146
+ for i in range(len(st.session_state['generated'])):
147
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user', avatar_style="identicon")
148
+ message(st.session_state['generated'][i], key=str(i), avatar_style="bottts")
149
+
150
+ # User input
151
+ with st.form(key='chat_form', clear_on_submit=True):
152
+ user_input = st.text_input("You:", key='input', placeholder="Type your message here...")
153
+ submit_button = st.form_submit_button(label='Send')
154
+
155
+ if submit_button and user_input:
156
+ if st.session_state['model'] is None or st.session_state['tokenizer'] is None:
157
+ st.warning("Please load the model first!")
158
+ else:
159
+ # Add user message to chat history
160
+ st.session_state['past'].append(user_input)
161
+
162
+ # Generate response
163
+ with st.spinner("Thinking..."):
164
+ response = generate_response(user_input)
165
+
166
+ # Add bot response to chat history
167
+ st.session_state['generated'].append(response)
168
+
169
+ # Rerun to update the chat display
170
+ st.experimental_rerun()