TobDeBer commited on
Commit
093459e
Β·
verified Β·
1 Parent(s): ccdf284

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +301 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ import time
5
+ import random
6
+
7
+ # Model configuration - using TinyLlama for efficient CPU inference
8
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
9
+
10
+ # Global variables for model components
11
+ tokenizer = None
12
+ model = None
13
+ text_generator = None
14
+
15
+ def load_model():
16
+ """Load the Smol LLM model and tokenizer"""
17
+ global tokenizer, model, text_generator
18
+ try:
19
+ print(f"Loading model: {MODEL_NAME}")
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_NAME,
23
+ torch_dtype=torch.float32, # Use float32 for CPU
24
+ device_map="auto"
25
+ )
26
+
27
+ # Create text generation pipeline
28
+ text_generator = pipeline(
29
+ "text-generation",
30
+ model=model,
31
+ tokenizer=tokenizer,
32
+ max_new_tokens=512,
33
+ temperature=0.7,
34
+ top_p=0.95,
35
+ do_sample=True
36
+ )
37
+
38
+ # Set pad token if not present
39
+ if tokenizer.pad_token is None:
40
+ tokenizer.pad_token = tokenizer.eos_token
41
+
42
+ return "βœ… Model loaded successfully!"
43
+ except Exception as e:
44
+ return f"❌ Error loading model: {str(e)}"
45
+
46
+ def format_prompt(prompt, system_prompt=None):
47
+ """Format the prompt for chat-style models"""
48
+ if system_prompt:
49
+ formatted = f"<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>"
50
+ else:
51
+ formatted = f"<|user|>\n{prompt}\n<|assistant|>"
52
+ return formatted
53
+
54
+ def generate_text(
55
+ prompt,
56
+ max_length=200,
57
+ temperature=0.7,
58
+ top_p=0.95,
59
+ repetition_penalty=1.1,
60
+ system_prompt="You are a helpful AI assistant. Provide clear and concise answers."
61
+ ):
62
+ """Generate text using the loaded model"""
63
+ global text_generator
64
+
65
+ if text_generator is None:
66
+ return "⚠️ Please load the model first using the 'Load Model' button."
67
+
68
+ if not prompt.strip():
69
+ return "⚠️ Please enter a prompt."
70
+
71
+ try:
72
+ # Format the prompt
73
+ formatted_prompt = format_prompt(prompt, system_prompt)
74
+
75
+ # Update pipeline parameters
76
+ text_generator.max_new_tokens = max_length
77
+ text_generator.temperature = temperature
78
+ text_generator.top_p = top_p
79
+ text_generator.repetition_penalty = repetition_penalty
80
+
81
+ # Generate response
82
+ start_time = time.time()
83
+ result = text_generator(
84
+ formatted_prompt,
85
+ max_new_tokens=max_length,
86
+ temperature=temperature,
87
+ top_p=top_p,
88
+ repetition_penalty=repetition_penalty,
89
+ do_sample=True,
90
+ pad_token_id=tokenizer.eos_token_id,
91
+ eos_token_id=tokenizer.eos_token_id
92
+ )
93
+
94
+ generation_time = time.time() - start_time
95
+
96
+ # Extract the generated text
97
+ generated_text = result[0]["generated_text"]
98
+
99
+ # Extract only the assistant's response
100
+ if "<|assistant|>" in generated_text:
101
+ response = generated_text.split("<|assistant|>")[-1].strip()
102
+ else:
103
+ response = generated_text
104
+
105
+ # Format output with metadata
106
+ output = f"**Response:**\n{response}\n\n---\n*Generated in {generation_time:.2f} seconds*"
107
+
108
+ return output
109
+
110
+ except Exception as e:
111
+ return f"❌ Error during generation: {str(e)}"
112
+
113
+ def clear_chat():
114
+ """Clear the chat interface"""
115
+ return "", ""
116
+
117
+ # Create custom theme
118
+ custom_theme = gr.themes.Soft(
119
+ primary_hue="blue",
120
+ secondary_hue="indigo",
121
+ neutral_hue="slate",
122
+ font=gr.themes.GoogleFont("Inter"),
123
+ text_size="lg",
124
+ spacing_size="lg",
125
+ radius_size="md"
126
+ ).set(
127
+ button_primary_background_fill="*primary_600",
128
+ button_primary_background_fill_hover="*primary_700",
129
+ block_title_text_weight="600",
130
+ )
131
+
132
+ # Build the Gradio interface
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown(
135
+ """
136
+ # πŸ€– Smol LLM Inference GUI
137
+
138
+ **Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)** -
139
+ Efficient text generation using TinyLlama
140
+
141
+ This application runs a compact language model locally for text generation.
142
+ Perfect for chat, completion tasks, and creative writing.
143
+ """
144
+ )
145
+
146
+ with gr.Row():
147
+ with gr.Column(scale=2):
148
+ # Model loading section
149
+ with gr.Group():
150
+ gr.Markdown("### πŸ“¦ Model Management")
151
+ model_status = gr.Textbox(
152
+ label="Model Status",
153
+ value="Model not loaded. Click 'Load Model' to start.",
154
+ interactive=False
155
+ )
156
+ load_btn = gr.Button(
157
+ "πŸ”„ Load Model",
158
+ variant="primary",
159
+ size="lg"
160
+ )
161
+
162
+ # Generation parameters
163
+ gr.Markdown("### βš™οΈ Generation Parameters")
164
+
165
+ with gr.Row():
166
+ max_length = gr.Slider(
167
+ minimum=50,
168
+ maximum=1024,
169
+ value=200,
170
+ step=50,
171
+ label="Max Tokens"
172
+ )
173
+ temperature = gr.Slider(
174
+ minimum=0.1,
175
+ maximum=2.0,
176
+ value=0.7,
177
+ step=0.1,
178
+ label="Temperature"
179
+ )
180
+
181
+ with gr.Row():
182
+ top_p = gr.Slider(
183
+ minimum=0.1,
184
+ maximum=1.0,
185
+ value=0.95,
186
+ step=0.05,
187
+ label="Top-p"
188
+ )
189
+ repetition_penalty = gr.Slider(
190
+ minimum=1.0,
191
+ maximum=2.0,
192
+ value=1.1,
193
+ step=0.1,
194
+ label="Repetition Penalty"
195
+ )
196
+
197
+ system_prompt = gr.Textbox(
198
+ label="System Prompt",
199
+ value="You are a helpful AI assistant. Provide clear and concise answers.",
200
+ lines=3,
201
+ placeholder="Enter a system prompt to guide the model's behavior..."
202
+ )
203
+
204
+ with gr.Column(scale=3):
205
+ # Main interface
206
+ with gr.Group():
207
+ gr.Markdown("### πŸ’¬ Text Generation")
208
+
209
+ prompt_input = gr.Textbox(
210
+ label="Enter your prompt",
211
+ placeholder="Type your message here...",
212
+ lines=4,
213
+ autofocus=True
214
+ )
215
+
216
+ with gr.Row():
217
+ generate_btn = gr.Button(
218
+ "πŸš€ Generate",
219
+ variant="primary",
220
+ size="lg"
221
+ )
222
+ clear_btn = gr.Button(
223
+ "πŸ—‘οΈ Clear",
224
+ variant="secondary"
225
+ )
226
+
227
+ output_text = gr.Markdown(
228
+ label="Generated Response",
229
+ value="*Response will appear here...*"
230
+ )
231
+
232
+ # Example prompts
233
+ with gr.Accordion("πŸ“ Example Prompts", open=False):
234
+ gr.Examples(
235
+ examples=[
236
+ ["Write a short story about a robot discovering music."],
237
+ ["Explain quantum computing in simple terms."],
238
+ ["Create a poem about the changing seasons."],
239
+ ["What are the benefits of renewable energy?"],
240
+ ["Write a Python function to calculate fibonacci numbers."],
241
+ ["Describe the perfect day in your own words."],
242
+ ["Explain the concept of machine learning to a beginner."],
243
+ ["Create a dialogue between two friends planning a trip."]
244
+ ],
245
+ inputs=[prompt_input],
246
+ label="Click an example to get started"
247
+ )
248
+
249
+ # Event handlers
250
+ load_btn.click(
251
+ fn=load_model,
252
+ outputs=[model_status],
253
+ api_visibility="public"
254
+ )
255
+
256
+ generate_btn.click(
257
+ fn=generate_text,
258
+ inputs=[
259
+ prompt_input,
260
+ max_length,
261
+ temperature,
262
+ top_p,
263
+ repetition_penalty,
264
+ system_prompt
265
+ ],
266
+ outputs=[output_text],
267
+ api_visibility="public"
268
+ )
269
+
270
+ clear_btn.click(
271
+ fn=clear_chat,
272
+ outputs=[prompt_input],
273
+ api_visibility="private"
274
+ )
275
+
276
+ # Allow Enter key to generate
277
+ prompt_input.submit(
278
+ fn=generate_text,
279
+ inputs=[
280
+ prompt_input,
281
+ max_length,
282
+ temperature,
283
+ top_p,
284
+ repetition_penalty,
285
+ system_prompt
286
+ ],
287
+ outputs=[output_text],
288
+ api_visibility="public"
289
+ )
290
+
291
+ # Launch the application
292
+ demo.launch(
293
+ theme=custom_theme,
294
+ footer_links=[
295
+ {"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"},
296
+ {"label": "TinyLlama Model", "url": "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0"},
297
+ {"label": "Gradio", "url": "https://gradio.app"}
298
+ ],
299
+ share=False,
300
+ show_error=True
301
+ )