Keeby-smilyai commited on
Commit
0feb44a
Β·
verified Β·
1 Parent(s): 10fd5b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +592 -673
app.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- SAM-Z-1 Production API with Gradio UI
3
- OpenAI-compatible API interface for Hugging Face Spaces
4
- """
5
-
6
  import gradio as gr
7
  import tensorflow as tf
8
  import keras
@@ -12,23 +7,23 @@ import os
12
  from tokenizers import Tokenizer
13
  import numpy as np
14
  import time
15
- from typing import Dict, Any, List
16
 
17
  # ============================================================================
18
- # Configuration
19
  # ============================================================================
 
 
 
 
 
 
 
20
 
21
  MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
22
  CACHE_DIR = "./model_cache"
23
 
24
- # Global model storage
25
- model = None
26
- tokenizer = None
27
- config = None
28
- eos_token_id = None
29
-
30
  # ============================================================================
31
- # Model Architecture (same as original)
32
  # ============================================================================
33
 
34
  @keras.saving.register_keras_serializable()
@@ -41,14 +36,18 @@ class RotaryEmbedding(keras.layers.Layer):
41
  self.built_cache = False
42
 
43
  def build(self, input_shape):
 
44
  super().build(input_shape)
45
 
46
  def _build_cache(self):
 
47
  if not self.built_cache:
48
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
49
  t = tf.range(self.max_len, dtype=tf.float32)
50
  freqs = tf.einsum("i,j->ij", t, inv_freq)
51
  emb = tf.concat([freqs, freqs], axis=-1)
 
 
52
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
53
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
54
  self.built_cache = True
@@ -58,13 +57,17 @@ class RotaryEmbedding(keras.layers.Layer):
58
  return tf.concat([-x2, x1], axis=-1)
59
 
60
  def call(self, q, k):
 
61
  self._build_cache()
 
62
  seq_len = tf.shape(q)[2]
63
  dtype = q.dtype
64
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
65
  sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
 
66
  q_rotated = (q * cos) + (self.rotate_half(q) * sin)
67
  k_rotated = (k * cos) + (self.rotate_half(k) * sin)
 
68
  return q_rotated, k_rotated
69
 
70
  def get_config(self):
@@ -107,20 +110,25 @@ class TransformerBlock(keras.layers.Layer):
107
 
108
  self.pre_attn_norm = RMSNorm()
109
  self.pre_ffn_norm = RMSNorm()
 
110
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
111
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
112
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
113
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
 
114
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
 
115
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
116
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
117
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
 
118
  self.dropout = keras.layers.Dropout(dropout)
119
 
120
  def call(self, x, training=None):
121
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
122
  dtype = x.dtype
123
 
 
124
  res = x
125
  y = self.pre_attn_norm(x)
126
 
@@ -129,7 +137,9 @@ class TransformerBlock(keras.layers.Layer):
129
  v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
130
 
131
  q, k = self.rope(q, k)
 
132
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
 
133
  mask = tf.where(
134
  tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
135
  tf.constant(-1e9, dtype=dtype),
@@ -137,9 +147,11 @@ class TransformerBlock(keras.layers.Layer):
137
  )
138
  scores += mask
139
  attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
 
140
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
141
  x = res + self.dropout(self.out_proj(attn), training=training)
142
 
 
143
  res = x
144
  y = self.pre_ffn_norm(x)
145
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
@@ -149,9 +161,13 @@ class TransformerBlock(keras.layers.Layer):
149
  def get_config(self):
150
  config = super().get_config()
151
  config.update({
152
- "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim,
153
- "dropout": self.dropout_rate, "max_len": self.max_len,
154
- "rope_theta": self.rope_theta, "layer_idx": self.layer_idx
 
 
 
 
155
  })
156
  return config
157
 
@@ -171,20 +187,28 @@ class SAM1Model(keras.Model):
171
 
172
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
173
  block_args = {
174
- 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'],
175
- 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'],
176
- 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']
 
 
 
177
  }
178
 
179
- self.blocks = [TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
180
- for i in range(self.cfg['n_layers'])]
 
 
 
181
  self.norm = RMSNorm(name="final_norm")
182
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
183
 
184
  def call(self, input_ids, training=None):
185
  x = self.embed(input_ids)
 
186
  for block in self.blocks:
187
  x = block(x, training=training)
 
188
  return self.lm_head(self.norm(x))
189
 
190
  def get_config(self):
@@ -192,43 +216,60 @@ class SAM1Model(keras.Model):
192
  base_config['config'] = self.cfg
193
  return base_config
194
 
195
- # ============================================================================
196
- # Model Loading
197
- # ============================================================================
198
-
199
- print("πŸš€ Loading SAM-Z-1 Model for API...")
200
 
 
201
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
202
 
 
203
  try:
204
  weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
 
205
  use_checkpoint = True
206
- print("βœ… Found checkpoint weights")
207
- except:
208
  model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
209
  use_checkpoint = False
210
- print("βœ… Found saved model")
211
 
 
212
  with open(config_path, 'r') as f:
213
  config = json.load(f)
214
 
215
- eos_token_id = config.get('eos_token_id', 50256)
216
-
217
- # Create tokenizer
218
- print("πŸ“¦ Creating tokenizer...")
219
  from transformers import AutoTokenizer
 
220
  hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
221
- hf_tokenizer.add_special_tokens({
222
- "additional_special_tokens": ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
223
- })
224
 
 
 
 
 
 
225
  os.makedirs("./temp_tokenizer", exist_ok=True)
226
  hf_tokenizer.save_pretrained("./temp_tokenizer")
227
  tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
228
 
229
- # Load model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  if use_checkpoint:
231
- print("πŸ“¦ Building model and loading weights...")
 
 
232
  model_config = {
233
  'vocab_size': config['vocab_size'],
234
  'd_model': config['hidden_size'],
@@ -236,49 +277,119 @@ if use_checkpoint:
236
  'n_heads': config['num_attention_heads'],
237
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
238
  'max_len': config['max_position_embeddings'],
239
- 'dropout': 0.1,
240
  'rope_theta': config['rope_theta']
241
  }
 
242
  model = SAM1Model(config=model_config)
 
 
243
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
244
  _ = model(dummy_input, training=False)
 
 
 
 
 
245
  model.load_weights(weights_path)
 
 
246
  else:
247
- model = keras.models.load_model(model_path, compile=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
 
249
  @tf.function(reduce_retracing=True)
250
  def fast_forward(input_tensor):
 
251
  return model(input_tensor, training=False)
252
 
253
- print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, ~313M params")
 
 
 
 
254
 
255
  # ============================================================================
256
- # Generation Engine
257
  # ============================================================================
258
 
259
- def generate_tokens(
260
- input_ids: List[int],
261
  max_tokens: int = 512,
262
  temperature: float = 0.8,
263
  top_k: int = 40,
264
  top_p: float = 0.9,
265
  repetition_penalty: float = 1.1
266
  ):
267
- """Generator that yields tokens one at a time"""
 
 
 
 
 
 
 
 
 
 
268
  if len(input_ids) > config['max_position_embeddings'] - max_tokens:
269
  input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):]
270
 
271
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
 
 
 
 
272
  token_freq = {}
273
 
 
 
274
  for step in range(max_tokens):
 
 
 
 
 
 
 
275
  logits = fast_forward(input_tensor)
276
  next_token_logits = logits[0, -1, :].numpy()
277
 
278
- # Temperature
279
  next_token_logits = next_token_logits / temperature
280
 
281
- # Repetition penalty
282
  if repetition_penalty != 1.0:
283
  for token_id, freq in token_freq.items():
284
  if token_id < len(next_token_logits):
@@ -290,14 +401,16 @@ def generate_tokens(
290
  top_k_logits = next_token_logits[top_k_indices]
291
  top_k_probs = tf.nn.softmax(top_k_logits).numpy()
292
 
293
- # Top-p sampling
294
  if top_p < 1.0:
295
  sorted_indices = np.argsort(top_k_probs)[::-1]
296
  cumsum = np.cumsum(top_k_probs[sorted_indices])
297
  cutoff_idx = np.searchsorted(cumsum, top_p)
298
  nucleus_indices = sorted_indices[:cutoff_idx + 1]
 
299
  nucleus_logits = top_k_logits[nucleus_indices]
300
  nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
 
301
  sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
302
  next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
303
  else:
@@ -307,159 +420,196 @@ def generate_tokens(
307
  probs = tf.nn.softmax(next_token_logits).numpy()
308
  next_token_id = np.random.choice(len(probs), p=probs)
309
 
 
310
  if next_token_id == eos_token_id:
311
  break
312
 
 
313
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
314
 
315
- yield next_token_id
 
 
 
316
 
 
 
 
 
317
  input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
318
 
 
319
  if input_tensor.shape[1] > config['max_position_embeddings']:
320
  input_tensor = input_tensor[:, -config['max_position_embeddings']:]
 
 
 
 
 
 
 
 
 
 
321
 
322
  # ============================================================================
323
- # API Functions - FIXED FOR GRADIO
324
  # ============================================================================
325
 
326
- def chat_completion_api(
327
- messages_json: str,
328
- max_tokens: int,
329
- temperature: float,
330
- top_p: float,
331
- top_k: int,
332
- repetition_penalty: float,
333
- stream: bool
334
- ) -> str:
335
- """OpenAI-style chat completion API"""
336
- try:
337
- messages = json.loads(messages_json)
338
-
339
- # Format messages
340
- prompt = ""
341
- for msg in messages:
342
- role = msg.get("role", "user")
343
- content = msg.get("content", "")
344
-
345
- if role == "system":
346
- prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
347
- elif role == "user":
348
- prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
349
- elif role == "assistant":
350
- prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
351
-
352
- prompt += "<|im_start|>assistant\n"
353
-
354
- # Tokenize
355
- input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
356
-
357
- start_time = time.time()
358
- token_count = 0
359
- response_text = ""
360
-
361
- for token_id in generate_tokens(
362
- input_ids, max_tokens, temperature, top_k, top_p, repetition_penalty
363
- ):
364
- token_text = tokenizer.decode([token_id])
365
- response_text += token_text
366
- token_count += 1
367
-
368
- if "<|im_end|>" in response_text:
369
- response_text = response_text.split("<|im_end|>")[0]
370
- break
371
-
372
- elapsed = time.time() - start_time
373
-
374
- result = {
375
- "id": f"chatcmpl-{int(time.time())}",
376
- "object": "chat.completion",
377
- "created": int(time.time()),
378
- "model": "sam-z-1",
379
- "choices": [{
380
- "index": 0,
381
- "message": {
382
- "role": "assistant",
383
- "content": response_text.strip()
384
- },
385
- "finish_reason": "stop"
386
- }],
387
- "usage": {
388
- "prompt_tokens": len(input_ids),
389
- "completion_tokens": token_count,
390
- "total_tokens": len(input_ids) + token_count
391
- },
392
- "stats": {
393
- "elapsed_sec": round(elapsed, 2),
394
- "tokens_per_sec": round(token_count / elapsed if elapsed > 0 else 0, 1)
395
- }
396
- }
397
-
398
- return json.dumps(result, indent=2)
399
 
400
- except Exception as e:
401
- return json.dumps({"error": str(e)}, indent=2)
 
 
 
 
 
 
 
 
402
 
403
- def text_completion_api(
404
- prompt: str,
 
405
  max_tokens: int,
406
  temperature: float,
407
- top_p: float,
408
  top_k: int,
409
- repetition_penalty: float,
410
- stream: bool
411
- ) -> str:
412
- """OpenAI-style text completion API"""
413
- try:
414
- input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
415
-
416
- start_time = time.time()
417
- token_count = 0
418
- response_text = ""
419
-
420
- for token_id in generate_tokens(
421
- input_ids, max_tokens, temperature, top_k, top_p, repetition_penalty
422
- ):
423
- token_text = tokenizer.decode([token_id])
424
- response_text += token_text
425
- token_count += 1
426
-
427
- elapsed = time.time() - start_time
428
-
429
- result = {
430
- "id": f"cmpl-{int(time.time())}",
431
- "object": "text_completion",
432
- "created": int(time.time()),
433
- "model": "sam-z-1",
434
- "choices": [{
435
- "text": response_text,
436
- "index": 0,
437
- "finish_reason": "stop"
438
- }],
439
- "usage": {
440
- "prompt_tokens": len(input_ids),
441
- "completion_tokens": token_count,
442
- "total_tokens": len(input_ids) + token_count
443
- },
444
- "stats": {
445
- "elapsed_sec": round(elapsed, 2),
446
- "tokens_per_sec": round(token_count / elapsed if elapsed > 0 else 0, 1)
447
- }
448
- }
449
-
450
- return json.dumps(result, indent=2)
451
 
452
- except Exception as e:
453
- return json.dumps({"error": str(e)}, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
  # ============================================================================
456
- # Gradio UI with API Routes
457
  # ============================================================================
458
 
459
- custom_css = """
460
- .api-container {
461
- max-width: 1400px;
462
- margin: auto;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  }
464
 
