Keeby-smilyai commited on
Commit
acf0e5f
Β·
verified Β·
1 Parent(s): 765bb8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +837 -381
app.py CHANGED
@@ -1,61 +1,31 @@
1
- import os
2
- os.environ['KERAS_BACKEND'] = 'tensorflow'
3
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
-
5
  import gradio as gr
6
  import tensorflow as tf
7
  import keras
8
  from huggingface_hub import hf_hub_download
9
  import json
10
- import numpy as np
11
  from tokenizers import Tokenizer
12
- import threading
13
  import time
14
- import queue
15
- import hashlib
16
- import sqlite3
17
- from datetime import datetime
18
- import uuid
19
 
20
- # ==============================================================================
21
- # 1. Hardware Optimization & Setup
22
- # ==============================================================================
23
- tf.config.threading.set_inter_op_parallelism_threads(2)
24
- tf.config.threading.set_intra_op_parallelism_threads(4)
25
- tf.config.optimizer.set_jit(True)
26
 
27
- print(f"πŸš€ SmilyAI Pro System Initializing...")
 
 
28
 
29
- # ==============================================================================
30
- # 2. Database
31
- # ==============================================================================
32
- def init_db():
33
- conn = sqlite3.connect('sam_tasks.db', check_same_thread=False)
34
- c = conn.cursor()
35
- c.execute('''CREATE TABLE IF NOT EXISTS users
36
- (id INTEGER PRIMARY KEY AUTOINCREMENT,
37
- username TEXT UNIQUE NOT NULL,
38
- password_hash TEXT NOT NULL)''')
39
- c.execute('''CREATE TABLE IF NOT EXISTS tasks
40
- (id TEXT PRIMARY KEY,
41
- user_id INTEGER,
42
- model_name TEXT,
43
- prompt TEXT,
44
- status TEXT,
45
- progress INTEGER DEFAULT 0,
46
- result TEXT,
47
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
48
- tokens_per_sec REAL DEFAULT 0,
49
- FOREIGN KEY (user_id) REFERENCES users(id))''')
50
- conn.commit()
51
- return conn
52
-
53
- db_conn = init_db()
54
- db_lock = threading.Lock()
55
 
56
- # ==============================================================================
57
- # 3. Model (Fixed with tf.cond)
58
- # ==============================================================================
59
  @keras.saving.register_keras_serializable()
60
  class RotaryEmbedding(keras.layers.Layer):
61
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
@@ -64,403 +34,889 @@ class RotaryEmbedding(keras.layers.Layer):
64
  self.max_len = max_len
65
  self.theta = theta
66
  self.built_cache = False
67
-
 
 
 
 
68
  def _build_cache(self):
 
69
  if not self.built_cache:
70
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
71
  t = tf.range(self.max_len, dtype=tf.float32)
72
  freqs = tf.einsum("i,j->ij", t, inv_freq)
73
  emb = tf.concat([freqs, freqs], axis=-1)
 
 
74
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
75
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
76
  self.built_cache = True
77
-
 
 
 
 
78
  def call(self, q, k):
 
79
  self._build_cache()
 
80
  seq_len = tf.shape(q)[2]
81
- cos = self.cos_cached[:seq_len, :][None, None, :, :]
82
- sin = self.sin_cached[:seq_len, :][None, None, :, :]
 
83
 
84
- def rotate_half(x):
85
- x1, x2 = tf.split(x, 2, axis=-1)
86
- return tf.concat([-x2, x1], axis=-1)
87
-
88
- q_rot = (q * cos) + (rotate_half(q) * sin)
89
- k_rot = (k * cos) + (rotate_half(k) * sin)
90
- return q_rot, k_rot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  @keras.saving.register_keras_serializable()
93
  class TransformerBlock(keras.layers.Layer):
94
- def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, **kwargs):
95
  super().__init__(**kwargs)
 
96
  self.n_heads = n_heads
 
 
 
 
97
  self.head_dim = d_model // n_heads
98
- self.d_model = d_model
 
 
 
 
 
 
 
 
 
99
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
100
- self.pre_attn_norm = keras.layers.LayerNormalization(epsilon=1e-5)
101
- self.pre_ffn_norm = keras.layers.LayerNormalization(epsilon=1e-5)
102
- self.q_proj = keras.layers.Dense(d_model, use_bias=False)
103
- self.k_proj = keras.layers.Dense(d_model, use_bias=False)
104
- self.v_proj = keras.layers.Dense(d_model, use_bias=False)
105
- self.out_proj = keras.layers.Dense(d_model, use_bias=False)
106
- self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False)
107
- self.up_proj = keras.layers.Dense(ff_dim, use_bias=False)
108
- self.down_proj = keras.layers.Dense(d_model, use_bias=False)
109
  self.dropout = keras.layers.Dropout(dropout)
110
-
111
- def call(self, x, cache=None, training=None):
112
- B = tf.shape(x)[0]
113
- T = tf.shape(x)[1]
114
 
115
- # 1. Attention
116
  res = x
117
  y = self.pre_attn_norm(x)
118
 
119
- q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
120
- k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
121
- v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
122
-
123
- # KV Cache
124
- if cache is not None:
125
- k_cache, v_cache = cache
126
- k = tf.concat([k_cache, k], axis=1)
127
- v = tf.concat([v_cache, v], axis=1)
128
- new_cache = (k, v)
129
-
130
- # RoPE
131
- q = tf.transpose(q, [0, 2, 1, 3])
132
- k_rot = tf.transpose(k, [0, 2, 1, 3])
133
- v_t = tf.transpose(v, [0, 2, 1, 3])
134
- q, k_rot = self.rope(q, k_rot)
135
-
136
- # Attention Scores
137
- scores = tf.matmul(q, k_rot, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, x.dtype))
138
-
139
- # --- πŸ› οΈ FIX: Graph-Safe Causal Mask ---
140
- def apply_mask():
141
- # Create triangular mask for prefill (T > 1)
142
- mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
143
- return (1.0 - mask) * -1e9
144
-
145
- def no_mask():
146
- # No mask needed for decoding step (T=1 attends to all past)
147
- return tf.zeros((1, 1)) # Broadcastable 0
148
-
149
- # Use tf.cond instead of python 'if'
150
- mask_offset = tf.cond(tf.greater(T, 1), apply_mask, no_mask)
151
- scores = scores + mask_offset
152
- # -----------------------------------------
153
-
154
- attn = tf.nn.softmax(scores, axis=-1)
155
- out = tf.matmul(attn, v_t)
156
- out = tf.reshape(tf.transpose(out, [0, 2, 1, 3]), [B, T, self.d_model])
157
- x = res + self.out_proj(out)
158
 
