DevNumb commited on
Commit
86c9b38
·
verified ·
1 Parent(s): 24ad31d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import time
5
+
6
+ # Load the model and tokenizer
7
+ @torch.no_grad()
8
+ def load_model():
9
+ print("Loading model...")
10
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
11
+ model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
12
+
13
+ # Add padding token if it doesn't exist
14
+ if tokenizer.pad_token is None:
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+
17
+ print("Model loaded successfully!")
18
+ return tokenizer, model
19
+
20
+ # Initialize the model
21
+ tokenizer, model = load_model()
22
+
23
+ def generate_response(message, history, temperature=0.7, max_length=150):
24
+ """
25
+ Generate a response using GPT-2
26
+ """
27
+ try:
28
+ # Format the conversation history
29
+ conversation_history = ""
30
+ for human, assistant in history:
31
+ conversation_history += f"Human: {human}\nAssistant: {assistant}\n"
32
+
33
+ # Add the current message
34
+ full_prompt = f"{conversation_history}Human: {message}\nAssistant:"
35
+
36
+ # Tokenize the input
37
+ inputs = tokenizer.encode(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
38
+
39
+ # Generate response
40
+ with torch.no_grad():
41
+ outputs = model.generate(
42
+ inputs,
43
+ max_length=inputs.shape[1] + max_length,
44
+ temperature=temperature,
45
+ do_sample=True,
46
+ pad_token_id=tokenizer.eos_token_id,
47
+ top_p=0.9,
48
+ repetition_penalty=1.1,
49
+ early_stopping=True
50
+ )
51
+
52
+ # Decode the response
53
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+
55
+ # Extract only the new generated text
56
+ response = response[len(full_prompt):].strip()
57
+
58
+ # Clean up the response (remove any trailing human prompts)
59
+ if "Human:" in response:
60
+ response = response.split("Human:")[0].strip()
61
+
62
+ return response
63
+
64
+ except Exception as e:
65
+ return f"Sorry, I encountered an error: {str(e)}"
66
+
67
+ def chat_interface(message, history, temperature, max_length):
68
+ """
69
+ Main chat interface function
70
+ """
71
+ if not message.strip():
72
+ return "", history
73
+
74
+ # Generate response
75
+ bot_response = generate_response(message, history, temperature, max_length)
76
+
77
+ # Update history
78
+ history.append([message, bot_response])
79
+
80
+ return "", history
81
+
82
+ def clear_chat():
83
+ """
84
+ Clear the chat history
85
+ """
86
+ return []
87
+
88
+ # Custom CSS for beautiful styling
89
+ custom_css = """
90
+ #chatbot {
91
+ background-color: #f8f9fa;
92
+ border-radius: 10px;
93
+ padding: 20px;
94
+ height: 500px;
95
+ }
96
+
97
+ .gradio-container {
98
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
99
+ }
100
+
101
+ .contain {
102
+ background-color: white;
103
+ border-radius: 15px;
104
+ padding: 20px;
105
+ box-shadow: 0 10px 30px rgba(0,0,0,0.1);
106
+ }
107
+
108
+ .dark .contain {
109
+ background-color: #1e1e1e;
110
+ }
111
+
112
+ .gr-button {
113
+ background: linear-gradient(45deg, #FF6B6B, #FF8E53) !important;
114
+ border: none !important;
115
+ color: white !important;
116
+ border-radius: 25px !important;
117
+ padding: 10px 20px !important;
118
+ }
119
+
120
+ .gr-button:hover {
121
+ transform: translateY(-2px);
122
+ box-shadow: 0 5px 15px rgba(255,107,107,0.4);
123
+ }
124
+
125
+ .slider-container {
126
+ background: white;
127
+ padding: 15px;
128
+ border-radius: 10px;
129
+ margin: 10px 0;
130
+ }
131
+
132
+ .dark .slider-container {
133
+ background: #2d2d2d;
134
+ }
135
+ """
136
+
137
+ # Create the Gradio interface
138
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
139
+ gr.Markdown(
140
+ """
141
+ # 🤖 GPT-2 Chatbot
142
+ **Chat with OpenAI's GPT-2 model!** This AI assistant can help you with conversations, creative writing, and more.
143
+ """
144
+ )
145
+
146
+ with gr.Row():
147
+ with gr.Column(scale=3):
148
+ chatbot = gr.Chatbot(
149
+ label="Chat History",
150
+ elem_id="chatbot",
151
+ show_copy_button=True,
152
+ avatar_images=("👤", "🤖")
153
+ )
154
+
155
+ with gr.Row():
156
+ msg = gr.Textbox(
157
+ label="Type your message here...",
158
+ placeholder="Hello! How can I help you today?",
159
+ lines=2,
160
+ scale=4
161
+ )
162
+ submit_btn = gr.Button("Send 🚀", scale=1)
163
+
164
+ with gr.Row():
165
+ clear_btn = gr.Button("Clear Chat 🗑️")
166
+
167
+ with gr.Column(scale=1):
168
+ gr.Markdown("### ⚙️ Settings")
169
+
170
+ with gr.Group(elem_classes="slider-container"):
171
+ temperature = gr.Slider(
172
+ minimum=0.1,
173
+ maximum=1.0,
174
+ value=0.7,
175
+ step=0.1,
176
+ label="Temperature",
177
+ info="Higher = more creative, Lower = more focused"
178
+ )
179
+
180
+ max_length = gr.Slider(
181
+ minimum=50,
182
+ maximum=300,
183
+ value=150,
184
+ step=10,
185
+ label="Max Response Length",
186
+ info="Maximum length of generated responses"
187
+ )
188
+
189
+ gr.Markdown("### ℹ️ About")
190
+ gr.Markdown("""
191
+ This chatbot uses **GPT-2** from OpenAI.
192
+
193
+ **Tips:**
194
+ - Be specific in your questions
195
+ - Adjust temperature for creativity
196
+ - Clear chat to start fresh
197
+ """)
198
+
199
+ # Event handlers
200
+ submit_event = msg.submit(
201
+ chat_interface,
202
+ inputs=[msg, chatbot, temperature, max_length],
203
+ outputs=[msg, chatbot]
204
+ )
205
+
206
+ submit_btn.click(
207
+ chat_interface,
208
+ inputs=[msg, chatbot, temperature, max_length],
209
+ outputs=[msg, chatbot]
210
+ )
211
+
212
+ clear_btn.click(
213
+ clear_chat,
214
+ outputs=[chatbot]
215
+ )
216
+
217
+ # Additional handlers for Enter key
218
+ submit_event.then(
219
+ lambda: gr.Textbox(value="", interactive=True),
220
+ outputs=[msg]
221
+ )
222
+
223
+ if __name__ == "__main__":
224
+ demo.launch(
225
+ server_name="0.0.0.0",
226
+ share=False
227
+ )