465
  .header {
@@ -471,525 +621,294 @@ custom_css = """
471
  margin-bottom: 2rem;
472
  }
473
 
474
- .endpoint-card {
 
 
 
 
 
 
 
 
 
 
 
475
  background: #f8f9fa;
476
- padding: 1.5rem;
477
  border-radius: 8px;
478
  border-left: 4px solid #667eea;
479
  margin: 1rem 0;
480
  }
 
 
 
 
 
 
 
 
481
  """
482
 
483
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="SAM-Z-1 API") as demo:
484
- gr.HTML("""
485
- <div class="header">
486
- <h1>πŸš€ SAM-Z-1 API Server</h1>
487
- <p>OpenAI-Compatible API for SAM-Z-1 Language Model</p>
488
- <p style="font-size: 0.9rem; opacity: 0.9;">
489
- 313M Parameters β€’ 768D β€’ 16 Layers β€’ TensorFlow Optimized
490
- </p>
491
- </div>
492
- """)
493
-
494
- with gr.Tabs():
495
- # ========== Chat Completion Tab ==========
496
- with gr.Tab("πŸ’¬ Chat Completion"):
497
- gr.Markdown("""
498
- ### Chat Completions API
499
- OpenAI-compatible chat completion endpoint
500
- """)
501
-
502
- with gr.Row():
503
- with gr.Column(scale=1):
504
- messages_input = gr.Code(
505
- label="Messages (JSON)",
506
- language="json",
507
- value=json.dumps([
508
- {"role": "user", "content": "Hello! Who are you?"}
509
- ], indent=2),
510
- lines=10
511
- )
512
-
513
- with gr.Row():
514
- chat_max_tokens = gr.Slider(50, 1024, 512, step=50, label="Max Tokens")
515
- chat_temperature = gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature")
516
-
517
- with gr.Row():
518
- chat_top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top P")
519
- chat_top_k = gr.Slider(1, 100, 40, step=1, label="Top K")
520
-
521
- chat_rep_penalty = gr.Slider(1.0, 2.0, 1.1, step=0.1, label="Repetition Penalty")
522
- chat_stream = gr.Checkbox(label="Stream Response (Not implemented in UI)", value=False)
523
-
524
- chat_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
525
-
526
- with gr.Column(scale=1):
527
- chat_output = gr.Code(
528
- label="API Response (JSON)",
529
- language="json",
530
- lines=20
531
- )
532
-
533
- gr.Markdown("""
534
- ### Python Example with Gradio Client
535
- ```python
536
- from gradio_client import Client
537
-
538
- client = Client("YOUR-SPACE-URL")
539
-
540
- messages = [
541
- {"role": "user", "content": "Hello! Who are you?"}
542
- ]
543
-
544
- result = client.predict(
545
- messages_json=json.dumps(messages),
546
- max_tokens=512,
547
- temperature=0.8,
548
- top_p=0.9,
549
- top_k=40,
550
- repetition_penalty=1.1,
551
- stream=False,
552
- api_name="/chat_completions"
553
  )
554
-
555
- print(result)
556
- ```
557
- """)
558
-
559
- # ========== Text Completion Tab ==========
560
- with gr.Tab("πŸ“ Text Completion"):
561
- gr.Markdown("""
562
- ### Text Completions API
563
- OpenAI-compatible text completion endpoint
564
- """)
565
 