159
- # 2. FFN
 
 
 
 
 
 
 
 
 
 
 
160
  res = x
161
  y = self.pre_ffn_norm(x)
162
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
163
 
164
- return res + ffn, new_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  @keras.saving.register_keras_serializable()
167
  class SAM1Model(keras.Model):
168
- def __init__(self, config, **kwargs):
169
- super().__init__(**kwargs)
170
- self.embed = keras.layers.Embedding(config['vocab_size'], config['d_model'])
171
- ff_dim = int(config['d_model'] * config['ff_mult'])
172
- self.blocks = [
173
- TransformerBlock(
174
- config['d_model'], config['n_heads'], ff_dim, config['dropout'],
175
- config['max_len'], config['rope_theta']
176
- ) for i in range(config['n_layers'])
177
- ]
178
- self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
179
- self.lm_head = keras.layers.Dense(config['vocab_size'], use_bias=False)
180
-
181
- def call(self, input_ids, cache=None, training=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  x = self.embed(input_ids)
183
- new_caches = []
184
- for i, block in enumerate(self.blocks):
185
- c_i = cache[i] if cache is not None else None
186
- x, nc_i = block(x, cache=c_i, training=training)
187
- new_caches.append(nc_i)
188
- return self.lm_head(self.norm(x)), new_caches
 
 
 
 
189
 
190
- # ==============================================================================
191
- # 4. Load Models
192
- # ==============================================================================
193
- print("\nπŸ“¦ Loading Resources...")
194
 
195
- dummy_in = tf.zeros((1, 1), dtype=tf.int32)
 
196
 
197
- # SAM-X (Reasoning)
198
- print("πŸ”Ή SAM-X-1 (Reasoning)")
199
  try:
200
- samx_cfg = json.load(open(hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "config.json")))
201
- samx_model = SAM1Model({
202
- 'vocab_size': samx_cfg['vocab_size'], 'd_model': samx_cfg['hidden_size'],
203
- 'n_layers': samx_cfg['num_hidden_layers'], 'n_heads': samx_cfg['num_attention_heads'],
204
- 'ff_mult': samx_cfg['intermediate_size']/samx_cfg['hidden_size'],
205
- 'max_len': samx_cfg['max_position_embeddings'], 'rope_theta': samx_cfg['rope_theta'], 'dropout':0.0
206
- })
207
- _ = samx_model(dummy_in)
208
- samx_model.load_weights(hf_hub_download("Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5"))
209
- tokenizer_x = Tokenizer.from_file(hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "tokenizer.json"))
210
- except Exception as e: print(f"⚠️ Failed to load SAM-X: {e}")
211
-
212
- # SAM-Z (Speed)
213
- print("πŸ”Ή SAM-Z-1 (Fast)")
214
- try:
215
- samz_cfg = json.load(open(hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "config.json")))
216
- samz_model = SAM1Model({
217
- 'vocab_size': samz_cfg['vocab_size'], 'd_model': samz_cfg['hidden_size'],
218
- 'n_layers': samz_cfg['num_hidden_layers'], 'n_heads': samz_cfg['num_attention_heads'],
219
- 'ff_mult': samz_cfg['intermediate_size']/samz_cfg['hidden_size'],
220
- 'max_len': samz_cfg['max_position_embeddings'], 'rope_theta': samz_cfg['rope_theta'], 'dropout':0.0
221
- })
222
- _ = samz_model(dummy_in)
223
- samz_model.load_weights(hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "ckpt.weights.h5"))
224
- tokenizer_z = Tokenizer.from_file(hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "tokenizer.json"))
225
- except Exception as e: print(f"⚠️ Failed to load SAM-Z: {e}")
226
-
227
- @tf.function(jit_compile=True)
228
- def predict_x(ids, cache): return samx_model(ids, cache=cache, training=False)
229
-
230
- @tf.function(jit_compile=True)
231
- def predict_z(ids, cache): return samz_model(ids, cache=cache, training=False)
 
 
 
 
 
232
 
233
  # ==============================================================================
234
- # 5. Backend Workers
235
  # ==============================================================================
236
- task_queue = queue.Queue()
237
 
238
- def worker():
239
- while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  try:
241
- tid, model, prompt = task_queue.get(timeout=1)
242
-
243
- # Select Model
244
- if "SAM-X" in model: pred_fn, tok = predict_x, tokenizer_x
245
- else: pred_fn, tok = predict_z, tokenizer_z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- # Inference
248
- try:
249
- ids = [i for i in tok.encode(prompt).ids]
250
- gen = []
 
 
251
 
252
- # Prefill
253
- curr = tf.constant([ids], dtype=tf.int32)
254
- logits, cache = pred_fn(curr, cache=None)
255
- next_t = np.argmax(logits[0,-1,:])
256
- gen.append(next_t)
257
 
258
- # Decode
259
- start = time.time()
260
- for i in range(1024):
261
- curr = tf.constant([[gen[-1]]], dtype=tf.int32)
262
- logits, cache = pred_fn(curr, cache=cache)
263
- next_t = np.argmax(logits[0,-1,:])
264
- if next_t == 50256: break
265
- gen.append(next_t)
266
-
267
- if i % 5 == 0:
268
- txt = tok.decode(gen)
269
- with db_lock:
270
- db_conn.execute("UPDATE tasks SET status='processing', result=?, progress=? WHERE id=?",
271
- (txt, int(i/10.24), tid))
272
- db_conn.commit()
273
-
274
- # Done
275
- txt = tok.decode(gen)
276
- with db_lock:
277
- db_conn.execute("UPDATE tasks SET status='completed', result=?, progress=100, completed_at=? WHERE id=?",
278
- (txt, datetime.now().isoformat(), tid))
279
- db_conn.commit()
280
-
281
- except Exception as e:
282
- print(f"Error {tid}: {e}")
283
- with db_lock:
284
- db_conn.execute("UPDATE tasks SET status='failed', result=? WHERE id=?", (str(e), tid))
285
- db_conn.commit()
286
-
287
- task_queue.task_done()
288
- except queue.Empty: continue
 
 
 
 
 
 
 
 
 
289
 
290
- threading.Thread(target=worker, daemon=True).start()
 
 
291
 
