Keeby-smilyai commited on
Commit
853a0d4
Β·
verified Β·
1 Parent(s): c783df6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +618 -0
app.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from huggingface_hub import hf_hub_download
4
+ import json
5
+ import os
6
+ from tokenizers import Tokenizer
7
+ import numpy as np
8
+ import time
9
+
10
+ # ============================================================================
11
+ # 🎊 FESTIVE MODE TOGGLE 🎊
12
+ # ============================================================================
13
+ FESTIVE = True # Set to False for production-only mode
14
+
15
+ # ============================================================================
16
+ # Configuration & Model Loading
17
+ # ============================================================================
18
+
19
+ print("πŸš€ Loading SAM-Z-1 Model...")
20
+
21
+ MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
22
+ CACHE_DIR = "./model_cache"
23
+
24
+ # Download model files
25
+ config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
26
+ model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
27
+ tokenizer_path = hf_hub_download(MODEL_REPO, "tokenizer.json", cache_dir=CACHE_DIR)
28
+
29
+ # Load config
30
+ with open(config_path, 'r') as f:
31
+ config = json.load(f)
32
+
33
+ # Load tokenizer
34
+ tokenizer = Tokenizer.from_file(tokenizer_path)
35
+ eos_token_id = config.get('eos_token_id', 50256)
36
+
37
+ # Load model with TF function optimization
38
+ model = tf.keras.models.load_model(model_path, compile=False)
39
+
40
+ # Create optimized inference function
41
+ @tf.function(reduce_retracing=True)
42
+ def fast_forward(input_tensor):
43
+ """TF-optimized forward pass for faster generation"""
44
+ return model(input_tensor, training=False)
45
+
46
+ print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
47
+ print(f"βœ… TF function optimization enabled for faster inference")
48
+
49
+ # Global stop flag
50
+ stop_generation = False
51
+
52
+ # ============================================================================
53
+ # Generation Function with Streaming & Stop Button
54
+ # ============================================================================
55
+
56
+ def generate_stream(
57
+ prompt: str,
58
+ max_tokens: int = 512,
59
+ temperature: float = 0.8,
60
+ top_k: int = 40,
61
+ top_p: float = 0.9,
62
+ repetition_penalty: float = 1.1
63
+ ):
64
+ """Generate text with streaming output and stop support"""
65
+ global stop_generation
66
+ stop_generation = False
67
+
68
+ # Tokenize prompt
69
+ input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
70
+
71
+ if len(input_ids) == 0:
72
+ yield "⚠️ Empty prompt after tokenization"
73
+ return
74
+
75
+ if len(input_ids) > config['max_position_embeddings'] - max_tokens:
76
+ input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):]
77
+
78
+ input_tensor = tf.constant([input_ids], dtype=tf.int32)
79
+ generated_text = ""
80
+ token_count = 0
81
+
82
+ # Track token frequencies for repetition penalty
83
+ token_freq = {}
84
+
85
+ start_time = time.time()
86
+
87
+ for step in range(max_tokens):
88
+ # Check stop flag
89
+ if stop_generation:
90
+ generated_text += "\n\n*[Generation stopped by user]*"
91
+ yield generated_text
92
+ break
93
+
94
+ # Get logits using optimized TF function
95
+ logits = fast_forward(input_tensor)
96
+ next_token_logits = logits[0, -1, :].numpy()
97
+
98
+ # Apply temperature
99
+ next_token_logits = next_token_logits / temperature
100
+
101
+ # Apply repetition penalty
102
+ if repetition_penalty != 1.0:
103
+ for token_id, freq in token_freq.items():
104
+ if token_id < len(next_token_logits):
105
+ next_token_logits[token_id] /= (repetition_penalty ** freq)
106
+
107
+ # Top-k filtering
108
+ if top_k > 0:
109
+ top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
110
+ top_k_logits = next_token_logits[top_k_indices]
111
+ top_k_probs = tf.nn.softmax(top_k_logits).numpy()
112
+
113
+ # Top-p (nucleus) sampling
114
+ if top_p < 1.0:
115
+ sorted_indices = np.argsort(top_k_probs)[::-1]
116
+ cumsum = np.cumsum(top_k_probs[sorted_indices])
117
+ cutoff_idx = np.searchsorted(cumsum, top_p)
118
+ nucleus_indices = sorted_indices[:cutoff_idx + 1]
119
+
120
+ nucleus_logits = top_k_logits[nucleus_indices]
121
+ nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
122
+
123
+ sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
124
+ next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
125
+ else:
126
+ sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs)
127
+ next_token_id = int(top_k_indices[sampled_idx])
128
+ else:
129
+ probs = tf.nn.softmax(next_token_logits).numpy()
130
+ next_token_id = np.random.choice(len(probs), p=probs)
131
+
132
+ # Stop on EOS
133
+ if next_token_id == eos_token_id:
134
+ break
135
+
136
+ # Update token frequency
137
+ token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
138
+
139
+ # Decode and yield
140
+ token_text = tokenizer.decode([next_token_id])
141
+ generated_text += token_text
142
+ token_count += 1
143
+
144
+ # Yield progressive output
145
+ yield generated_text
146
+
147
+ # Update input
148
+ input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
149
+
150
+ # Truncate if too long
151
+ if input_tensor.shape[1] > config['max_position_embeddings']:
152
+ input_tensor = input_tensor[:, -config['max_position_embeddings']:]
153
+
154
+ # Calculate stats
155
+ elapsed = time.time() - start_time
156
+ tokens_per_sec = token_count / elapsed if elapsed > 0 else 0
157
+
158
+ # Add generation stats
159
+ if token_count > 0 and not stop_generation:
160
+ generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({tokens_per_sec:.1f} tok/s)]*"
161
+
162
+ yield generated_text
163
+
164
+ # ============================================================================
165
+ # Chat Interface Logic
166
+ # ============================================================================
167
+
168
+ def format_chat_prompt(message: str, history: list) -> str:
169
+ """Format message history into chat prompt"""
170
+ prompt = ""
171
+
172
+ # Add history
173
+ for user_msg, assistant_msg in history:
174
+ prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
175
+ if assistant_msg:
176
+ prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
177
+
178
+ # Add current message
179
+ prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
180
+
181
+ return prompt
182
+
183
+ def chat_stream(
184
+ message: str,
185
+ history: list,
186
+ max_tokens: int,
187
+ temperature: float,
188
+ top_k: int,
189
+ top_p: float,
190
+ repetition_penalty: float
191
+ ):
192
+ """Streaming chat response"""
193
+ if not message.strip():
194
+ yield history
195
+ return
196
+
197
+ # Format prompt
198
+ prompt = format_chat_prompt(message, history)
199
+
200
+ # Generate with streaming
201
+ partial_response = ""
202
+ for generated in generate_stream(
203
+ prompt,
204
+ max_tokens=max_tokens,
205
+ temperature=temperature,
206
+ top_k=top_k,
207
+ top_p=top_p,
208
+ repetition_penalty=repetition_penalty
209
+ ):
210
+ partial_response = generated
211
+
212
+ # Stop at end tags
213
+ if "<|im_end|>" in partial_response:
214
+ partial_response = partial_response.split("<|im_end|>")[0]
215
+
216
+ # Update history
217
+ yield history + [[message, partial_response.strip()]]
218
+
219
+ def stop_gen():
220
+ """Stop generation callback"""
221
+ global stop_generation
222
+ stop_generation = True
223
+ return None
224
+
225
+ # ============================================================================
226
+ # Gradio UI
227
+ # ============================================================================
228
+
229
+ # Festive CSS
230
+ festive_css = """
231
+ .gradio-container {
232
+ max-width: 1200px !important;
233
+ margin: auto !important;
234
+ }
235
+
236
+ .header {
237
+ text-align: center;
238
+ padding: 2rem;
239
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
240
+ color: white;
241
+ border-radius: 12px;
242
+ margin-bottom: 2rem;
243
+ box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);
244
+ animation: pulse 2s ease-in-out infinite;
245
+ }
246
+
247
+ @keyframes pulse {
248
+ 0%, 100% { transform: scale(1); }
249
+ 50% { transform: scale(1.02); }
250
+ }
251
+
252
+ .header h1 {
253
+ font-size: 2.8rem;
254
+ margin-bottom: 0.5rem;
255
+ font-weight: 700;
256
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
257
+ }
258
+
259
+ .header p {
260
+ font-size: 1.1rem;
261
+ opacity: 0.95;
262
+ }
263
+
264
+ .celebration {
265
+ font-size: 2rem;
266
+ margin: 0.5rem;
267
+ animation: bounce 1s ease infinite;
268
+ }
269
+
270
+ @keyframes bounce {
271
+ 0%, 100% { transform: translateY(0); }
272
+ 50% { transform: translateY(-10px); }
273
+ }
274
+
275
+ .stats-card {
276
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
277
+ padding: 1.5rem;
278
+ border-radius: 12px;
279
+ border-left: 4px solid #f5576c;
280
+ margin: 1rem 0;
281
+ box-shadow: 0 4px 16px rgba(252, 182, 159, 0.3);
282
+ }
283
+
284
+ .twin-badge {
285
+ display: inline-block;
286
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
287
+ color: white;
288
+ padding: 0.5rem 1rem;
289
+ border-radius: 20px;
290
+ font-weight: bold;
291
+ margin: 0.5rem;
292
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
293
+ }
294
+
295
+ footer {
296
+ text-align: center;
297
+ padding: 2rem;
298
+ color: #666;
299
+ border-top: 1px solid #eee;
300
+ margin-top: 2rem;
301
+ }
302
+
303
+ .confetti {
304
+ position: fixed;
305
+ width: 10px;
306
+ height: 10px;
307
+ background: #f5576c;
308
+ position: absolute;
309
+ animation: confetti-fall 3s linear infinite;
310
+ }
311
+
312
+ @keyframes confetti-fall {
313
+ to { transform: translateY(100vh) rotate(360deg); }
314
+ }
315
+ """
316
+
317
+ # Production CSS
318
+ production_css = """
319
+ .gradio-container {
320
+ max-width: 1200px !important;
321
+ margin: auto !important;
322
+ }
323
+
324
+ .header {
325
+ text-align: center;
326
+ padding: 2rem;
327
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
328
+ color: white;
329
+ border-radius: 12px;
330
+ margin-bottom: 2rem;
331
+ }
332
+
333
+ .header h1 {
334
+ font-size: 2.5rem;
335
+ margin-bottom: 0.5rem;
336
+ font-weight: 700;
337
+ }
338
+
339
+ .header p {
340
+ font-size: 1.1rem;
341
+ opacity: 0.95;
342
+ }
343
+
344
+ .stats-card {
345
+ background: #f8f9fa;
346
+ padding: 1rem;
347
+ border-radius: 8px;
348
+ border-left: 4px solid #667eea;
349
+ margin: 1rem 0;
350
+ }
351
+
352
+ footer {
353
+ text-align: center;
354
+ padding: 2rem;
355
+ color: #666;
356
+ border-top: 1px solid #eee;
357
+ margin-top: 2rem;
358
+ }
359
+ """
360
+
361
+ # Select CSS based on mode
362
+ custom_css = festive_css if FESTIVE else production_css
363
+
364
+ # Build interface
365
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
366
+ # Header
367
+ if FESTIVE:
368
+ gr.HTML("""
369
+ <div class="header">
370
+ <div class="celebration">πŸŽ‰ 🎊 ✨ 🎈 πŸŽ†</div>
371
+ <h1>πŸ€– SAM-Z-1 Chat πŸ€–</h1>
372
+ <p><strong>LATEST RELEASE!</strong> Our fastest non-reasoning model</p>
373
+ <div class="twin-badge">Twin of SAM-X-1 (Reasoning Model)</div>
374
+ <p style="font-size: 0.9rem; margin-top: 1rem;">
375
+ 768D β€’ 16 Layers β€’ 12 Heads β€’ ~140M Parameters β€’ Trained on TPU v5e-8
376
+ </p>
377
+ <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
378
+ </div>
379
+ """)
380
+ else:
381
+ gr.HTML("""
382
+ <div class="header">
383
+ <h1>πŸ€– SAM-Z-1 Chat</h1>
384
+ <p>Fast, direct responses without reasoning overhead</p>
385
+ <p style="font-size: 0.9rem; margin-top: 0.5rem;">
386
+ 768D β€’ 16 Layers β€’ 12 Heads β€’ Trained on TPU v5e-8
387
+ </p>
388
+ </div>
389
+ """)
390
+
391
+ with gr.Row():
392
+ with gr.Column(scale=4):
393
+ # Chat interface
394
+ chatbot = gr.Chatbot(
395
+ height=600,
396
+ show_label=False,
397
+ avatar_images=(None, "πŸ€–" if not FESTIVE else "πŸŽ‰"),
398
+ bubble_full_width=False
399
+ )
400
+
401
+ with gr.Row():
402
+ msg = gr.Textbox(
403
+ placeholder="Type your message here..." if not FESTIVE else "Ask me anything! I'm the fast twin! ⚑",
404
+ show_label=False,
405
+ scale=8,
406
+ container=False
407
+ )
408
+ submit_btn = gr.Button("Send πŸš€" if FESTIVE else "Send", variant="primary", scale=1)
409
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
410
+
411
+ with gr.Row():
412
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
413
+ retry_btn = gr.Button("πŸ”„ Retry", size="sm")
414
+
415
+ with gr.Column(scale=1):
416
+ gr.Markdown("### βš™οΈ Generation Settings")
417
+
418
+ max_tokens = gr.Slider(
419
+ minimum=50,
420
+ maximum=1024,
421
+ value=512,
422
+ step=50,
423
+ label="Max Tokens",
424
+ info="Maximum length of response"
425
+ )
426
+
427
+ temperature = gr.Slider(
428
+ minimum=0.1,
429
+ maximum=2.0,
430
+ value=0.8,
431
+ step=0.1,
432
+ label="Temperature",
433
+ info="Higher = more creative"
434
+ )
435
+
436
+ top_k = gr.Slider(
437
+ minimum=1,
438
+ maximum=100,
439
+ value=40,
440
+ step=1,
441
+ label="Top-K",
442
+ info="Sample from top K tokens"
443
+ )
444
+
445
+ top_p = gr.Slider(
446
+ minimum=0.1,
447
+ maximum=1.0,
448
+ value=0.9,
449
+ step=0.05,
450
+ label="Top-P",
451
+ info="Nucleus sampling threshold"
452
+ )
453
+
454
+ repetition_penalty = gr.Slider(
455
+ minimum=1.0,
456
+ maximum=2.0,
457
+ value=1.1,
458
+ step=0.1,
459
+ label="Repetition Penalty",
460
+ info="Penalize repeated tokens"
461
+ )
462
+
463
+ gr.Markdown("---")
464
+
465
+ # Model info
466
+ if FESTIVE:
467
+ gr.Markdown(f"""
468
+ ### 🎊 SAM-Z-1 Model Info
469
+
470
+ **🎯 The Fast Twin!**
471
+
472
+ **Type:** Direct Response Model
473
+ **Parameters:** ~140M
474
+ **Context:** {config['max_position_embeddings']} tokens
475
+ **Vocab:** {config['vocab_size']}
476
+ **Speed:** ⚑ Optimized with TF Functions
477
+
478
+ **Twin Model:**
479
+ - **SAM-X-1**: Reasoning model (with thinking)
480
+ - **SAM-Z-1**: Fast model (YOU ARE HERE! πŸŽ‰)
481
+
482
+ **Architecture:**
483
+ - RoPE positional encoding
484
+ - SwiGLU activation
485
+ - RMSNorm layers
486
+ - No bias terms (efficient!)
487
+
488
+ **Training:**
489
+ - Trained from scratch
490
+ - TPU v5e-8 (8 cores)
491
+ - Mixed precision (bfloat16)
492
+ - Cosine decay schedule
493
+ """)
494
+ else:
495
+ gr.Markdown(f"""
496
+ ### πŸ“Š Model Info
497
+
498
+ **Architecture:** SAM-Z-1 (Direct Response)
499
+ **Parameters:** ~140M
500
+ **Context:** {config['max_position_embeddings']} tokens
501
+ **Vocab:** {config['vocab_size']}
502
+
503
+ **Twin Models:**
504
+ - SAM-X-1: Reasoning model
505
+ - SAM-Z-1: Direct response model
506
+
507
+ **Features:**
508
+ - RoPE positional encoding
509
+ - SwiGLU activation
510
+ - RMSNorm layers
511
+ - TF-optimized inference
512
+ """)
513
+
514
+ # Example prompts
515
+ gr.Examples(
516
+ examples=[
517
+ "Hi! What can you do?",
518
+ "Explain quantum computing in simple terms",
519
+ "Write a short poem about AI",
520
+ "What's the capital of France?",
521
+ "How do I learn programming?",
522
+ "Tell me an interesting fact about space",
523
+ "What's the difference between you and SAM-X-1?",
524
+ "Why are you called the fast twin?",
525
+ ],
526
+ inputs=msg,
527
+ label="πŸ’‘ Try these examples" if not FESTIVE else "🎯 Try these examples!"
528
+ )
529
+
530
+ # Footer
531
+ if FESTIVE:
532
+ gr.HTML("""
533
+ <footer>
534
+ <p style="font-size: 1.2rem;"><strong>πŸŽ‰ SAM-Z-1 - LATEST RELEASE! πŸŽ‰</strong></p>
535
+ <p><strong>The Fast Twin</strong> - Direct responses without reasoning overhead</p>
536
+ <p style="font-size: 0.9rem; color: #999; margin-top: 0.5rem;">
537
+ Trained from scratch on TPU v5e-8 β€’ Built with TensorFlow & Gradio
538
+ </p>
539
+ <p style="font-size: 0.9rem; color: #999;">
540
+ Twin of SAM-X-1 (reasoning model) β€’ Same architecture, different training objective
541
+ </p>
542
+ <div style="margin-top: 1rem; font-size: 1.5rem;">
543
+ ⚑ πŸš€ πŸ’« ✨ 🎯
544
+ </div>
545
+ </footer>
546
+ """)
547
+ else:
548
+ gr.HTML("""
549
+ <footer>
550
+ <p><strong>SAM-Z-1</strong> - Direct response language model</p>
551
+ <p style="font-size: 0.9rem; color: #999;">
552
+ Trained from scratch on TPU v5e-8 β€’ Built with TensorFlow & Gradio
553
+ </p>
554
+ <p style="font-size: 0.9rem; color: #999;">
555
+ Twin of SAM-X-1 (reasoning model)
556
+ </p>
557
+ </footer>
558
+ """)
559
+
560
+ # Event handlers
561
+ submit_event = msg.submit(
562
+ chat_stream,
563
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
564
+ outputs=[chatbot]
565
+ ).then(
566
+ lambda: ("", None),
567
+ outputs=[msg, None]
568
+ )
569
+
570
+ click_event = submit_btn.click(
571
+ chat_stream,
572
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
573
+ outputs=[chatbot]
574
+ ).then(
575
+ lambda: ("", None),
576
+ outputs=[msg, None]
577
+ )
578
+
579
+ # Stop button
580
+ stop_btn.click(
581
+ fn=stop_gen,
582
+ inputs=None,
583
+ outputs=None,
584
+ cancels=[submit_event, click_event]
585
+ )
586
+
587
+ clear_btn.click(lambda: (None, ""), outputs=[chatbot, msg])
588
+
589
+ def retry_last(history, max_tok, temp, topk, topp, rep_pen):
590
+ if not history:
591
+ return history
592
+ last_user_msg = history[-1][0]
593
+ history = history[:-1]
594
+ for update in chat_stream(last_user_msg, history, max_tok, temp, topk, topp, rep_pen):
595
+ yield update
596
+
597
+ retry_event = retry_btn.click(
598
+ retry_last,
599
+ inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
600
+ outputs=[chatbot]
601
+ )
602
+
603
+ stop_btn.click(
604
+ fn=stop_gen,
605
+ inputs=None,
606
+ outputs=None,
607
+ cancels=[retry_event]
608
+ )
609
+
610
+ # Launch
611
+ if __name__ == "__main__":
612
+ demo.queue(max_size=20)
613
+ demo.launch(
614
+ server_name="0.0.0.0",
615
+ server_port=7860,
616
+ share=False,
617
+ show_error=True
618
+ )