566
  with gr.Row():
567
- with gr.Column(scale=1):
568
- prompt_input = gr.Textbox(
569
- label="Prompt",
570
- placeholder="Once upon a time...",
571
- lines=5
572
- )
573
-
574
- with gr.Row():
575
- text_max_tokens = gr.Slider(50, 1024, 512, step=50, label="Max Tokens")
576
- text_temperature = gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature")
577
-
578
- with gr.Row():
579
- text_top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top P")
580
- text_top_k = gr.Slider(1, 100, 40, step=1, label="Top K")
581
-
582
- text_rep_penalty = gr.Slider(1.0, 2.0, 1.1, step=0.1, label="Repetition Penalty")
583
- text_stream = gr.Checkbox(label="Stream Response (Not implemented in UI)", value=False)
584
-
585
- text_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
586
-
587
- with gr.Column(scale=1):
588
- text_output = gr.Code(
589
- label="API Response (JSON)",
590
- language="json",
591
- lines=20
592
- )
593
 
594
- gr.Markdown("""
595
- ### Python Example with Gradio Client
596
- ```python
597
- from gradio_client import Client
598
-
599
- client = Client("YOUR-SPACE-URL")
600
-
601
- result = client.predict(
602
- prompt="Once upon a time",
603
- max_tokens=512,
604
- temperature=0.8,
605
- top_p=0.9,
606
- top_k=40,
607
- repetition_penalty=1.1,
608
- stream=False,
609
- api_name="/text_completions"
610
  )
611
-
612
- print(result)
613
- ```
614
- """)
615
-
616
- # ========== Documentation Tab ==========
617
- with gr.Tab("πŸ“– Documentation"):
618
- gr.Markdown("""
619
- # SAM-Z-1 API Documentation
620
-
621
- ## Model Information
622
- - **Model**: SAM-Z-1 (Direct Response Model)
623
- - **Parameters**: ~313M
624
- - **Architecture**: Transformer with RoPE, SwiGLU, RMSNorm
625
- - **Context Length**: {config['max_position_embeddings']} tokens
626
- - **Vocabulary Size**: {config['vocab_size']}
627
-
628
- ## Using the API
629
 
630
- ### Method 1: Gradio Client (Recommended)
631
-
632
- Install the Gradio client:
633
- ```bash
634
- pip install gradio_client
635
- ```
636
-
637
- **Chat Completion:**
638
- ```python
639
- from gradio_client import Client
640
- import json
641
-
642
- client = Client("https://YOUR-SPACE.hf.space")
643
-
644
- messages = [
645
- {{"role": "user", "content": "What is Python?"}}
646
- ]
647
-
648
- result = client.predict(
649
- messages_json=json.dumps(messages),
650
- max_tokens=512,
651
- temperature=0.8,
652
- top_p=0.9,
653
- top_k=40,
654
- repetition_penalty=1.1,
655
- stream=False,
656
- api_name="/chat_completions"
657
  )
658
-
659
- response = json.loads(result)
660
- print(response["choices"][0]["message"]["content"])
661
- ```
662
 
663
- **Text Completion:**
664
- ```python
665
- result = client.predict(
666
- prompt="Once upon a time",
667
- max_tokens=512,
668
- temperature=0.8,
669
- top_p=0.9,
670
- top_k=40,
671
- repetition_penalty=1.1,
672
- stream=False,
673
- api_name="/text_completions"
674
  )
675
-
676
- response = json.loads(result)
677
- print(response["choices"][0]["text"])
678
- ```
679
-
680
- ### Method 2: Direct HTTP Requests
681
-
682
- **Chat Completion:**
683
- ```python
684
- import requests
685
- import json
686
-
687
- url = "https://YOUR-SPACE.hf.space/call/chat_completions"
688
-
689
- payload = {{
690
- "data": [
691
- json.dumps([{{"role": "user", "content": "Hello!"}}]), # messages_json
692
- 512, # max_tokens
693
- 0.8, # temperature
694
- 0.9, # top_p
695
- 40, # top_k
696
- 1.1, # repetition_penalty
697
- False # stream
698
- ]
699
- }}
700
-
701
- response = requests.post(url, json=payload)
702
- print(response.json())
703
- ```
704
-
705
- ## API Endpoints
706
-
707
- ### Chat Completions
708
- - **API Name**: `/chat_completions`
709
- - **URL**: `https://YOUR-SPACE.hf.space/call/chat_completions`
710
-
711
- **Parameters:**
712
- 1. `messages_json` (str): JSON string of messages array
713
- 2. `max_tokens` (int): Maximum tokens to generate (50-1024)
714
- 3. `temperature` (float): Sampling temperature (0.1-2.0)
715
- 4. `top_p` (float): Nucleus sampling threshold (0.1-1.0)
716
- 5. `top_k` (int): Top-K sampling (1-100)
717
- 6. `repetition_penalty` (float): Penalty for repetition (1.0-2.0)
718
- 7. `stream` (bool): Stream response (UI only, not functional)
719
-
720
- ### Text Completions
721
- - **API Name**: `/text_completions`
722
- - **URL**: `https://YOUR-SPACE.hf.space/call/text_completions`
723
 
724
- **Parameters:**
725
- 1. `prompt` (str): Text prompt
726
- 2. `max_tokens` (int): Maximum tokens to generate
727
- 3. `temperature` (float): Sampling temperature
728
- 4. `top_p` (float): Nucleus sampling threshold
729
- 5. `top_k` (int): Top-K sampling
730
- 6. `repetition_penalty` (float): Penalty for repetition
731
- 7. `stream` (bool): Stream response (UI only)
732
-
733
- ## Response Format
734
-
735
- **Chat Completion Response:**
736
- ```json
737
- {{
738
- "id": "chatcmpl-1234567890",
739
- "object": "chat.completion",
740
- "created": 1234567890,
741
- "model": "sam-z-1",
742
- "choices": [{{
743
- "index": 0,
744
- "message": {{
745
- "role": "assistant",
746
- "content": "Response text here"
747
- }},
748
- "finish_reason": "stop"
749
- }}],
750
- "usage": {{
751
- "prompt_tokens": 10,
752
- "completion_tokens": 20,
753
- "total_tokens": 30
754
- }},
755
- "stats": {{
756
- "elapsed_sec": 1.5,
757
- "tokens_per_sec": 13.3
758
- }}
759
- }}
760
- ```
761
-
762
- **Text Completion Response:**
763
- ```json
764
- {{
765
- "id": "cmpl-1234567890",
766
- "object": "text_completion",
767
- "created": 1234567890,
768
- "model": "sam-z-1",
769
- "choices": [{{
770
- "text": "Completion text here",
771
- "index": 0,
772
- "finish_reason": "stop"
773
- }}],
774
- "usage": {{
775
- "prompt_tokens": 5,
776
- "completion_tokens": 15,
777
- "total_tokens": 20
778
- }},
779
- "stats": {{
780
- "elapsed_sec": 1.2,
781
- "tokens_per_sec": 12.5
782
- }}
783
- }}
784
- ```
785
-
786
- ## Complete Example Script
787
-
788
- ```python
789
- #!/usr/bin/env python3
790
- """
791
- SAM-Z-1 API Client Example
792
- """
793
- from gradio_client import Client
794
- import json
795
-
796
- # Initialize client
797
- client = Client("https://YOUR-SPACE.hf.space")
798
-
799
- def chat(message, history=[]):
800
- \"\"\"Send a chat message\"\"\"
801
- messages = history + [{{"role": "user", "content": message}}]
802
-
803
- result = client.predict(
804
- messages_json=json.dumps(messages),
805
- max_tokens=512,
806
- temperature=0.8,
807
- top_p=0.9,
808
- top_k=40,
809
- repetition_penalty=1.1,
810
- stream=False,
811
- api_name="/chat_completions"
812
- )
813
-
814
- response = json.loads(result)
815
- assistant_msg = response["choices"][0]["message"]["content"]
816
-
817
- # Update history
818
- history.append({{"role": "user", "content": message}})
819
- history.append({{"role": "assistant", "content": assistant_msg}})
820
-
821
- return assistant_msg, history
822
-
823
- def complete(prompt):
824
- \"\"\"Complete text\"\"\"
825
- result = client.predict(
826
- prompt=prompt,
827
- max_tokens=512,
828
- temperature=0.8,
829
- top_p=0.9,
830
- top_k=40,
831
- repetition_penalty=1.1,
832
- stream=False,
833
- api_name="/text_completions"
834
- )
835
-
836
- response = json.loads(result)
837
- return response["choices"][0]["text"]
838
-
839
- # Example usage
840
- if __name__ == "__main__":
841
- # Chat example
842
- print("=== Chat Example ===")
843
- history = []
844
-
845
- response, history = chat("Hello! Who are you?", history)
846
- print(f"Assistant: {{response}}\\n")
847
-
848
- response, history = chat("What can you help me with?", history)
849
- print(f"Assistant: {{response}}\\n")
850
-
851
- # Text completion example
852
- print("\\n=== Text Completion Example ===")
853
- completion = complete("Once upon a time in a distant galaxy")
854
- print(f"Completion: {{completion}}")
855
- ```
856
-
857
- ## Parameters Guide
858
-
859
- ### Temperature (0.1 - 2.0)
860
- - **Low (0.1-0.5)**: More focused, deterministic, factual
861
- - **Medium (0.6-0.9)**: Balanced creativity and coherence
862
- - **High (1.0-2.0)**: More creative, diverse, unpredictable
863
-
864
- ### Top-P (0.1 - 1.0)
865
- - Controls diversity via nucleus sampling
866
- - **0.9** (default): Good balance
867
- - Lower values = more focused
868
- - Higher values = more diverse
869
-
870
- ### Top-K (1 - 100)
871
- - Limits vocabulary to top K tokens
872
- - **40** (default): Good balance
873
- - Lower values = more focused
874
- - Higher values = more diverse
875
-
876
- ### Repetition Penalty (1.0 - 2.0)
877
- - **1.0**: No penalty
878
- - **1.1** (default): Slight penalty
879
- - **1.5+**: Strong penalty (use if model repeats)
880
-
881
- ## Rate Limits & Performance
882
 
883
- - **Concurrent Requests**: Supported via Gradio queue
884
- - **Average Speed**: 10-20 tokens/sec on CPU
885
- - **Context Window**: {config['max_position_embeddings']} tokens
886
- - **Queue Size**: Up to 20 concurrent requests
 
 
 
 
887
 
888
- ## Error Handling
889
 
890
- ```python
891
- try:
892
- result = client.predict(
893
- messages_json=json.dumps(messages),
894
- max_tokens=512,
895
- temperature=0.8,
896
- top_p=0.9,
897
- top_k=40,
898
- repetition_penalty=1.1,
899
- stream=False,
900
- api_name="/chat_completions"
901
- )
902
- response = json.loads(result)
903
-
904
- if "error" in response:
905
- print(f"API Error: {{response['error']}}")
906
- else:
907
- print(response["choices"][0]["message"]["content"])
908
 
909
- except Exception as e:
910
- print(f"Request failed: {{e}}")
911
- ```
912
-
913
- ## Troubleshooting
914
-
915
- **Connection Issues:**
916
- - Verify Space URL is correct
917
- - Check if Space is running
918
- - Ensure gradio_client is installed
919
-
920
- **Slow Responses:**
921
- - Reduce `max_tokens`
922
- - Lower `top_k` value
923
- - Use shorter prompts
924
-
925
- **Repetitive Output:**
926
- - Increase `repetition_penalty` (try 1.2-1.5)
927
- - Adjust `temperature` higher
928
- - Use `top_p` sampling
929
-
930
- **Incoherent Output:**
931
- - Lower `temperature` (try 0.5-0.7)
932
- - Reduce `top_k` (try 20-30)
933
- - Ensure prompt is clear and well-formatted
934
-
935
- ## Chat Template Format
936
-
937
- The model uses ChatML format:
938
- ```
939
- <|im_start|>system
940
- System message here<|im_end|>
941
- <|im_start|>user
942
- User message here<|im_end|>
943
- <|im_start|>assistant
944
- Assistant response here<|im_end|>
945
- ```
946
-
947
- ## Tips for Best Results
948
-
949
- 1. **Use clear, specific prompts**
950
- 2. **Lower temperature for factual tasks**
951
- 3. **Higher temperature for creative tasks**
952
- 4. **Adjust repetition penalty if model repeats phrases**
953
- 5. **Keep context under {config['max_position_embeddings']} tokens**
954
- 6. **Use system messages to set behavior**
955
-
956
- ## Model Capabilities
957
-
958
- βœ… General conversation
959
- βœ… Question answering
960
- βœ… Code generation
961
- βœ… Creative writing
962
- βœ… Text completion
963
- βœ… Instruction following
964
-
965
- ❌ Does NOT use reasoning tokens (`<think>` tags)
966
- ❌ Not fine-tuned for specific domains
967
-
968
- ---
969
-
970
- **Model**: SAM-Z-1 | **API Version**: 1.0
971
- **Support**: Open an issue on the Space for bugs or questions
972
- """)
973
-
974
- # ========== API Routes - MUST USE api_name parameter ==========
975
- chat_btn.click(
976
- fn=chat_completion_api,
977
- inputs=[
978
- messages_input, chat_max_tokens, chat_temperature,
979
- chat_top_p, chat_top_k, chat_rep_penalty, chat_stream
980
  ],
981
- outputs=[chat_output],
982
- api_name="chat_completions" # This creates /call/chat_completions endpoint
983
  )