292
- # ==============================================================================
293
- # 6. "More Better" UI (Custom CSS + Chat Layout)
294
- # ==============================================================================
295
- css = """
296
- body { background-color: #0b0f19; color: #e5e7eb; }
297
- .sidebar { background-color: #111827; border-right: 1px solid #374151; height: 100vh; overflow-y: auto; padding: 20px; }
298
- .main-content { padding: 20px; max-width: 900px; margin: 0 auto; }
299
- .task-card {
300
- background: #1f2937; border: 1px solid #374151; border-radius: 8px;
301
- padding: 12px; margin-bottom: 8px; cursor: pointer; transition: all 0.2s;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  }
303
- .task-card:hover { background: #374151; border-color: #60a5fa; }
304
- .status-badge {
305
- font-size: 10px; padding: 2px 6px; border-radius: 4px; text-transform: uppercase; font-weight: bold;
 
306
  }
307
- .status-queued { background: #f59e0b20; color: #f59e0b; }
308
- .status-processing { background: #3b82f620; color: #3b82f6; animation: pulse 2s infinite; }
309
- .status-completed { background: #10b98120; color: #10b981; }
310
- .status-failed { background: #ef444420; color: #ef4444; }
311
-
312
- /* Message Bubbles */
313
- .chat-container { display: flex; flex-direction: column; gap: 20px; margin-top: 20px; }
314
- .message { padding: 16px; border-radius: 12px; max-width: 85%; line-height: 1.6; }
315
- .user-msg { align-self: flex-end; background: #2563eb; color: white; }
316
- .bot-msg { align-self: flex-start; background: #1f2937; border: 1px solid #374151; color: #e5e7eb; width: 100%; }
317
-
318
- /* Thought Block */
319
- details.think {
320
- background: #172554; border-left: 3px solid #3b82f6; border-radius: 4px;
321
- padding: 8px; margin-bottom: 12px; font-size: 0.9em; color: #93c5fd;
322
  }
323
- details.think summary { cursor: pointer; font-weight: bold; opacity: 0.8; }
324
- details.think[open] summary { margin-bottom: 8px; border-bottom: 1px solid #3b82f640; padding-bottom: 4px; }
325
 
326
- @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.6; } 100% { opacity: 1; } }
 
 
 
 
 
 
327
  """
328
 