984
 
985
- text_btn.click(
986
- fn=text_completion_api,
987
- inputs=[
988
- prompt_input, text_max_tokens, text_temperature,
989
- text_top_p, text_top_k, text_rep_penalty, text_stream
990
- ],
991
- outputs=[text_output],
992
- api_name="text_completions" # This creates /call/text_completions endpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993
  )
994
 
995
  # Launch
 
 
 
 
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import keras
 
7
  from tokenizers import Tokenizer
8
  import numpy as np
9
  import time
 
10
 
11
  # ============================================================================
12
+ # 🎊 FESTIVE MODE TOGGLE 🎊
13
  # ============================================================================
14
+ FESTIVE = True # Set to False for production-only mode
15
+
16
+ # ============================================================================
17
+ # Configuration & Model Loading
18
+ # ============================================================================
19
+
20
+ print("πŸš€ Loading SAM-Z-1 Model...")
21
 
22
  MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
23
  CACHE_DIR = "./model_cache"
24
 
 
 
 
 
 
 
25
  # ============================================================================
26
+ # Model Architecture Definitions (FIXED for model loading)
27
  # ============================================================================
28
 
29
  @keras.saving.register_keras_serializable()
 
36
  self.built_cache = False
37
 
38
  def build(self, input_shape):
39
+ # Use the ORIGINAL training code - compute cache on first call, not in build
40
  super().build(input_shape)
41
 
42
  def _build_cache(self):
43
+ """Build RoPE cache on first forward pass"""
44
  if not self.built_cache:
45
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
46
  t = tf.range(self.max_len, dtype=tf.float32)
47
  freqs = tf.einsum("i,j->ij", t, inv_freq)
48
  emb = tf.concat([freqs, freqs], axis=-1)
49
+
50
+ # Store as numpy arrays to avoid graph issues
51
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
52
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
53
  self.built_cache = True
 
57
  return tf.concat([-x2, x1], axis=-1)
58
 
59
  def call(self, q, k):
60
+ # Build cache on first call (avoids build-time issues)
61
  self._build_cache()
62
+
63
  seq_len = tf.shape(q)[2]
64
  dtype = q.dtype
65
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
66
  sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
67
+
68
  q_rotated = (q * cos) + (self.rotate_half(q) * sin)
69
  k_rotated = (k * cos) + (self.rotate_half(k) * sin)
70
+
71
  return q_rotated, k_rotated
72
 
73
  def get_config(self):
 
110
 
111
  self.pre_attn_norm = RMSNorm()
112
  self.pre_ffn_norm = RMSNorm()
113
+
114
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
115
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
116
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
117
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
118
+
119
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
120
+
121
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
122
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
123
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
124
+
125
  self.dropout = keras.layers.Dropout(dropout)
126
 
127
  def call(self, x, training=None):
128
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
129
  dtype = x.dtype
130
 
131
+ # Attention
132
  res = x
133
  y = self.pre_attn_norm(x)
134
 
 
137
  v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
138
 
139
  q, k = self.rope(q, k)
140
+
141
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
142
+
143
  mask = tf.where(
144
  tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
145
  tf.constant(-1e9, dtype=dtype),
 
147
  )
148
  scores += mask
149
  attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
150
+
151
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
152
  x = res + self.dropout(self.out_proj(attn), training=training)
153
 
154
+ # FFN (SwiGLU)
155
  res = x
156
  y = self.pre_ffn_norm(x)
157
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
 
161
  def get_config(self):
162
  config = super().get_config()
163
  config.update({
164
+ "d_model": self.d_model,
165
+ "n_heads": self.n_heads,
166
+ "ff_dim": self.ff_dim,
167
+ "dropout": self.dropout_rate,
168
+ "max_len": self.max_len,
169
+ "rope_theta": self.rope_theta,
170
+ "layer_idx": self.layer_idx
171
  })
172
  return config
173
 
 
187
 
188
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
189
  block_args = {
190
+ 'd_model': self.cfg['d_model'],
191
+ 'n_heads': self.cfg['n_heads'],
192
+ 'ff_dim': ff_dim,
193
+ 'dropout': self.cfg['dropout'],
194
+ 'max_len': self.cfg['max_len'],
195
+ 'rope_theta': self.cfg['rope_theta']
196
  }
197
 
198
+ self.blocks = []
199
+ for i in range(self.cfg['n_layers']):
200
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
201
+ self.blocks.append(block)
202
+
203
  self.norm = RMSNorm(name="final_norm")
204
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
205
 
206
  def call(self, input_ids, training=None):
207
  x = self.embed(input_ids)
208
+
209
  for block in self.blocks:
210
  x = block(x, training=training)
211
+
212
  return self.lm_head(self.norm(x))
213
 
214
  def get_config(self):
 
216
  base_config['config'] = self.cfg
217
  return base_config
218
 
219
+ print("βœ… Model architecture registered")
 
 
 
 
220
 
221
+ # Download model files
222
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
223
 
224
+ # Try to download checkpoint weights first (more reliable)
225
  try:
226
  weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
227
+ print("βœ… Found checkpoint weights (ckpt.weights.h5)")
228
  use_checkpoint = True
229
+ except Exception as e:
230
+ print(f"⚠️ Checkpoint not found, falling back to model.keras: {e}")
231
  model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
232
  use_checkpoint = False
 
233
 
234
+ # Load config
235
  with open(config_path, 'r') as f:
236
  config = json.load(f)
237
 
238
+ # Create tokenizer from scratch
239
+ print("πŸ“¦ Creating tokenizer from GPT-2 base...")
 
 
240
  from transformers import AutoTokenizer
241
+
242
  hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
 
 
243
 
244
+ # Add custom tokens to match model's vocab size
245
+ custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
246
+ hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
247
+
248
+ # Save and reload as tokenizers format
249
  os.makedirs("./temp_tokenizer", exist_ok=True)
250
  hf_tokenizer.save_pretrained("./temp_tokenizer")
251
  tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
252
 