329
- def format_chat(text):
330
- if not text: return ""
331
- # Beautiful formatted thought blocks
332
- if "<think>" in text:
333
- parts = text.split("<think>")
334
- pre = parts[0]
335
- rest = parts[1]
336
- if "</think>" in rest:
337
- thought, ans = rest.split("</think>")
338
- return f"{pre}<details class='think'><summary>🧠 Thought Process</summary>{thought}</details>{ans}"
339
- return f"{pre}<details class='think' open><summary>🧠 Thinking...</summary>{rest} <span class='status-processing'>●</span></details>"
340
- return text.replace("\n", "<br>")
341
-
342
- with gr.Blocks(css=css, title="SmilyAI Studio", theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate")) as demo:
343
- user_id = gr.State(value=None)
344
- current_task = gr.State(value=None)
345
-
346
- with gr.Row(elem_classes="container"):
347
- # --- Left Sidebar (History) ---
348
- with gr.Column(scale=1, elem_classes="sidebar"):
349
- gr.Markdown("### πŸ—‚οΈ History")
350
- refresh_btn = gr.Button("πŸ”„ Refresh", size="sm", variant="secondary")
351
- history_list = gr.HTML("Log in to see tasks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  gr.Markdown("---")
354
- gr.Markdown("### πŸ‘€ Account")
355
- u_in = gr.Textbox(placeholder="Username", show_label=False)
356
- p_in = gr.Textbox(placeholder="Password", show_label=False, type="password")
357
- login_btn = gr.Button("Login", size="sm")
358
-
359
- # --- Main Content (Chat & Monitor) ---
360
- with gr.Column(scale=3, elem_classes="main-content"):
361
- gr.Markdown("# ✨ SmilyAI Studio")
362
 
363
- with gr.Group():
364
- with gr.Row():
365
- model_sel = gr.Dropdown(
366
- ["SAM-X-1 (Reasoning)", "SAM-Z-1 (Fast)"],
367
- value="SAM-Z-1 (Fast)", label="Select Model", interactive=True
368
- )
369
- prompt_in = gr.Textbox(
370
- placeholder="Ask anything... (e.g. 'Explain quantum physics')",
371
- lines=3, show_label=False
372
- )
373
- with gr.Row():
374
- generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
375
-
376
- # Live View
377
- gr.Markdown("### πŸ“‘ Live Monitor")
378
- with gr.Group():
379
- stream_display = gr.HTML(
380
- "<div style='padding:20px; text-align:center; color:#6b7280'>Select a task to watch</div>",
381
- elem_id="stream-box"
382
- )
383
-
384
- # --- Logic Functions ---
385
- def login(u, p):
386
- h = hashlib.sha256(p.encode()).hexdigest()
387
- with db_lock:
388
- c = db_conn.cursor()
389
- c.execute("SELECT id FROM users WHERE username=?", (u,))
390
- row = c.fetchone()
391
- if not row: # Auto-register for demo
392
- c.execute("INSERT INTO users (username, password_hash) VALUES (?,?)", (u, h))
393
- db_conn.commit()
394
- row = (c.lastrowid,)
395
- return row[0], load_history(row[0])
396
-
397
- def create_task(uid, model, text):
398
- if not uid: return None, "Please login first"
399
- tid = str(uuid.uuid4())
400
- with db_lock:
401
- db_conn.execute("INSERT INTO tasks (id, user_id, model_name, prompt, status) VALUES (?,?,?,?,?)",
402
- (tid, uid, model, text, 'queued'))
403
- db_conn.commit()
404
- task_queue.put((tid, model, text))
405
- return tid, tid # Set current task
406
-
407
- def load_history(uid):
408
- if not uid: return "Please Login"
409
- with db_lock:
410
- rows = db_conn.execute("SELECT id, model_name, status, prompt FROM tasks WHERE user_id=? ORDER BY created_at DESC LIMIT 10", (uid,)).fetchall()
411
-
412
- html = ""
413
- for r in rows:
414
- tid, mod, stat, p = r
415
- short_mod = "Reasoning" if "SAM-X" in mod else "Fast"
416
- html += f"""
417
- <div class='task-card' onclick="setTask('{tid}')">
418
- <div style='display:flex; justify-content:space-between; margin-bottom:4px'>
419
- <span style='font-weight:bold; color:#e5e7eb'>{short_mod}</span>
420
- <span class='status-badge status-{stat}'>{stat}</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  </div>
422
- <div style='font-size:12px; color:#9ca3af; white-space:nowrap; overflow:hidden; text-overflow:ellipsis'>{p}</div>
423
- <div style='font-size:10px; color:#4b5563; margin-top:4px'>ID: {tid[:8]}</div>
424
- </div>
425
- """
426
- return html
427
-
428
- def watch_stream(tid):
429
- if not tid: return "Select a task..."
430
- with db_lock:
431
- row = db_conn.execute("SELECT result, status FROM tasks WHERE id=?", (tid,)).fetchone()
432
- if not row: return "Task not found"
433
-
434
- text, status = row
435
- formatted = format_chat(text)
436
-
437
- container = f"""
438
- <div class='chat-container'>
439
- <div class='message bot-msg'>
440
- {formatted}
441
- </div>
442
- </div>
443
- """
444
- return container
445
-
446
- # --- Wiring ---
447
- login_btn.click(login, [u_in, p_in], [user_id, history_list])
448
 
449
- generate_btn.click(
450
- create_task, [user_id, model_sel, prompt_in], [current_task, current_task]
 
 
451
  ).then(
452
- load_history, [user_id], [history_list]
 
453
  )
454
 
455
- refresh_btn.click(load_history, [user_id], [history_list])
 
 
 
 
 
 
456
 
457
- # Helper to handle Javascript click on HTML cards
458
- # Requires a hidden text input to bridge JS -> Python (omitted for brevity, polling works fine)
459
 
460
- # Auto-refresh stream
461
- timer = gr.Timer(0.5)
462
- timer.tick(watch_stream, [current_task], [stream_display])
463
- timer.tick(load_history, [user_id], [history_list])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
 
465
  if __name__ == "__main__":
466
- demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import keras
4
  from huggingface_hub import hf_hub_download
5
  import json
6
+ import os
7
  from tokenizers import Tokenizer
8
+ import numpy as np
9
  import time
 
 
 
 
 
10
 
11
+ # ============================================================================
12
+ # 🎊 FESTIVE MODE TOGGLE 🎊
13
+ # ============================================================================
14
+ FESTIVE = True # Set to False for production-only mode
 
 
15
 
16
+ # ============================================================================
17
+ # Configuration & Model Loading
18
+ # ============================================================================
19
 
20
+ print("πŸš€ Loading SAM-Z-1 Model...")
21
+
22
+ MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
23
+ CACHE_DIR = "./model_cache"
24
+
25
+ # ============================================================================
26
+ # Model Architecture Definitions (FIXED for model loading)
27
+ # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
29
  @keras.saving.register_keras_serializable()
30
  class RotaryEmbedding(keras.layers.Layer):
31
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
 
34
  self.max_len = max_len
35
  self.theta = theta
36
  self.built_cache = False
37
+
38
+ def build(self, input_shape):
39
+ # Use the ORIGINAL training code - compute cache on first call, not in build
40
+ super().build(input_shape)
41
+
42
  def _build_cache(self):
43
+ """Build RoPE cache on first forward pass"""
44
  if not self.built_cache:
45
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
46
  t = tf.range(self.max_len, dtype=tf.float32)
47
  freqs = tf.einsum("i,j->ij", t, inv_freq)
48
  emb = tf.concat([freqs, freqs], axis=-1)
49
+
50
+ # Store as numpy arrays to avoid graph issues
51
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
52
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
53
  self.built_cache = True
54
+
55
+ def rotate_half(self, x):
56
+ x1, x2 = tf.split(x, 2, axis=-1)
57
+ return tf.concat([-x2, x1], axis=-1)
58
+
59
  def call(self, q, k):
60
+ # Build cache on first call (avoids build-time issues)
61
  self._build_cache()
62
+
63
  seq_len = tf.shape(q)[2]
64
+ dtype = q.dtype
65
+ cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
66
+ sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
67
 
68
+ q_rotated = (q * cos) + (self.rotate_half(q) * sin)
69
+ k_rotated = (k * cos) + (self.rotate_half(k) * sin)
70
+
71
+ return q_rotated, k_rotated
72
+
73
+ def get_config(self):
74
+ config = super().get_config()
75
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
76
+ return config
77
+
78
+
79
+ @keras.saving.register_keras_serializable()
80
+ class RMSNorm(keras.layers.Layer):
81
+ def __init__(self, epsilon=1e-5, **kwargs):
82
+ super().__init__(**kwargs)
83
+ self.epsilon = epsilon
84
+
85
+ def build(self, input_shape):
86
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
87
+
88
+ def call(self, x):
89
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
90
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
91
+
92
+ def get_config(self):
93
+ config = super().get_config()
94
+ config.update({"epsilon": self.epsilon})
95
+ return config
96
+
97
 
98
  @keras.saving.register_keras_serializable()
99
  class TransformerBlock(keras.layers.Layer):
100
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
101
  super().__init__(**kwargs)
102
+ self.d_model = d_model
103
  self.n_heads = n_heads
104
+ self.ff_dim = ff_dim
105
+ self.dropout_rate = dropout
106
+ self.max_len = max_len
107
+ self.rope_theta = rope_theta
108
  self.head_dim = d_model // n_heads
109
+ self.layer_idx = layer_idx
110
+
111
+ self.pre_attn_norm = RMSNorm()
112
+ self.pre_ffn_norm = RMSNorm()
113
+
114
+ self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
115
+ self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
116
+ self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
117
+ self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
118
+
119
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
120
+
121
+ self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
122
+ self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
123
+ self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
124
+
 
 
 
 
125
  self.dropout = keras.layers.Dropout(dropout)
126
+
127
+ def call(self, x, training=None):
128
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
129
+ dtype = x.dtype
130
 
131
+ # Attention
132
  res = x
133
  y = self.pre_attn_norm(x)
134
 
135
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
136
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
137
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
138
+
139
+ q, k = self.rope(q, k)
140
+
141
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ mask = tf.where(
144
+ tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
145
+ tf.constant(-1e9, dtype=dtype),
146
+ tf.constant(0.0, dtype=dtype)
147
+ )
148
+ scores += mask
149
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
150
+
151
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
152
+ x = res + self.dropout(self.out_proj(attn), training=training)
153
+
154
+ # FFN (SwiGLU)
155
  res = x
156
  y = self.pre_ffn_norm(x)
157
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
158
 
159
+ return res + self.dropout(ffn, training=training)
160
+
161
+ def get_config(self):
162
+ config = super().get_config()
163
+ config.update({
164
+ "d_model": self.d_model,
165
+ "n_heads": self.n_heads,
166
+ "ff_dim": self.ff_dim,
167
+ "dropout": self.dropout_rate,
168
+ "max_len": self.max_len,
169
+ "rope_theta": self.rope_theta,
170
+ "layer_idx": self.layer_idx
171
+ })
172
+ return config
173
+
174
 
175
  @keras.saving.register_keras_serializable()
176
  class SAM1Model(keras.Model):
177
+ def __init__(self, **kwargs):
178
+ super().__init__()
179
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
180
+ self.cfg = kwargs['config']
181
+ elif 'vocab_size' in kwargs:
182
+ self.cfg = kwargs
183
+ else:
184
+ self.cfg = kwargs.get('cfg', kwargs)
185
+
186
+ self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
187
+
188
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
189
+ block_args = {
190
+ 'd_model': self.cfg['d_model'],
191
+ 'n_heads': self.cfg['n_heads'],
192
+ 'ff_dim': ff_dim,
193
+ 'dropout': self.cfg['dropout'],
194
+ 'max_len': self.cfg['max_len'],
195
+ 'rope_theta': self.cfg['rope_theta']
196
+ }
197
+
198
+ self.blocks = []
199
+ for i in range(self.cfg['n_layers']):
200
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
201
+ self.blocks.append(block)
202
+
203
+ self.norm = RMSNorm(name="final_norm")
204
+ self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
205
+
206
+ def call(self, input_ids, training=None):
207
  x = self.embed(input_ids)
208
+
209
+ for block in self.blocks:
210
+ x = block(x, training=training)
211
+
212
+ return self.lm_head(self.norm(x))
213
+
214
+ def get_config(self):
215
+ base_config = super().get_config()
216
+ base_config['config'] = self.cfg
217
+ return base_config
218
 
219
+ print("βœ… Model architecture registered")
 
 
 
220
 
221
+ # Download model files
222
+ config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
223
 
224
+ # Try to download checkpoint weights first (more reliable)
 
225
  try:
226
+ weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
227
+ print("βœ… Found checkpoint weights (ckpt.weights.h5)")
228
+ use_checkpoint = True
229
+ except Exception as e:
230
+ print(f"⚠️ Checkpoint not found, falling back to model.keras: {e}")
231
+ model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
232
+ use_checkpoint = False
233
+
234
+ # Load config
235
+ with open(config_path, 'r') as f:
236
+ config = json.load(f)
237
+
238
+ # Create tokenizer from scratch
239
+ print("πŸ“¦ Creating tokenizer from GPT-2 base...")
240
+ from transformers import AutoTokenizer
241
+
242
+ hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
243
+
244
+ # Add custom tokens to match model's vocab size
245
+ custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
246
+ hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
247
+
248
+ # Save and reload as tokenizers format
249
+ os.makedirs("./temp_tokenizer", exist_ok=True)
250
+ hf_tokenizer.save_pretrained("./temp_tokenizer")
251
+ tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
252
+
253
+ print(f"βœ… Tokenizer created with vocab size: {tokenizer.get_vocab_size()}")
254
+ print(f" Custom tokens added: {custom_tokens}")
255
+ print(f" Model vocab size: {config.get('vocab_size', 'unknown')}")
256
+
257
+ # Verify vocab sizes match
258
+ if tokenizer.get_vocab_size() != config.get('vocab_size'):
259
+ print(f"⚠️ WARNING: Tokenizer vocab ({tokenizer.get_vocab_size()}) != Model vocab ({config.get('vocab_size')})")
260
+ print(f" Model was trained with these tokens, but SAM-Z-1 doesn't use <think> tags in generation")
261
+
262
+ eos_token_id = config.get('eos_token_id', 50256)
263
 
264
  # ==============================================================================
265
+ # Load Model - Priority: checkpoint weights > saved model
266
  # ==============================================================================
267
+ print("\nπŸ”„ Loading model...")
268
 
269
+ if use_checkpoint:
270
+ print("πŸ“¦ Building model from config and loading checkpoint weights...")
271
+
272
+ # Build model from scratch with config
273
+ model_config = {
274
+ 'vocab_size': config['vocab_size'],
275
+ 'd_model': config['hidden_size'],
276
+ 'n_layers': config['num_hidden_layers'],
277
+ 'n_heads': config['num_attention_heads'],
278
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
279
+ 'max_len': config['max_position_embeddings'],
280
+ 'dropout': 0.1, # Default dropout
281
+ 'rope_theta': config['rope_theta']
282
+ }
283
+
284
+ model = SAM1Model(config=model_config)
285
+
286
+ # Build model by running a dummy forward pass
287
+ dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
288
+ _ = model(dummy_input, training=False)
289
+
290
+ print(f"βœ… Model architecture built: {model.count_params():,} parameters")
291
+
292
+ # Load checkpoint weights
293
+ print(f"πŸ“₯ Loading checkpoint weights from: {weights_path}")
294
+ model.load_weights(weights_path)
295
+ print("βœ… Checkpoint weights loaded successfully!")
296
+
297
+ else:
298
+ print("πŸ“¦ Loading full saved model...")
299
+ try:
300
+ model = keras.models.load_model(model_path, compile=False)
301
+ print("βœ… Model loaded successfully")
302
+ except Exception as e:
303
+ print(f"❌ Failed to load model: {e}")
304
+ print("\nπŸ”„ Trying alternative: building from config + loading weights...")
305
+
306
+ # Fallback to building model
307
+ model_config = {
308
+ 'vocab_size': config['vocab_size'],
309
+ 'd_model': config['hidden_size'],
310
+ 'n_layers': config['num_hidden_layers'],
311
+ 'n_heads': config['num_attention_heads'],
312
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
313
+ 'max_len': config['max_position_embeddings'],
314
+ 'dropout': 0.1,
315
+ 'rope_theta': config['rope_theta']
316
+ }
317
+
318
+ model = SAM1Model(config=model_config)
319
+ dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
320
+ _ = model(dummy_input, training=False)
321
+
322
+ # Try to load weights from model.keras
323
  try:
324
+ temp_model = keras.models.load_model(model_path, compile=False)
325
+ model.set_weights(temp_model.get_weights())
326
+ print("βœ… Weights transferred successfully")
327
+ except:
328
+ print("❌ Could not load weights - model may not work correctly!")
329
+ raise
330
+
331
+ # Create optimized inference function
332
+ @tf.function(reduce_retracing=True)
333
+ def fast_forward(input_tensor):
334
+ """TF-optimized forward pass for faster generation"""
335
+ return model(input_tensor, training=False)
336
+
337
+ print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
338
+ print(f"βœ… TF function optimization enabled for faster inference")
339
+
340
+ # Global stop flag
341
+ stop_generation = False
342
+
343
+ # ============================================================================
344
+ # Generation Function with Streaming & Stop Button
345
+ # ============================================================================
346
+
347
+ def generate_stream(
348
+ prompt: str,
349
+ max_tokens: int = 512,
350
+ temperature: float = 0.8,
351
+ top_k: int = 40,
352
+ top_p: float = 0.9,
353
+ repetition_penalty: float = 1.1
354
+ ):
355
+ """Generate text with streaming output and stop support"""
356
+ global stop_generation
357
+ stop_generation = False
358
+
359
+ # Tokenize prompt
360
+ input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
361
+
362
+ if len(input_ids) == 0:
363
+ yield "⚠️ Empty prompt after tokenization"
364
+ return
365
+
366
+ if len(input_ids) > config['max_position_embeddings'] - max_tokens:
367
+ input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):]
368
+
369
+ input_tensor = tf.constant([input_ids], dtype=tf.int32)
370
+ generated_text = ""
371
+ token_count = 0
372
+
373
+ # Track token frequencies for repetition penalty
374
+ token_freq = {}
375
+
376
+ start_time = time.time()
377
+
378
+ for step in range(max_tokens):
379
+ # Check stop flag
380
+ if stop_generation:
381
+ generated_text += "\n\n*[Generation stopped by user]*"
382
+ yield generated_text
383
+ break
384
+
385
+ # Get logits using optimized TF function
386
+ logits = fast_forward(input_tensor)
387
+ next_token_logits = logits[0, -1, :].numpy()
388
+
389
+ # Apply temperature
390
+ next_token_logits = next_token_logits / temperature
391
+
392
+ # Apply repetition penalty
393
+ if repetition_penalty != 1.0:
394
+ for token_id, freq in token_freq.items():
395
+ if token_id < len(next_token_logits):
396
+ next_token_logits[token_id] /= (repetition_penalty ** freq)
397
+
398
+ # Top-k filtering
399
+ if top_k > 0:
400
+ top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
401
+ top_k_logits = next_token_logits[top_k_indices]
402
+ top_k_probs = tf.nn.softmax(top_k_logits).numpy()
403
 