253
+ print(f"βœ… Tokenizer created with vocab size: {tokenizer.get_vocab_size()}")
254
+ print(f" Custom tokens added: {custom_tokens}")
255
+ print(f" Model vocab size: {config.get('vocab_size', 'unknown')}")
256
+
257
+ # Verify vocab sizes match
258
+ if tokenizer.get_vocab_size() != config.get('vocab_size'):
259
+ print(f"⚠️ WARNING: Tokenizer vocab ({tokenizer.get_vocab_size()}) != Model vocab ({config.get('vocab_size')})")
260
+ print(f" Model was trained with these tokens, but SAM-Z-1 doesn't use <think> tags in generation")
261
+
262
+ eos_token_id = config.get('eos_token_id', 50256)
263
+
264
+ # ==============================================================================
265
+ # Load Model - Priority: checkpoint weights > saved model
266
+ # ==============================================================================
267
+ print("\nπŸ”„ Loading model...")
268
+
269
  if use_checkpoint:
270
+ print("πŸ“¦ Building model from config and loading checkpoint weights...")
271
+
272
+ # Build model from scratch with config
273
  model_config = {
274
  'vocab_size': config['vocab_size'],
275
  'd_model': config['hidden_size'],
 
277
  'n_heads': config['num_attention_heads'],
278
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
279
  'max_len': config['max_position_embeddings'],
280
+ 'dropout': 0.1, # Default dropout
281
  'rope_theta': config['rope_theta']
282
  }
283
+
284
  model = SAM1Model(config=model_config)
285
+
286
+ # Build model by running a dummy forward pass
287
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
288
  _ = model(dummy_input, training=False)
289
+
290
+ print(f"βœ… Model architecture built: {model.count_params():,} parameters")
291
+
292
+ # Load checkpoint weights
293
+ print(f"πŸ“₯ Loading checkpoint weights from: {weights_path}")
294
  model.load_weights(weights_path)
295
+ print("βœ… Checkpoint weights loaded successfully!")
296
+
297
  else:
298
+ print("πŸ“¦ Loading full saved model...")
299
+ try:
300
+ model = keras.models.load_model(model_path, compile=False)
301
+ print("βœ… Model loaded successfully")
302
+ except Exception as e:
303
+ print(f"❌ Failed to load model: {e}")
304
+ print("\nπŸ”„ Trying alternative: building from config + loading weights...")
305
+
306
+ # Fallback to building model
307
+ model_config = {
308
+ 'vocab_size': config['vocab_size'],
309
+ 'd_model': config['hidden_size'],
310
+ 'n_layers': config['num_hidden_layers'],
311
+ 'n_heads': config['num_attention_heads'],
312
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
313
+ 'max_len': config['max_position_embeddings'],
314
+ 'dropout': 0.1,
315
+ 'rope_theta': config['rope_theta']
316
+ }
317
+
318
+ model = SAM1Model(config=model_config)
319
+ dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
320
+ _ = model(dummy_input, training=False)
321
+
322
+ # Try to load weights from model.keras
323
+ try:
324
+ temp_model = keras.models.load_model(model_path, compile=False)
325
+ model.set_weights(temp_model.get_weights())
326
+ print("βœ… Weights transferred successfully")
327
+ except:
328
+ print("❌ Could not load weights - model may not work correctly!")
329
+ raise
330
 
331
+ # Create optimized inference function
332
  @tf.function(reduce_retracing=True)
333
  def fast_forward(input_tensor):
334
+ """TF-optimized forward pass for faster generation"""
335
  return model(input_tensor, training=False)
336
 
337
+ print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
338
+ print(f"βœ… TF function optimization enabled for faster inference")
339
+
340
+ # Global stop flag
341
+ stop_generation = False
342
 
343
  # ============================================================================
344
+ # Generation Function with Streaming & Stop Button
345
  # ============================================================================
346
 
347
+ def generate_stream(
348
+ prompt: str,
349
  max_tokens: int = 512,
350
  temperature: float = 0.8,
351
  top_k: int = 40,
352
  top_p: float = 0.9,
353
  repetition_penalty: float = 1.1
354
  ):
355
+ """Generate text with streaming output and stop support"""
356
+ global stop_generation
357
+ stop_generation = False
358
+
359
+ # Tokenize prompt
360
+ input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
361
+
362
+ if len(input_ids) == 0:
363
+ yield "⚠️ Empty prompt after tokenization"
364
+ return
365
+
366
  if len(input_ids) > config['max_position_embeddings'] - max_tokens:
367
  input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):]
368
 
369
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
370
+ generated_text = ""
371
+ token_count = 0
372
+
373
+ # Track token frequencies for repetition penalty
374
  token_freq = {}
375
 
376
+ start_time = time.time()
377
+
378
  for step in range(max_tokens):
379
+ # Check stop flag
380
+ if stop_generation:
381
+ generated_text += "\n\n*[Generation stopped by user]*"
382
+ yield generated_text
383
+ break
384
+
385
+ # Get logits using optimized TF function
386
  logits = fast_forward(input_tensor)
387
  next_token_logits = logits[0, -1, :].numpy()
388
 
389
+ # Apply temperature
390
  next_token_logits = next_token_logits / temperature
391
 
392
+ # Apply repetition penalty
393
  if repetition_penalty != 1.0:
394
  for token_id, freq in token_freq.items():
395
  if token_id < len(next_token_logits):
 
401
  top_k_logits = next_token_logits[top_k_indices]
402
  top_k_probs = tf.nn.softmax(top_k_logits).numpy()
403
 
404
+ # Top-p (nucleus) sampling
405
  if top_p < 1.0:
406
  sorted_indices = np.argsort(top_k_probs)[::-1]
407
  cumsum = np.cumsum(top_k_probs[sorted_indices])
408
  cutoff_idx = np.searchsorted(cumsum, top_p)
409
  nucleus_indices = sorted_indices[:cutoff_idx + 1]
410
+
411
  nucleus_logits = top_k_logits[nucleus_indices]
412
  nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
413
+
414
  sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
415
  next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
416
  else:
 
420
  probs = tf.nn.softmax(next_token_logits).numpy()
421
  next_token_id = np.random.choice(len(probs), p=probs)
422
 
423
+ # Stop on EOS
424
  if next_token_id == eos_token_id:
425
  break
426
 
427
+ # Update token frequency
428
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
429
 
430
+ # Decode and yield
431
+ token_text = tokenizer.decode([next_token_id])
432
+ generated_text += token_text
433
+ token_count += 1
434
 
435
+ # Yield progressive output
436
+ yield generated_text
437
+
438
+ # Update input
439
  input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
440
 
441
+ # Truncate if too long
442
  if input_tensor.shape[1] > config['max_position_embeddings']:
443
  input_tensor = input_tensor[:, -config['max_position_embeddings']:]
444
+
445
+ # Calculate stats
446
+ elapsed = time.time() - start_time
447
+ tokens_per_sec = token_count / elapsed if elapsed > 0 else 0
448
+
449
+ # Add generation stats
450
+ if token_count > 0 and not stop_generation:
451
+ generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({tokens_per_sec:.1f} tok/s)]*"
452
+
453
+ yield generated_text
454
 
455
  # ============================================================================
456
+ # Chat Interface Logic
457
  # ============================================================================
458
 
459
+ def format_chat_prompt(message: str, history: list) -> str:
460
+ """Format message history into chat prompt"""
461
+ prompt = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
+ # Add history
464
+ for user_msg, assistant_msg in history:
465
+ prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
466
+ if assistant_msg:
467
+ prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
468
+
469
+ # Add current message
470
+ prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
471
+
472
+ return prompt
473
 