404
+ # Top-p (nucleus) sampling
405
+ if top_p < 1.0:
406
+ sorted_indices = np.argsort(top_k_probs)[::-1]
407
+ cumsum = np.cumsum(top_k_probs[sorted_indices])
408
+ cutoff_idx = np.searchsorted(cumsum, top_p)
409
+ nucleus_indices = sorted_indices[:cutoff_idx + 1]
410
 
411
+ nucleus_logits = top_k_logits[nucleus_indices]
412
+ nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
 
 
 
413
 
414
+ sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
415
+ next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
416
+ else:
417
+ sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs)
418
+ next_token_id = int(top_k_indices[sampled_idx])
419
+ else:
420
+ probs = tf.nn.softmax(next_token_logits).numpy()
421
+ next_token_id = np.random.choice(len(probs), p=probs)
422
+
423
+ # Stop on EOS
424
+ if next_token_id == eos_token_id:
425
+ break
426
+
427
+ # Update token frequency
428
+ token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
429
+
430
+ # Decode and yield
431
+ token_text = tokenizer.decode([next_token_id])
432
+ generated_text += token_text
433
+ token_count += 1
434
+
435
+ # Yield progressive output
436
+ yield generated_text
437
+
438
+ # Update input
439
+ input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
440
+
441
+ # Truncate if too long
442
+ if input_tensor.shape[1] > config['max_position_embeddings']:
443
+ input_tensor = input_tensor[:, -config['max_position_embeddings']:]
444
+
445
+ # Calculate stats
446
+ elapsed = time.time() - start_time
447
+ tokens_per_sec = token_count / elapsed if elapsed > 0 else 0
448
+
449
+ # Add generation stats
450
+ if token_count > 0 and not stop_generation:
451
+ generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({tokens_per_sec:.1f} tok/s)]*"
452
+
453
+ yield generated_text
454
 