474
+ def chat_stream(
475
+ message: str,
476
+ history: list,
477
  max_tokens: int,
478
  temperature: float,
 
479
  top_k: int,
480
+ top_p: float,
481
+ repetition_penalty: float
482
+ ):
483
+ """Streaming chat response"""
484
+ if not message.strip():
485
+ yield history
486
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
+ # Format prompt
489
+ prompt = format_chat_prompt(message, history)
490
+
491
+ # Generate with streaming
492
+ partial_response = ""
493
+ for generated in generate_stream(
494
+ prompt,
495
+ max_tokens=max_tokens,
496
+ temperature=temperature,
497
+ top_k=top_k,
498
+ top_p=top_p,
499
+ repetition_penalty=repetition_penalty
500
+ ):
501
+ partial_response = generated
502
+
503
+ # Stop at end tags
504
+ if "<|im_end|>" in partial_response:
505
+ partial_response = partial_response.split("<|im_end|>")[0]
506
+
507
+ # Update history
508
+ yield history + [[message, partial_response.strip()]]
509
+
510
+ def stop_gen():
511
+ """Stop generation callback"""
512
+ global stop_generation
513
+ stop_generation = True
514
+ return None
515
 
516
  # ============================================================================
517
+ # Gradio UI
518
  # ============================================================================
519
 
520
+ # Festive CSS
521
+ festive_css = """
522
+ .gradio-container {
523
+ max-width: 1200px !important;
524
+ margin: auto !important;
525
+ }
526
+
527
+ .header {
528
+ text-align: center;
529
+ padding: 2rem;
530
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
531
+ color: white;
532
+ border-radius: 12px;
533
+ margin-bottom: 2rem;
534
+ box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);
535
+ animation: pulse 2s ease-in-out infinite;
536
+ }
537
+
538
+ @keyframes pulse {
539
+ 0%, 100% { transform: scale(1); }
540
+ 50% { transform: scale(1.02); }
541
+ }
542
+
543
+ .header h1 {
544
+ font-size: 2.8rem;
545
+ margin-bottom: 0.5rem;
546
+ font-weight: 700;
547
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
548
+ }
549
+
550
+ .header p {
551
+ font-size: 1.1rem;
552
+ opacity: 0.95;
553
+ }
554
+
555
+ .celebration {
556
+ font-size: 2rem;
557
+ margin: 0.5rem;
558
+ animation: bounce 1s ease infinite;
559
+ }
560
+
561
+ @keyframes bounce {
562
+ 0%, 100% { transform: translateY(0); }
563
+ 50% { transform: translateY(-10px); }
564
+ }
565
+
566
+ .stats-card {
567
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
568
+ padding: 1.5rem;
569
+ border-radius: 12px;
570
+ border-left: 4px solid #f5576c;
571
+ margin: 1rem 0;
572
+ box-shadow: 0 4px 16px rgba(252, 182, 159, 0.3);
573
+ }
574
+
575
+ .twin-badge {
576
+ display: inline-block;
577
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
578
+ color: white;
579
+ padding: 0.5rem 1rem;
580
+ border-radius: 20px;
581
+ font-weight: bold;
582
+ margin: 0.5rem;
583
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
584
+ }
585
+
586
+ footer {
587
+ text-align: center;
588
+ padding: 2rem;
589
+ color: #666;
590
+ border-top: 1px solid #eee;
591
+ margin-top: 2rem;
592
+ }
593
+
594
+ .confetti {
595
+ position: fixed;
596
+ width: 10px;
597
+ height: 10px;
598
+ background: #f5576c;
599
+ position: absolute;
600
+ animation: confetti-fall 3s linear infinite;
601
+ }
602
+
603
+ @keyframes confetti-fall {
604
+ to { transform: translateY(100vh) rotate(360deg); }
605
+ }
606
+ """
607
+
608
+ # Production CSS
609
+ production_css = """
610
+ .gradio-container {
611
+ max-width: 1200px !important;
612
+ margin: auto !important;
613
  }
614
 
615
  .header {
 
621
  margin-bottom: 2rem;
622
  }
623
 
624
+ .header h1 {
625
+ font-size: 2.5rem;
626
+ margin-bottom: 0.5rem;
627
+ font-weight: 700;
628
+ }
629
+
630
+ .header p {
631
+ font-size: 1.1rem;
632
+ opacity: 0.95;
633
+ }
634
+
635
+ .stats-card {
636
  background: #f8f9fa;
637
+ padding: 1rem;
638
  border-radius: 8px;
639
  border-left: 4px solid #667eea;
640
  margin: 1rem 0;
641
  }
642
+
643
+ footer {
644
+ text-align: center;
645
+ padding: 2rem;
646
+ color: #666;
647
+ border-top: 1px solid #eee;
648
+ margin-top: 2rem;
649
+ }
650
  """
651
 
652
+ # Select CSS based on mode
653
+ custom_css = festive_css if FESTIVE else production_css
654
+
655
+ # Build interface
656
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
657
+ # Header
658
+ if FESTIVE:
659
+ gr.HTML("""
660
+ <div class="header">
661
+ <div class="celebration">πŸŽ‰ 🎊 ✨ 🎈 πŸŽ†</div>
662
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
663
+ alt="SAM-Z-1"
664
+ style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 24px rgba(0,0,0,0.2);">
665
+ <h1>πŸ€– SAM-Z-1 Chat πŸ€–</h1>
666
+ <p><strong>LATEST RELEASE!</strong> Our <strong>Best</strong> non-reasoning model</p>
667
+ <div class="twin-badge">Twin of SAM-X-1 (Reasoning Model)</div>
668
+ <p style="font-size: 0.9rem; margin-top: 1rem;">
669
+ 768D β€’ 16 Layers β€’ 12 Heads β€’ ~313M Parameters β€’ Trained on TPU v5e-8
670
+ </p>
671
+ <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
672
+ </div>
673
+ """)
674
+ else:
675
+ gr.HTML("""
676
+ <div class="header">
677
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
678
+ alt="SAM-Z-1"
679
+ style="max-width: 300px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 4px 16px rgba(0,0,0,0.15);">
680
+ <h1>πŸ€– SAM-Z-1 Chat</h1>
681
+ <p>Fast, direct responses without reasoning overhead</p>
682
+ <p style="font-size: 0.9rem; margin-top: 0.5rem;">
683
+ 768D β€’ 16 Layers β€’ 12 Heads β€’ Trained on TPU v5e-8
684
+ </p>
685
+ </div>
686
+ """)
687
+
688
+ with gr.Row():
689
+ with gr.Column(scale=4):
690
+ # Chat interface with bot avatar
691
+ chatbot = gr.Chatbot(
692
+ height=600,
693
+ show_label=False,
694
+ avatar_images=(
695
+ None,
696
+ "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg"
697
+ ),
698
+ bubble_full_width=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  )
 
 
 
 
 
 
 
 
 
 
 
700
 
701
  with gr.Row():
702
+ msg = gr.Textbox(
703
+ placeholder="Type your message here..." if not FESTIVE else "Ask me anything! I'm the fast twin! ⚑",
704
+ show_label=False,
705
+ scale=8,
706
+ container=False
707
+ )
708
+ submit_btn = gr.Button("Send πŸš€" if FESTIVE else "Send", variant="primary", scale=1)
709
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
+ with gr.Row():
712
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
713
+ retry_btn = gr.Button("πŸ”„ Retry", size="sm")
714
+
715
+ with gr.Column(scale=1):
716
+ gr.Markdown("### βš™οΈ Generation Settings")
717
+
718
+ max_tokens = gr.Slider(
719
+ minimum=50,
720
+ maximum=1024,
721
+ value=512,
722
+ step=50,
723
+ label="Max Tokens",
724
+ info="Maximum length of response"
 
 
725
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
 
727
+ temperature = gr.Slider(
728
+ minimum=0.1,
729
+ maximum=2.0,
730
+ value=0.8,
731
+ step=0.1,
732
+ label="Temperature",
733
+ info="Higher = more creative"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  )
 
 
 
 
735
 
736
+ top_k = gr.Slider(
737
+ minimum=1,
738
+ maximum=100,
739
+ value=40,
740
+ step=1,
741
+ label="Top-K",
742
+ info="Sample from top K tokens"
 
 
 
 
743
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
 
745
+ top_p = gr.Slider(
746
+ minimum=0.1,
747
+ maximum=1.0,
748
+ value=0.9,
749
+ step=0.05,
750
+ label="Top-P",
751
+ info="Nucleus sampling threshold"
752
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
 
754
+ repetition_penalty = gr.Slider(
755
+ minimum=1.0,
756
+ maximum=2.0,
757
+ value=1.1,
758
+ step=0.1,
759
+ label="Repetition Penalty",
760
+ info="Penalize repeated tokens"
761
+ )
762
 
763
+ gr.Markdown("---")
764
 
765
+ # Model info
766
+ if FESTIVE:
767
+ gr.Markdown(f"""
768
+ ### 🎊 SAM-Z-1 Model Info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
+ **🎯 The Fast Twin!**
771
+
772
+ **Type:** Direct Response Model
773
+ **Parameters:** ~313M
774
+ **Context:** {config['max_position_embeddings']} tokens
775
+ **Vocab:** {config['vocab_size']}
776
+ **Speed:** ⚑ Optimized with TF Functions
777
+
778
+ **Twin Model:**
779
+ - **SAM-X-1**: Reasoning model (uses `<think>` tags)
780
+ - **SAM-Z-1**: Fast model (no thinking, direct answers! πŸŽ‰)
781
+
782
+ **Note:** Model includes `<think>` tokens in vocab but doesn't use them. Training used same tokenizer as SAM-X-1.
783
+
784
+ **Architecture:**
785
+ - RoPE positional encoding
786
+ - SwiGLU activation
787
+ - RMSNorm layers
788
+ - No bias terms (efficient!)
789
+
790
+ **Training:**
791
+ - Trained from scratch
792
+ - TPU v5e-8 (8 cores)
793
+ - Mixed precision (bfloat16)
794
+ - Cosine decay schedule
795
+ """)
796
+ else:
797
+ gr.Markdown(f"""
798
+ ### πŸ“Š Model Info
799
+
800
+ **Architecture:** SAM-Z-1 (Direct Response)
801
+ **Parameters:** ~313M
802
+ **Context:** {config['max_position_embeddings']} tokens
803
+ **Vocab:** {config['vocab_size']}
804
+
805
+ **Twin Models:**
806
+ - SAM-X-1: Reasoning model (uses `<think>` tags)
807
+ - SAM-Z-1: Direct response model (no thinking)
808
+
809
+ **Note:** Vocab includes `<think>` tokens but model doesn't use them in generation.
810
+
811
+ **Features:**
812
+ - RoPE positional encoding
813
+ - SwiGLU activation
814
+ - RMSNorm layers
815
+ - TF-optimized inference
816
+ """)
817
+
818
+ # Example prompts
819
+ gr.Examples(
820
+ examples=[
821
+ "Hi! What can you do?",
822
+ "Explain quantum computing in simple terms",
823
+ "Write a short poem about AI",
824
+ "What's the capital of France?",
825
+ "How do I learn programming?",
826
+ "Tell me an interesting fact about space",
827
+ "What's the difference between you and SAM-X-1?",
828
+ "Why are you called the fast twin?",
 
 
 
 
 
 
 
 
 
 
 
 
829
  ],
830
+ inputs=msg,
831
+ label="πŸ’‘ Try these examples" if not FESTIVE else "🎯 Try these examples!"
832
  )
833
 
834
+ # Footer
835
+ if FESTIVE:
836
+ gr.HTML("""
837
+ <footer>
838
+ <p style="font-size: 1.2rem;"><strong>πŸŽ‰ SAM-Z-1 - LATEST RELEASE! πŸŽ‰</strong></p>
839
+ <p><strong>The Fast Twin</strong> - Direct responses without reasoning overhead</p>
840
+ <p style="font-size: 0.9rem; color: #999; margin-top: 0.5rem;">
841
+ Trained from scratch on TPU v5e-8 β€’ Built with TensorFlow & Gradio
842
+ </p>
843
+ <p style="font-size: 0.9rem; color: #999;">
844
+ Twin of SAM-X-1 (reasoning model) β€’ Same architecture, different training objective
845
+ </p>
846
+ <div style="margin-top: 1rem; font-size: 1.5rem;">
847
+ ⚑ πŸš€ πŸ’« ✨ 🎯
848
+ </div>
849
+ </footer>
850
+ """)
851
+ else:
852
+ gr.HTML("""
853
+ <footer>
854
+ <p><strong>SAM-Z-1</strong> - Direct response language model</p>
855
+ <p style="font-size: 0.9rem; color: #999;">
856
+ Trained from scratch on TPU v5e-8 β€’ Built with TensorFlow & Gradio
857
+ </p>
858
+ <p style="font-size: 0.9rem; color: #999;">
859
+ Twin of SAM-X-1 (reasoning model)
860
+ </p>
861
+ </footer>
862
+ """)
863
+
864
+ # Event handlers
865
+ submit_event = msg.submit(
866
+ chat_stream,
867
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
868
+ outputs=[chatbot]
869
+ ).then(
870
+ lambda: "",
871
+ outputs=[msg]
872
+ )
873
+
874
+ click_event = submit_btn.click(
875
+ chat_stream,
876
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
877
+ outputs=[chatbot]
878
+ ).then(
879
+ lambda: "",
880
+ outputs=[msg]
881
+ )
882
+
883
+ # Stop button
884
+ stop_btn.click(
885
+ fn=stop_gen,
886
+ inputs=None,
887
+ outputs=None,
888
+ cancels=[submit_event, click_event]
889
+ )
890
+
891
+ clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
892
+
893
+ def retry_last(history, max_tok, temp, topk, topp, rep_pen):
894
+ if not history:
895
+ return history
896
+ last_user_msg = history[-1][0]
897
+ history = history[:-1]
898
+ for update in chat_stream(last_user_msg, history, max_tok, temp, topk, topp, rep_pen):
899
+ yield update
900
+
901
+ retry_event = retry_btn.click(
902
+ retry_last,
903
+ inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty],
904
+ outputs=[chatbot]
905
+ )
906
+
907
+ stop_btn.click(
908
+ fn=stop_gen,
909
+ inputs=None,
910
+ outputs=None,
911
+ cancels=[retry_event]
912
  )
913
 
914
  # Launch