455
+ # ============================================================================
456
+ # Chat Interface Logic
457
+ # ============================================================================
458
 
459
+ def format_chat_prompt(message: str, history: list) -> str:
460
+ """Format message history into chat prompt"""
461
+ prompt = ""
462
+
463
+ # Add history
464
+ for user_msg, assistant_msg in history:
465
+ prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
466
+ if assistant_msg:
467
+ prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
468
+
469
+ # Add current message
470
+ prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
471
+
472
+ return prompt
473
+
474
+ def chat_stream(
475
+ message: str,
476
+ history: list,
477
+ max_tokens: int,
478
+ temperature: float,
479
+ top_k: int,
480
+ top_p: float,
481
+ repetition_penalty: float
482
+ ):
483
+ """Streaming chat response"""
484
+ if not message.strip():
485
+ yield history
486
+ return
487
+
488
+ # Format prompt
489
+ prompt = format_chat_prompt(message, history)
490
+
491
+ # Generate with streaming
492
+ partial_response = ""
493
+ for generated in generate_stream(
494
+ prompt,
495
+ max_tokens=max_tokens,
496
+ temperature=temperature,
497
+ top_k=top_k,
498
+ top_p=top_p,
499
+ repetition_penalty=repetition_penalty
500
+ ):
501
+ partial_response = generated
502
+
503
+ # Stop at end tags
504
+ if "<|im_end|>" in partial_response:
505
+ partial_response = partial_response.split("<|im_end|>")[0]
506
+
507
+ # Update history
508
+ yield history + [[message, partial_response.strip()]]
509
+
510
+ def stop_gen():
511
+ """Stop generation callback"""
512
+ global stop_generation
513
+ stop_generation = True
514
+ return None
515
+
516
+ # ============================================================================
517
+ # Gradio UI
518
+ # ============================================================================
519
+
520
+ # Festive CSS
521
+ festive_css = """
522
+ .gradio-container {
523
+ max-width: 1200px !important;
524
+ margin: auto !important;
525
+ }
526
+
527
+ .header {
528
+ text-align: center;
529
+ padding: 2rem;
530
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
531
+ color: white;
532
+ border-radius: 12px;
533
+ margin-bottom: 2rem;
534
+ box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);
535
+ animation: pulse 2s ease-in-out infinite;
536
+ }
537
+
538
+ @keyframes pulse {
539
+ 0%, 100% { transform: scale(1); }
540
+ 50% { transform: scale(1.02); }
541
+ }
542
+
543
+ .header h1 {
544
+ font-size: 2.8rem;
545
+ margin-bottom: 0.5rem;
546
+ font-weight: 700;
547
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
548
+ }
549
+
550
+ .header p {
551
+ font-size: 1.1rem;
552
+ opacity: 0.95;
553
+ }
554
+
555
+ .celebration {
556
+ font-size: 2rem;
557
+ margin: 0.5rem;
558
+ animation: bounce 1s ease infinite;
559
+ }
560
+
561
+ @keyframes bounce {
562
+ 0%, 100% { transform: translateY(0); }
563
+ 50% { transform: translateY(-10px); }
564
+ }
565
+
566
+ .stats-card {
567
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
568
+ padding: 1.5rem;
569
+ border-radius: 12px;
570
+ border-left: 4px solid #f5576c;
571
+ margin: 1rem 0;
572
+ box-shadow: 0 4px 16px rgba(252, 182, 159, 0.3);
573
+ }
574
+
575
+ .twin-badge {
576
+ display: inline-block;
577
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
578
+ color: white;
579
+ padding: 0.5rem 1rem;
580
+ border-radius: 20px;
581
+ font-weight: bold;
582
+ margin: 0.5rem;
583
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
584
+ }
585
+
586
+ footer {
587
+ text-align: center;
588
+ padding: 2rem;
589
+ color: #666;
590
+ border-top: 1px solid #eee;
591
+ margin-top: 2rem;
592
+ }
593
+
594
+ .confetti {
595
+ position: fixed;
596
+ width: 10px;
597
+ height: 10px;
598
+ background: #f5576c;
599
+ position: absolute;
600
+ animation: confetti-fall 3s linear infinite;
601
+ }
602
+
603
+ @keyframes confetti-fall {
604
+ to { transform: translateY(100vh) rotate(360deg); }
605
+ }
606
+ """
607
+
608
+ # Production CSS
609
+ production_css = """
610
+ .gradio-container {
611
+ max-width: 1200px !important;
612
+ margin: auto !important;
613
+ }
614
+
615
+ .header {
616
+ text-align: center;
617
+ padding: 2rem;
618
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
619
+ color: white;
620
+ border-radius: 12px;
621
+ margin-bottom: 2rem;
622
+ }
623
+
624
+ .header h1 {
625
+ font-size: 2.5rem;
626
+ margin-bottom: 0.5rem;
627
+ font-weight: 700;
628
  }
629
+
630
+ .header p {
631
+ font-size: 1.1rem;
632
+ opacity: 0.95;
633
  }
634
+
635
+ .stats-card {
636
+ background: #f8f9fa;
637
+ padding: 1rem;
638
+ border-radius: 8px;
639
+ border-left: 4px solid #667eea;
640
+ margin: 1rem 0;
 
 
 
 
 
 
 
 
641
  }
 
 
642
 
643
+ footer {
644
+ text-align: center;
645
+ padding: 2rem;
646
+ color: #666;
647
+ border-top: 1px solid #eee;
648
+ margin-top: 2rem;
649
+ }
650
  """
651
 
652
+ # Select CSS based on mode
653
+ custom_css = festive_css if FESTIVE else production_css
654
+
655
+ # Build interface
656
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
657
+ # Header
658
+ if FESTIVE:
659
+ gr.HTML("""
660
+ <div class="header">
661
+ <div class="celebration">πŸŽ‰ 🎊 ✨ 🎈 πŸŽ†</div>
662
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
663
+ alt="SAM-Z-1"
664
+ style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 24px rgba(0,0,0,0.2);">
665
+ <h1>πŸ€– SAM-Z-1 Chat πŸ€–</h1>
666
+ <p><strong>LATEST RELEASE!</strong> Our <strong>Best</strong> non-reasoning model</p>
667
+ <div class="twin-badge">Twin of SAM-X-1 (Reasoning Model)</div>
668
+ <p style="font-size: 0.9rem; margin-top: 1rem;">
669
+ 768D β€’ 16 Layers β€’ 12 Heads β€’ ~313M Parameters β€’ Trained on TPU v5e-8
670
+ </p>
671
+ <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
672
+ </div>
673
+ """)
674
+ else:
675
+ gr.HTML("""
676
+ <div class="header">
677
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
678
+ alt="SAM-Z-1"
679
+ style="max-width: 300px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 4px 16px rgba(0,0,0,0.15);">
680
+ <h1>πŸ€– SAM-Z-1 Chat</h1>
681
+ <p>Fast, direct responses without reasoning overhead</p>
682
+ <p style="font-size: 0.9rem; margin-top: 0.5rem;">
683
+ 768D β€’ 16 Layers β€’ 12 Heads β€’ Trained on TPU v5e-8
684
+ </p>
685
+ </div>
686
+ """)
687
+
688
+ with gr.Row():
689
+ with gr.Column(scale=4):
690
+ # Chat interface with bot avatar
691
+ chatbot = gr.Chatbot(
692
+ height=600,
693
+ show_label=False,
694
+ avatar_images=(
695
+ None,
696
+ "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg"
697
+ ),
698
+ bubble_full_width=False
699
+ )
700
+
701
+ with gr.Row():
702
+ msg = gr.Textbox(
703
+ placeholder="Type your message here..." if not FESTIVE else "Ask me anything! I'm the fast twin! ⚑",
704
+ show_label=False,
705
+ scale=8,
706
+ container=False
707
+ )
708
+ submit_btn = gr.Button("Send πŸš€" if FESTIVE else "Send", variant="primary", scale=1)
709
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
710
+
711
+ with gr.Row():
712
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
713
+ retry_btn = gr.Button("πŸ”„ Retry", size="sm")
714
+
715
+ with gr.Column(scale=1):
716
+ gr.Markdown("### βš™οΈ Generation Settings")
717
+
718
+ max_tokens = gr.Slider(
719
+ minimum=50,
720
+ maximum=1024,
721
+ value=512,
722
+ step=50,
723
+ label="Max Tokens",
724
+ info="Maximum length of response"
725
+ )
726
+
727
+ temperature = gr.Slider(
728
+ minimum=0.1,
729
+ maximum=2.0,
730
+ value=0.8,
731
+ step=0.1,
732
+ label="Temperature",
733
+ info="Higher = more creative"
734
+ )
735
+
736
+ top_k = gr.Slider(
737
+ minimum=1,
738
+ maximum=100,
739
+ value=40,
740
+ step=1,
741
+ label="Top-K",
742
+ info="Sample from top K tokens"
743
+ )
744
+
745
+ top_p = gr.Slider(
746
+ minimum=0.1,
747
+ maximum=1.0,
748
+ value=0.9,
749
+ step=0.05,
750
+ label="Top-P",
751
+ info="Nucleus sampling threshold"
752
+ )
753
+
754
+ repetition_penalty = gr.Slider(
755
+ minimum=1.0,
756
+ maximum=2.0,
757
+ value=1.1,
758
+ step=0.1,
759
+ label="Repetition Penalty",
760
+ info="Penalize repeated tokens"
761
+ )
762
 
763
  gr.Markdown("---")
 
 
 
 
 
 
 
 
764
 
765
+ # Model info
766
+ if FESTIVE:
767
+ gr.Markdown(f"""
768
+ ### 🎊 SAM-Z-1 Model Info
769
+
770
+ **🎯 The Fast Twin!**
771
+
772
+ **Type:** Direct Response Model
773
+ **Parameters:** ~313M
774
+ **Context:** {config['max_position_embeddings']} tokens
775
+ **Vocab:** {config['vocab_size']}
776
+ **Speed:** ⚑ Optimized with TF Functions
777
+
778
+ **Twin Model:**
779
+ - **SAM-X-1**: Reasoning model (uses `<think>` tags)
780
+ - **SAM-Z-1**: Fast model (no thinking, direct answers! πŸŽ‰)
781
+
782
+ **Note:** Model includes `<think>` tokens in vocab but doesn't use them. Training used same tokenizer as SAM-X-1.
783
+
784
+ **Architecture:**
785
+ - RoPE positional encoding
786
+ - SwiGLU activation
787
+ - RMSNorm layers
788
+ - No bias terms (efficient!)
789
+
790
+ **Training:**
791
+ - Trained from scratch
792
+ - TPU v5e-8 (8 cores)
793
+ - Mixed precision (bfloat16)
794
+ - Cosine decay schedule
795
+ """)
796
+ else:
797
+ gr.Markdown(f"""
798
+ ### πŸ“Š Model Info
799
+
800
+ **Architecture:** SAM-Z-1 (Direct Response)
801
+ **Parameters:** ~313M
802
+ **Context:** {config['max_position_embeddings']} tokens
803
+ **Vocab:** {config['vocab_size']}
804
+
805
+ **Twin Models:**
806
+ - SAM-X-1: Reasoning model (uses `<think>` tags)
807
+ - SAM-Z-1: Direct response model (no thinking)
808
+
809
+ **Note:** Vocab includes `<think>` tokens but model doesn't use them in generation.
810
+
811
+ **Features:**
812
+ - RoPE positional encoding
813
+ - SwiGLU activation
814
+ - RMSNorm layers
815
+ - TF-optimized inference
816
+ """)
817
+
818
+ # Example prompts
819
+ gr.Examples(
820
+ examples=[
821
+ "Hi! What can you do?",
822
+ "Explain quantum computing in simple terms",
823
+ "Write a short poem about AI",
824
+ "What's the capital of France?",
825
+ "How do I learn programming?",
826
+ "Tell me an interesting fact about space",
827
+ "What's the difference between you and SAM-X-1?",
828
+ "Why are you called the fast twin?",
829
+ ],
830
+ inputs=msg,
831
+ label="πŸ’‘ Try these examples" if not FESTIVE else "🎯 Try these examples!"
832
+ )
833
+
834
+ # Footer
835
+ if FESTIVE:
836
+ gr.HTML("""
837
+ <footer>
838
+ <p style="font-size: 1.2rem;"><strong>πŸŽ‰ SAM-Z-1 - LATEST RELEASE! πŸŽ‰</strong></p>
839
+ <p><strong>The Fast Twin</strong> - Direct responses without reasoning overhead</p>
840
+ <p style="font-size: 0.9rem; color: #999; margin-top: 0.5rem;">
841
+ Trained from scratch on TPU v5e-8 β€’ Built with TensorFlow & Gradio
842
+ </p>
843
+ <p style="font-size: 0.9rem; color: #999;">
844
+ Twin of SAM-X-1 (reasoning model) β€’ Same architecture, different training objective
845
+ </p>
846
+ <div style="margin-top: 1rem; font-size: 1.5rem;">
847
+ ⚑ πŸš€ πŸ’« ✨ 🎯
848
  </div>
849
+ </footer>
850
+ """)
851
+ else:
852
+ gr.HTML("""
853
+ <footer>
854
+ <p><strong>SAM-Z-1</strong> - Direct response language model</p>
855
+ <p style="font-size: 0.9rem; color: #999;">
856
+ Trained from scratch on TPU v5e-8 β€’ Built with TensorFlow & Gradio
857
+ </p>
858
+ <p style="font-size: 0.9rem; color: #999;">
859
+ Twin of SAM-X-1 (reasoning model)
860
+ </p>
861
+ </footer>
862
+ """)
863
+
864
+ # Event handlers
865
+ submit_event = msg.submit(
866
+ chat_stream,
867
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
868
+ outputs=[chatbot]
869
+ ).then(
870
+ lambda: "",
871
+ outputs=[msg]
872
+ )
 
 
873
 
874
+ click_event = submit_btn.click(
875
+ chat_stream,
876
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
877
+ outputs=[chatbot]
878
  ).then(
879
+ lambda: "",
880
+ outputs=[msg]
881
  )
882
 
883
+ # Stop button
884
+ stop_btn.click(
885
+ fn=stop_gen,
886
+ inputs=None,
887
+ outputs=None,
888
+ cancels=[submit_event, click_event]
889
+ )
890
 
891
+ clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
 
892
 
893
+ def retry_last(history, max_tok, temp, topk, topp, rep_pen):
894
+ if not history:
895
+ return history
896
+ last_user_msg = history[-1][0]
897
+ history = history[:-1]
898
+ for update in chat_stream(last_user_msg, history, max_tok, temp, topk, topp, rep_pen):
899
+ yield update
900
+
901
+ retry_event = retry_btn.click(
902
+ retry_last,
903
+ inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
904
+ outputs=[chatbot]
905
+ )
906
+
907
+ stop_btn.click(
908
+ fn=stop_gen,
909
+ inputs=None,
910
+ outputs=None,
911
+ cancels=[retry_event]
912
+ )
913
 
914
+ # Launch
915
  if __name__ == "__main__":
916
+ demo.queue(max_size=20)
917
+ demo.launch(
918
+ server_name="0.0.0.0",
919
+ server_port=7860,
920
+ share=False,
921
+ show_error=True
922
+ )