Keeby-smilyai commited on
Commit
9580f69
·
verified ·
1 Parent(s): 5a3d225

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -573
app.py CHANGED
@@ -15,19 +15,21 @@ import queue
15
  import hashlib
16
  import sqlite3
17
  from datetime import datetime
18
- from dataclasses import dataclass, field
19
- from typing import List, Dict, Optional
20
  import uuid
21
 
22
  # ==============================================================================
23
- # GPU/CPU Optimization
24
  # ==============================================================================
25
  tf.config.threading.set_inter_op_parallelism_threads(2)
26
  tf.config.threading.set_intra_op_parallelism_threads(4)
27
  tf.config.optimizer.set_jit(True)
28
 
 
 
 
29
  # ==============================================================================
30
- # Database Setup
31
  # ==============================================================================
32
  def init_db():
33
  conn = sqlite3.connect('sam_tasks.db', check_same_thread=False)
@@ -53,11 +55,10 @@ def init_db():
53
  tokens_per_sec REAL DEFAULT 0,
54
  FOREIGN KEY (user_id) REFERENCES users(id))''')
55
 
56
- # Create admin account
57
  admin_pass = hashlib.sha256("admin123".encode()).hexdigest()
58
  try:
59
- c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)",
60
- ("admin", admin_pass))
61
  conn.commit()
62
  except sqlite3.IntegrityError:
63
  pass
@@ -69,7 +70,7 @@ db_conn = init_db()
69
  db_lock = threading.Lock()
70
 
71
  # ==============================================================================
72
- # Model Architecture (Compact)
73
  # ==============================================================================
74
  @keras.saving.register_keras_serializable()
75
  class RotaryEmbedding(keras.layers.Layer):
@@ -124,11 +125,6 @@ class RMSNorm(keras.layers.Layer):
124
  def call(self, x):
125
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
126
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
127
-
128
- def get_config(self):
129
- config = super().get_config()
130
- config.update({"epsilon": self.epsilon})
131
- return config
132
 
133
  @keras.saving.register_keras_serializable()
134
  class TransformerBlock(keras.layers.Layer):
@@ -138,693 +134,408 @@ class TransformerBlock(keras.layers.Layer):
138
  self.n_heads = n_heads
139
  self.ff_dim = ff_dim
140
  self.dropout_rate = dropout
141
- self.max_len = max_len
142
- self.rope_theta = rope_theta
143
  self.head_dim = d_model // n_heads
144
- self.layer_idx = layer_idx
145
 
146
  self.pre_attn_norm = RMSNorm()
147
  self.pre_ffn_norm = RMSNorm()
 
148
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
149
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
150
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
151
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
152
- self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
153
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
154
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
155
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
156
  self.dropout = keras.layers.Dropout(dropout)
157
-
158
- def call(self, x, training=None):
159
- B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
160
- dtype = x.dtype
161
 
162
  res = x
163
  y = self.pre_attn_norm(x)
164
 
165
- q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
166
- k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
167
- v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- q, k = self.rope(q, k)
170
 
171
- scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
172
- mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
173
- tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
174
- scores += mask
175
- attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
 
 
 
176
 
177
- attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
178
  x = res + self.dropout(self.out_proj(attn), training=training)
179
 
180
  res = x
181
  y = self.pre_ffn_norm(x)
182
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
183
 
184
- return res + self.dropout(ffn, training=training)
185
-
186
  def get_config(self):
187
- config = super().get_config()
188
- config.update({
189
- "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim,
190
- "dropout": self.dropout_rate, "max_len": self.max_len,
191
- "rope_theta": self.rope_theta, "layer_idx": self.layer_idx
192
- })
193
- return config
194
 
195
  @keras.saving.register_keras_serializable()
196
  class SAM1Model(keras.Model):
197
- def __init__(self, **kwargs):
198
- super().__init__()
199
- if 'config' in kwargs and isinstance(kwargs['config'], dict):
200
- self.cfg = kwargs['config']
201
- elif 'vocab_size' in kwargs:
202
- self.cfg = kwargs
203
- else:
204
- self.cfg = kwargs.get('cfg', kwargs)
205
-
206
- self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
207
- ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
208
- block_args = {
209
- 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'],
210
- 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'],
211
- 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']
212
- }
213
-
214
- self.blocks = [TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
215
- for i in range(self.cfg['n_layers'])]
216
- self.norm = RMSNorm(name="final_norm")
217
- self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
218
-
219
- def call(self, input_ids, training=None):
220
  x = self.embed(input_ids)
221
- for block in self.blocks:
222
- x = block(x, training=training)
223
- return self.lm_head(self.norm(x))
224
-
225
- def get_config(self):
226
- base_config = super().get_config()
227
- base_config['config'] = self.cfg
228
- return base_config
229
 
230
  # ==============================================================================
231
- # KV Cache for SAM-Z (Ultra-Fast)
232
  # ==============================================================================
233
- @dataclass
234
- class KVCache:
235
- k_cache: List[tf.Tensor] = field(default_factory=list)
236
- v_cache: List[tf.Tensor] = field(default_factory=list)
237
-
238
- def update(self, layer_idx: int, k: tf.Tensor, v: tf.Tensor):
239
- if layer_idx >= len(self.k_cache):
240
- self.k_cache.append(k)
241
- self.v_cache.append(v)
242
- else:
243
- self.k_cache[layer_idx] = tf.concat([self.k_cache[layer_idx], k], axis=2)
244
- self.v_cache[layer_idx] = tf.concat([self.v_cache[layer_idx], v], axis=2)
245
- return self.k_cache[layer_idx], self.v_cache[layer_idx]
246
-
247
- def clear(self):
248
- self.k_cache.clear()
249
- self.v_cache.clear()
250
 
251
- # ==============================================================================
252
- # Load Models
253
- # ==============================================================================
254
- print("🚀 Loading SAM Models...")
255
 
256
- # SAM-X-1 (Reasoning with thinking)
257
- print("\n📦 Loading SAM-X-1-Large...")
258
  samx_weights = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5")
 
259
  samx_config_path = hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "config.json")
260
 
261
  with open(samx_config_path, 'r') as f:
262
- samx_cfg = json.load(f)
263
-
264
- samx_model_cfg = {
265
- 'vocab_size': samx_cfg['vocab_size'],
266
- 'd_model': samx_cfg['hidden_size'],
267
- 'n_layers': samx_cfg['num_hidden_layers'],
268
- 'n_heads': samx_cfg['num_attention_heads'],
269
- 'ff_mult': samx_cfg['intermediate_size'] / samx_cfg['hidden_size'],
270
- 'max_len': samx_cfg['max_position_embeddings'],
271
  'dropout': 0.0,
272
- 'rope_theta': samx_cfg['rope_theta']
273
- }
274
-
275
- samx_model = SAM1Model(config=samx_model_cfg)
276
- dummy = tf.zeros((1, 1), dtype=tf.int32)
277
- _ = samx_model(dummy)
278
  samx_model.load_weights(samx_weights)
279
  samx_model.trainable = False
280
 
281
- @tf.function(jit_compile=True)
282
- def samx_predict(inputs):
283
- return samx_model(inputs, training=False)
284
-
285
- print("✅ SAM-X-1 loaded")
286
-
287
- # SAM-Z-1 (Fast with KV cache)
288
- print("\n📦 Loading SAM-Z-1...")
289
  samz_weights = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "ckpt.weights.h5")
290
  samz_config_path = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "config.json")
291
 
292
  with open(samz_config_path, 'r') as f:
293
- samz_cfg = json.load(f)
294
-
295
- samz_model_cfg = {
296
- 'vocab_size': samz_cfg['vocab_size'],
297
- 'd_model': samz_cfg['hidden_size'],
298
- 'n_layers': samz_cfg['num_hidden_layers'],
299
- 'n_heads': samz_cfg['num_attention_heads'],
300
- 'ff_mult': samz_cfg['intermediate_size'] / samz_cfg['hidden_size'],
301
- 'max_len': samz_cfg['max_position_embeddings'],
302
  'dropout': 0.0,
303
- 'rope_theta': samz_cfg['rope_theta']
304
- }
305
-
306
- samz_model = SAM1Model(config=samz_model_cfg)
307
- _ = samz_model(dummy)
308
  samz_model.load_weights(samz_weights)
309
  samz_model.trainable = False
310
 
311
- @tf.function(jit_compile=True)
312
- def samz_predict(inputs):
313
- return samz_model(inputs, training=False)
314
-
315
- print("✅ SAM-Z-1 loaded")
316
-
317
  # Tokenizer
318
- tokenizer_path = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "tokenizer.json")
319
- tokenizer = Tokenizer.from_file(tokenizer_path)
320
  eos_token_id = 50256
321
 
322
- print(f"✅ Tokenizer ready (vocab: {tokenizer.get_vocab_size()})")
 
 
 
 
 
 
 
 
 
323
 
324
  # ==============================================================================
325
- # Background Task Processing
326
  # ==============================================================================
327
  task_queue = queue.Queue()
328
- active_tasks: Dict[str, Dict] = {}
329
  task_lock = threading.Lock()
330
 
331
- def create_task(user_id: int, model_name: str, prompt: str) -> str:
332
  task_id = str(uuid.uuid4())
333
-
334
  with db_lock:
335
  c = db_conn.cursor()
336
- c.execute("""INSERT INTO tasks (id, user_id, model_name, prompt, status)
337
- VALUES (?, ?, ?, ?, ?)""",
338
- (task_id, user_id, model_name, prompt, "queued"))
339
  db_conn.commit()
340
-
341
- with task_lock:
342
- active_tasks[task_id] = {
343
- 'status': 'queued',
344
- 'progress': 0,
345
- 'result': '',
346
- 'tokens_generated': 0,
347
- 'tokens_per_sec': 0.0
348
- }
349
-
350
- task_queue.put((task_id, user_id, model_name, prompt))
351
  return task_id
352
 
353
- def update_task_status(task_id: str, status: str, progress: int = 0,
354
- result: str = '', tokens: int = 0, tps: float = 0.0):
355
- with task_lock:
356
- if task_id in active_tasks:
357
- active_tasks[task_id].update({
358
- 'status': status,
359
- 'progress': progress,
360
- 'result': result,
361
- 'tokens_generated': tokens,
362
- 'tokens_per_sec': tps
363
- })
364
-
365
  with db_lock:
366
  c = db_conn.cursor()
367
  c.execute("""UPDATE tasks SET status=?, progress=?, result=?,
368
- tokens_generated=?, tokens_per_sec=?
369
- WHERE id=?""",
370
  (status, progress, result, tokens, tps, task_id))
371
-
372
- if status == 'completed':
373
- c.execute("UPDATE tasks SET completed_at=? WHERE id=?",
374
- (datetime.now().isoformat(), task_id))
375
-
376
  db_conn.commit()
377
 
378
- def generate_with_samx(prompt: str, task_id: str, max_tokens: int = 512):
379
- """SAM-X-1: Reasoning model with <think> tags"""
380
- input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
381
- generated = input_ids.copy()
382
- result = ""
383
 
 
 
384
  start_time = time.time()
385
 
386
- for step in range(max_tokens):
387
- logits = samx_predict(tf.constant([generated], dtype=tf.int32))
388
- next_logits = logits[0, -1, :].numpy()
389
-
390
- # Temperature sampling
391
- next_logits = next_logits / 0.7
392
- probs = tf.nn.softmax(next_logits).numpy()
393
- next_token = np.random.choice(len(probs), p=probs)
394
-
395
- if next_token == eos_token_id:
396
- break
397
-
398
- generated.append(int(next_token))
399
-
400
- # Decode periodically
401
- if step % 10 == 0 or step == max_tokens - 1:
402
- result = tokenizer.decode(generated[len(input_ids):])
403
- elapsed = time.time() - start_time
404
- tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
405
- progress = int((step / max_tokens) * 100)
406
-
407
- update_task_status(task_id, 'processing', progress, result,
408
- len(generated[len(input_ids):]), tps)
409
-
410
- # Final result
411
- result = tokenizer.decode(generated[len(input_ids):])
412
- elapsed = time.time() - start_time
413
- tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
414
 
415
- update_task_status(task_id, 'completed', 100, result,
416
- len(generated[len(input_ids):]), tps)
417
-
418
- def generate_with_samz(prompt: str, task_id: str, max_tokens: int = 512):
419
- """SAM-Z-1: Fast model with KV cache"""
420
- input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
421
- generated = input_ids.copy()
422
- result = ""
423
- kv_cache = KVCache()
424
 
425
- start_time = time.time()
 
426
 
 
 
 
427
  for step in range(max_tokens):
428
- # Use KV cache for speed
429
- if step == 0:
430
- current_input = generated
431
- else:
432
- current_input = [generated[-1]]
433
-
434
- logits = samz_predict(tf.constant([current_input], dtype=tf.int32))
435
- next_logits = logits[0, -1, :].numpy()
436
 
437
- # Fast sampling
438
- next_logits = next_logits / 0.8
439
- top_k = np.argpartition(next_logits, -40)[-40:]
440
- top_k_logits = next_logits[top_k]
441
- probs = tf.nn.softmax(top_k_logits).numpy()
442
- next_token = top_k[np.random.choice(len(probs), p=probs)]
443
 
444
  if next_token == eos_token_id:
445
  break
446
-
447
  generated.append(int(next_token))
448
 
449
- # Decode periodically
450
- if step % 15 == 0 or step == max_tokens - 1:
451
- result = tokenizer.decode(generated[len(input_ids):])
452
  elapsed = time.time() - start_time
453
- tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
454
- progress = int((step / max_tokens) * 100)
455
-
456
- update_task_status(task_id, 'processing', progress, result,
457
- len(generated[len(input_ids):]), tps)
458
-
459
- # Final result
460
- result = tokenizer.decode(generated[len(input_ids):])
461
  elapsed = time.time() - start_time
462
- tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
463
-
464
- update_task_status(task_id, 'completed', 100, result,
465
- len(generated[len(input_ids):]), tps)
466
 
467
- def task_worker():
468
- """Background worker thread"""
469
- print("🔧 Task worker started")
470
-
471
  while True:
472
  try:
473
- task_id, user_id, model_name, prompt = task_queue.get(timeout=1)
474
-
475
- print(f"⚙️ Processing task {task_id[:8]}... ({model_name})")
476
-
477
- update_task_status(task_id, 'processing', 0)
478
 
479
  try:
480
- if 'SAM-X' in model_name or 'Large' in model_name:
481
- generate_with_samx(prompt, task_id)
482
  else:
483
- generate_with_samz(prompt, task_id)
484
-
485
- print(f"✅ Task {task_id[:8]} completed")
486
  except Exception as e:
487
- print(f"❌ Task {task_id[:8]} failed: {e}")
488
- update_task_status(task_id, 'failed', 0, f"Error: {str(e)}")
489
-
490
  task_queue.task_done()
491
-
492
  except queue.Empty:
493
  continue
494
 
495
- # Start worker threads (2 workers for parallel processing)
496
  for _ in range(2):
497
- worker = threading.Thread(target=task_worker, daemon=True)
498
- worker.start()
499
-
500
- # ==============================================================================
501
- # User Management
502
- # ==============================================================================
503
- def hash_password(password: str) -> str:
504
- return hashlib.sha256(password.encode()).hexdigest()
505
-
506
- def create_user(username: str, password: str):
507
- with db_lock:
508
- try:
509
- c = db_conn.cursor()
510
- c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)",
511
- (username, hash_password(password)))
512
- db_conn.commit()
513
- return True, "Account created!"
514
- except sqlite3.IntegrityError:
515
- return False, "Username exists!"
516
-
517
- def authenticate(username: str, password: str):
518
- with db_lock:
519
- c = db_conn.cursor()
520
- c.execute("SELECT id, password_hash FROM users WHERE username=?", (username,))
521
- result = c.fetchone()
522
-
523
- if result and result[1] == hash_password(password):
524
- return True, result[0]
525
- return False, None
526
-
527
- def get_user_tasks(user_id: int):
528
- with db_lock:
529
- c = db_conn.cursor()
530
- c.execute("""SELECT id, model_name, prompt, status, progress,
531
- tokens_generated, tokens_per_sec, created_at
532
- FROM tasks WHERE user_id=?
533
- ORDER BY created_at DESC LIMIT 50""",
534
- (user_id,))
535
- return c.fetchall()
536
-
537
- def get_user_active_tasks(user_id: int):
538
- with db_lock:
539
- c = db_conn.cursor()
540
- c.execute("""SELECT COUNT(*) FROM tasks
541
- WHERE user_id=? AND status IN ('queued', 'processing')""",
542
- (user_id,))
543
- return c.fetchone()[0]
544
 
545
  # ==============================================================================
546
- # Gradio UI
547
  # ==============================================================================
548
  css = """
549
- .container { max-width: 1400px; margin: 0 auto; }
550
- .task-card {
551
- background: white;
552
- border: 2px solid #e5e7eb;
553
- border-radius: 12px;
554
- padding: 16px;
555
- margin: 8px 0;
556
- }
557
- .status-queued { color: #f59e0b; }
558
- .status-processing { color: #3b82f6; }
559
- .status-completed { color: #10b981; }
560
- .status-failed { color: #ef4444; }
561
- .progress-bar {
562
- height: 8px;
563
- background: #e5e7eb;
564
- border-radius: 4px;
565
- overflow: hidden;
566
- margin: 8px 0;
567
  }
568
- .progress-fill {
569
- height: 100%;
570
- background: linear-gradient(90deg, #10b981, #059669);
571
- transition: width 0.3s;
 
 
 
572
  }
573
  """
574
 
575
- with gr.Blocks(css=css, title="SAM Background Processor") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  user_id_state = gr.State(None)
577
 
578
- gr.Markdown("# 🚀 SAM Multi-Task Processor")
579
- gr.Markdown("Submit up to 5 background tasks. No need to stay on page!")
580
 
581
- # Auth
582
  with gr.Group(visible=True) as auth_group:
583
- gr.Markdown("### 🔐 Sign In / Sign Up")
584
- auth_username = gr.Textbox(label="Username", placeholder="username")
585
- auth_password = gr.Textbox(label="Password", type="password")
586
- auth_btn = gr.Button("Continue", variant="primary")
587
  auth_msg = gr.Markdown("")
588
-
589
- # Main UI
590
  with gr.Group(visible=False) as main_group:
591
- with gr.Row():
592
- gr.Markdown("### 🤖 Create Task")
593
- user_display = gr.Markdown("")
594
 
595
  with gr.Row():
596
- with gr.Column(scale=2):
597
- model_choice = gr.Radio(
598
- choices=["SAM-X-1-Large (Reasoning)", "SAM-Z-1 (Fast)"],
599
- value="SAM-Z-1 (Fast)",
600
- label="Model"
601
- )
602
- prompt_input = gr.Textbox(
603
- label="Prompt",
604
- placeholder="Enter your prompt...",
605
- lines=4
606
- )
607
- submit_btn = gr.Button("🚀 Submit Task", variant="primary", size="lg")
608
- task_msg = gr.Markdown("")
609
-
610
  with gr.Column(scale=1):
611
- gr.Markdown("### ℹ️ Info")
612
- gr.Markdown("""
613
- - **SAM-X-1**: Reasoning model with `<think>` tags
614
- - **SAM-Z-1**: Ultra-fast direct responses
615
- - Max 5 concurrent tasks
616
- - Results saved to database
617
- - Background processing
618
- """)
619
-
620
- gr.Markdown("---")
621
-
622
- with gr.Row():
623
- gr.Markdown("### 📋 Your Tasks")
624
- refresh_btn = gr.Button("🔄 Refresh", size="sm")
625
-
626
- tasks_display = gr.HTML("")
627
-
628
- auto_refresh = gr.Checkbox(label="Auto-refresh every 3 seconds", value=True)
629
-
630
- # Auth handler
631
- def handle_auth(username, password):
632
- if len(username) < 3 or len(password) < 6:
633
- return None, "❌ Invalid credentials", gr.update(), gr.update()
634
-
635
- success, user_id = authenticate(username, password)
636
-
637
- if not success:
638
- success, msg = create_user(username, password)
639
- if success:
640
- success, user_id = authenticate(username, password)
641
-
642
- if success:
643
- return (
644
- user_id,
645
- f"✅ Welcome, **{username}**!",
646
- gr.update(visible=False),
647
- gr.update(visible=True)
648
- )
649
-
650
- return None, "❌ Authentication failed", gr.update(), gr.update()
651
-
652
- # Submit task
653
- def submit_task(user_id, model, prompt):
654
- if not user_id:
655
- return "❌ Please sign in", ""
656
-
657
- if not prompt.strip():
658
- return "❌ Prompt required", ""
659
-
660
- active_count = get_user_active_tasks(user_id)
661
- if active_count >= 5:
662
- return f"❌ Max 5 active tasks (you have {active_count})", ""
663
-
664
- task_id = create_task(user_id, model, prompt)
665
- return f"✅ Task submitted! ID: `{task_id[:8]}...`", ""
666
 
667
- # Render tasks
668
- def render_tasks(user_id):
669
- if not user_id:
670
- return ""
671
-
672
- tasks = get_user_tasks(user_id)
673
-
674
- if not tasks:
675
- return "<div style='text-align: center; padding: 40px; color: #9ca3af;'>No tasks yet</div>"
676
-
677
- html = ""
678
- for task in tasks:
679
- task_id, model, prompt, status, progress, tokens, tps, created = task
680
-
681
- status_class = f"status-{status}"
682
 
683
- html += f"""
684
- <div class="task-card">
685
- <div style="display: flex; justify-content: space-between; margin-bottom: 8px;">
686
- <strong>Task: {task_id[:8]}...</strong>
687
- <span class="{status_class}">●{status.upper()}</span>
688
- </div>
689
- <div><strong>Model:</strong> {model}</div>
690
- <div><strong>Prompt:</strong> {prompt[:100]}{'...' if len(prompt) > 100 else ''}</div>
691
- <div class="progress-bar">
692
- <div class="progress-fill" style="width: {progress}%"></div>
693
- </div>
694
- <div style="font-size: 12px; color: #6b7280;">
695
- Progress: {progress}% | Tokens: {tokens} | Speed: {tps:.1f} tok/s
696
- </div>
697
- </div>
698
- """
699
 
 
 
 
 
 
 
 
700
  return html
 
 
 
701
 
702
- # Get task result
703
- def get_task_result(user_id, task_id_short):
704
- if not user_id or not task_id_short:
705
- return "❌ Invalid request"
706
 
 
 
707
  with db_lock:
708
  c = db_conn.cursor()
709
- c.execute("""SELECT result, status FROM tasks
710
- WHERE user_id=? AND id LIKE ?""",
711
- (user_id, f"{task_id_short}%"))
712
- result = c.fetchone()
713
-
714
- if result:
715
- if result[1] == 'completed':
716
- return f"### ✅ Result\n\n{result[0]}"
717
- elif result[1] == 'failed':
718
- return f"### ❌ Failed\n\n{result[0]}"
719
- else:
720
- return f"### Status: {result[1]}"
721
- return "❌ Task not found"
722
-
723
- # Event handlers
724
- auth_btn.click(
725
- handle_auth,
726
- [auth_username, auth_password],
727
- [user_id_state, auth_msg, auth_group, main_group]
728
- )
729
-
730
- submit_btn.click(
731
- submit_task,
732
- [user_id_state, model_choice, prompt_input],
733
- [task_msg, prompt_input]
734
- ).then(
735
- render_tasks,
736
- [user_id_state],
737
- [tasks_display]
738
- )
739
-
740
- refresh_btn.click(
741
- render_tasks,
742
- [user_id_state],
743
- [tasks_display]
744
- )
745
-
746
- # Auto-refresh timer
747
- def auto_refresh_tasks(user_id, enabled):
748
- if enabled and user_id:
749
- return render_tasks(user_id)
750
- return gr.update()
751
-
752
- # Poll every 3 seconds when auto-refresh enabled
753
- demo.load(
754
- lambda: None,
755
- None,
756
- None,
757
- every=3
758
- )
759
-
760
- # Update user display on load
761
- def update_user_display(user_id):
762
- if user_id:
763
- with db_lock:
764
- c = db_conn.cursor()
765
- c.execute("SELECT username FROM users WHERE id=?", (user_id,))
766
- result = c.fetchone()
767
- if result:
768
- active = get_user_active_tasks(user_id)
769
- return f"**User:** {result[0]} | **Active:** {active}/5"
770
- return ""
771
-
772
- # Periodic refresh
773
- refresh_timer = gr.Timer(3)
774
-
775
- @refresh_timer.tick
776
- def timer_refresh(user_id, auto_enabled):
777
- if auto_enabled and user_id:
778
- return render_tasks(user_id), update_user_display(user_id)
779
- return gr.update(), gr.update()
780
-
781
- refresh_timer.tick(
782
- timer_refresh,
783
- [user_id_state, auto_refresh],
784
- [tasks_display, user_display]
785
- )
786
-
787
- # View full result (expandable)
788
- with gr.Accordion("🔍 View Task Result", open=False):
789
- result_task_id = gr.Textbox(
790
- label="Task ID (first 8 chars)",
791
- placeholder="e.g., 3f7a9b2c"
792
- )
793
- view_result_btn = gr.Button("View Result", variant="primary")
794
- result_display = gr.Markdown("")
795
 
796
- view_result_btn.click(
797
- get_task_result,
798
- [user_id_state, result_task_id],
799
- [result_display]
800
  )
801
 
802
- # Initial load
803
- def on_auth_success(user_id):
804
- if user_id:
805
- return render_tasks(user_id), update_user_display(user_id)
806
- return "", ""
807
 
808
- user_id_state.change(
809
- on_auth_success,
810
- [user_id_state],
811
- [tasks_display, user_display]
812
- )
813
 
814
  if __name__ == "__main__":
815
- print("\n" + "="*80)
816
- print("🚀 SAM BACKGROUND PROCESSOR".center(80))
817
- print("="*80)
818
- print(f"✅ 2 worker threads active")
819
- print(f"✅ Max 5 tasks per user")
820
- print(f"✅ Background processing enabled")
821
- print(f"✅ Database: sam_tasks.db")
822
- print("="*80 + "\n")
823
-
824
- demo.queue(max_size=50)
825
- demo.launch(
826
- server_name="0.0.0.0",
827
- server_port=7860,
828
- share=False,
829
- show_error=True
830
- )
 
15
  import hashlib
16
  import sqlite3
17
  from datetime import datetime
18
+ from typing import List, Dict, Optional, Tuple, Any
 
19
  import uuid
20
 
21
  # ==============================================================================
22
+ # 1. Hardware & System Setup
23
  # ==============================================================================
24
  tf.config.threading.set_inter_op_parallelism_threads(2)
25
  tf.config.threading.set_intra_op_parallelism_threads(4)
26
  tf.config.optimizer.set_jit(True)
27
 
28
+ print(f"🚀 SmilyAI System Initializing...")
29
+ print(f"📱 TensorFlow Version: {tf.__version__}")
30
+
31
  # ==============================================================================
32
+ # 2. Database (State Management)
33
  # ==============================================================================
34
  def init_db():
35
  conn = sqlite3.connect('sam_tasks.db', check_same_thread=False)
 
55
  tokens_per_sec REAL DEFAULT 0,
56
  FOREIGN KEY (user_id) REFERENCES users(id))''')
57
 
58
+ # Admin account
59
  admin_pass = hashlib.sha256("admin123".encode()).hexdigest()
60
  try:
61
+ c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)", ("admin", admin_pass))
 
62
  conn.commit()
63
  except sqlite3.IntegrityError:
64
  pass
 
70
  db_lock = threading.Lock()
71
 
72
  # ==============================================================================
73
+ # 3. Model Architecture (Enhanced with KV Cache)
74
  # ==============================================================================
75
  @keras.saving.register_keras_serializable()
76
  class RotaryEmbedding(keras.layers.Layer):
 
125
  def call(self, x):
126
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
127
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
 
 
 
 
 
128
 
129
  @keras.saving.register_keras_serializable()
130
  class TransformerBlock(keras.layers.Layer):
 
134
  self.n_heads = n_heads
135
  self.ff_dim = ff_dim
136
  self.dropout_rate = dropout
 
 
137
  self.head_dim = d_model // n_heads
138
+ self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
139
 
140
  self.pre_attn_norm = RMSNorm()
141
  self.pre_ffn_norm = RMSNorm()
142
+
143
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
144
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
145
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
146
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
147
+
148
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
149
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
150
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
151
  self.dropout = keras.layers.Dropout(dropout)
152
+
153
+ def call(self, x, cache=None, training=None):
154
+ # Shape: [Batch, Time, Dim]
155
+ B, T = tf.shape(x)[0], tf.shape(x)[1]
156
 
157
  res = x
158
  y = self.pre_attn_norm(x)
159
 
160
+ # Projections
161
+ q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
162
+ k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
163
+ v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
164
+
165
+ # --- KV CACHE UPDATE ---
166
+ if cache is not None:
167
+ old_k, old_v = cache
168
+ k = tf.concat([old_k, k], axis=1)
169
+ v = tf.concat([old_v, v], axis=1)
170
+
171
+ new_cache = (k, v)
172
+
173
+ # RoPE & Attention
174
+ q = tf.transpose(q, [0, 2, 1, 3]) # [B, Heads, T, HeadDim]
175
+ k_rot = tf.transpose(k, [0, 2, 1, 3])
176
+
177
+ q_rot, k_rot = self.rope(q, k_rot)
178
 
179
+ scores = tf.matmul(q_rot, k_rot, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, x.dtype))
180
 
181
+ # Masking (Only needed if sequence length > 1)
182
+ if T > 1:
183
+ mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
184
+ mask = (1.0 - mask) * -1e9
185
+ scores += mask
186
+
187
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), tf.transpose(v, [0, 2, 1, 3]))
188
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, self.d_model])
189
 
 
190
  x = res + self.dropout(self.out_proj(attn), training=training)
191
 
192
  res = x
193
  y = self.pre_ffn_norm(x)
194
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
195
 
196
+ return res + self.dropout(ffn, training=training), new_cache
197
+
198
  def get_config(self):
199
+ return super().get_config()
 
 
 
 
 
 
200
 
201
  @keras.saving.register_keras_serializable()
202
  class SAM1Model(keras.Model):
203
+ def __init__(self, config, **kwargs):
204
+ super().__init__(**kwargs)
205
+ self.cfg = config
206
+ self.embed = keras.layers.Embedding(config['vocab_size'], config['d_model'])
207
+
208
+ ff_dim = int(config['d_model'] * config['ff_mult'])
209
+ self.blocks = [
210
+ TransformerBlock(
211
+ d_model=config['d_model'], n_heads=config['n_heads'], ff_dim=ff_dim,
212
+ dropout=config['dropout'], max_len=config['max_len'],
213
+ rope_theta=config['rope_theta'], name=f"blk_{i}"
214
+ ) for i in range(config['n_layers'])
215
+ ]
216
+ self.norm = RMSNorm()
217
+ self.lm_head = keras.layers.Dense(config['vocab_size'], use_bias=False)
218
+
219
+ def call(self, input_ids, cache=None, training=None):
 
 
 
 
 
 
220
  x = self.embed(input_ids)
221
+ new_caches = []
222
+
223
+ for i, block in enumerate(self.blocks):
224
+ layer_cache = cache[i] if cache is not None else None
225
+ x, updated_cache = block(x, cache=layer_cache, training=training)
226
+ new_caches.append(updated_cache)
227
+
228
+ return self.lm_head(self.norm(x)), new_caches
229
 
230
  # ==============================================================================
231
+ # 4. Load Models
232
  # ==============================================================================
233
+ print("\n📦 Loading SAM Models with KV Cache...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ # Dummy input for initialization
236
+ dummy_in = tf.zeros((1, 1), dtype=tf.int32)
 
 
237
 
238
+ # --- SAM-X-1 (Reasoning) ---
239
+ print("🔹 Loading SAM-X-1 (Reasoning)...")
240
  samx_weights = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5")
241
+ # UPDATED CONFIG PATH as requested
242
  samx_config_path = hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "config.json")
243
 
244
  with open(samx_config_path, 'r') as f:
245
+ cfg_x = json.load(f)
246
+
247
+ samx_model = SAM1Model({
248
+ 'vocab_size': cfg_x['vocab_size'],
249
+ 'd_model': cfg_x['hidden_size'],
250
+ 'n_layers': cfg_x['num_hidden_layers'],
251
+ 'n_heads': cfg_x['num_attention_heads'],
252
+ 'ff_mult': cfg_x['intermediate_size'] / cfg_x['hidden_size'],
253
+ 'max_len': cfg_x['max_position_embeddings'],
254
  'dropout': 0.0,
255
+ 'rope_theta': cfg_x['rope_theta']
256
+ })
257
+ _ = samx_model(dummy_in) # Build
 
 
 
258
  samx_model.load_weights(samx_weights)
259
  samx_model.trainable = False
260
 
261
+ # --- SAM-Z-1 (Fast) ---
262
+ print("🔹 Loading SAM-Z-1 (Speed)...")
 
 
 
 
 
 
263
  samz_weights = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "ckpt.weights.h5")
264
  samz_config_path = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "config.json")
265
 
266
  with open(samz_config_path, 'r') as f:
267
+ cfg_z = json.load(f)
268
+
269
+ samz_model = SAM1Model({
270
+ 'vocab_size': cfg_z['vocab_size'],
271
+ 'd_model': cfg_z['hidden_size'],
272
+ 'n_layers': cfg_z['num_hidden_layers'],
273
+ 'n_heads': cfg_z['num_attention_heads'],
274
+ 'ff_mult': cfg_z['intermediate_size'] / cfg_z['hidden_size'],
275
+ 'max_len': cfg_z['max_position_embeddings'],
276
  'dropout': 0.0,
277
+ 'rope_theta': cfg_z['rope_theta']
278
+ })
279
+ _ = samz_model(dummy_in) # Build
 
 
280
  samz_model.load_weights(samz_weights)
281
  samz_model.trainable = False
282
 
 
 
 
 
 
 
283
  # Tokenizer
284
+ tok_path = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "tokenizer.json")
285
+ tokenizer = Tokenizer.from_file(tok_path)
286
  eos_token_id = 50256
287
 
288
+ # JIT Compiled Prediction Steps (Separate for safety)
289
+ @tf.function(jit_compile=True)
290
+ def predict_x(ids, cache):
291
+ return samx_model(ids, cache=cache, training=False)
292
+
293
+ @tf.function(jit_compile=True)
294
+ def predict_z(ids, cache):
295
+ return samz_model(ids, cache=cache, training=False)
296
+
297
+ print("✅ Models Loaded & JIT Compiled")
298
 
299
  # ==============================================================================
300
+ # 5. Task Queue & Workers
301
  # ==============================================================================
302
  task_queue = queue.Queue()
303
+ active_tasks = {}
304
  task_lock = threading.Lock()
305
 
306
+ def create_task(user_id, model, prompt):
307
  task_id = str(uuid.uuid4())
 
308
  with db_lock:
309
  c = db_conn.cursor()
310
+ c.execute("INSERT INTO tasks (id, user_id, model_name, prompt, status) VALUES (?,?,?,?,?)",
311
+ (task_id, user_id, model, prompt, 'queued'))
 
312
  db_conn.commit()
313
+ task_queue.put((task_id, model, prompt))
 
 
 
 
 
 
 
 
 
 
314
  return task_id
315
 
316
+ def update_db_status(task_id, status, progress, result, tokens, tps):
 
 
 
 
 
 
 
 
 
 
 
317
  with db_lock:
318
  c = db_conn.cursor()
319
  c.execute("""UPDATE tasks SET status=?, progress=?, result=?,
320
+ tokens_generated=?, tokens_per_sec=? WHERE id=?""",
 
321
  (status, progress, result, tokens, tps, task_id))
322
+ if status in ['completed', 'failed']:
323
+ c.execute("UPDATE tasks SET completed_at=? WHERE id=?", (datetime.now().isoformat(), task_id))
 
 
 
324
  db_conn.commit()
325
 
326
+ def generate_stream(task_id, model_func, prompt, max_tokens=1024):
327
+ """Universal generator using KV Cache"""
 
 
 
328
 
329
+ # 1. Prefill Phase
330
+ input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
331
  start_time = time.time()
332
 
333
+ # Process generic prompt to get initial cache
334
+ # Note: We must treat 'None' cache as a special case in the TF function usually,
335
+ # or just pass generic list of None in Eager, but TF function expects tensors.
336
+ # For simplicity in this script, we run prefill in eager or adapt the loop.
337
+ # Here we do the first pass:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
+ current_ids = tf.constant([input_ids], dtype=tf.int32)
340
+ logits, kv_cache = model_func(current_ids, cache=None)
 
 
 
 
 
 
 
341
 
342
+ next_token = np.argmax(logits[0, -1, :].numpy())
343
+ generated = [int(next_token)]
344
 
345
+ update_db_status(task_id, 'processing', 0, tokenizer.decode(generated), 0, 0)
346
+
347
+ # 2. Decode Phase (Token by token)
348
  for step in range(max_tokens):
349
+ input_tensor = tf.constant([[generated[-1]]], dtype=tf.int32)
350
+ logits, kv_cache = model_func(input_tensor, cache=kv_cache)
 
 
 
 
 
 
351
 
352
+ # Sample
353
+ next_logits = logits[0, -1, :].numpy() / 0.7
354
+ probs = tf.nn.softmax(next_logits).numpy()
355
+ next_token = np.random.choice(len(probs), p=probs)
 
 
356
 
357
  if next_token == eos_token_id:
358
  break
359
+
360
  generated.append(int(next_token))
361
 
362
+ # Update DB every 3 tokens for smooth streaming UI
363
+ if step % 3 == 0:
364
+ text = tokenizer.decode(generated)
365
  elapsed = time.time() - start_time
366
+ tps = len(generated) / elapsed if elapsed > 0 else 0
367
+ prog = int((step / max_tokens) * 100)
368
+ update_db_status(task_id, 'processing', prog, text, len(generated), tps)
369
+
370
+ # Final Update
371
+ text = tokenizer.decode(generated)
 
 
372
  elapsed = time.time() - start_time
373
+ tps = len(generated) / elapsed
374
+ update_db_status(task_id, 'completed', 100, text, len(generated), tps)
 
 
375
 
376
+ def worker():
377
+ print("👷 Worker thread started")
 
 
378
  while True:
379
  try:
380
+ task_id, model_name, prompt = task_queue.get(timeout=1)
381
+ print(f"⚙️ Processing {task_id[:8]} with {model_name}")
 
 
 
382
 
383
  try:
384
+ if "SAM-X" in model_name:
385
+ generate_stream(task_id, predict_x, prompt)
386
  else:
387
+ generate_stream(task_id, predict_z, prompt)
 
 
388
  except Exception as e:
389
+ print(f"❌ Error: {e}")
390
+ update_db_status(task_id, 'failed', 0, f"Error: {str(e)}", 0, 0)
391
+
392
  task_queue.task_done()
 
393
  except queue.Empty:
394
  continue
395
 
396
+ # Start 2 Workers
397
  for _ in range(2):
398
+ t = threading.Thread(target=worker, daemon=True)
399
+ t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  # ==============================================================================
402
+ # 6. Gradio UI with Streaming & Thinking
403
  # ==============================================================================
404
  css = """
405
+ .container { max-width: 1200px; margin: 0 auto; }
406
+ .task-card {
407
+ border: 1px solid #e5e7eb; padding: 15px; margin-bottom: 10px; border-radius: 8px;
408
+ background: white; box-shadow: 0 1px 3px rgba(0,0,0,0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  }
410
+ .status-processing { color: #2563eb; font-weight: bold; animation: pulse 1.5s infinite; }
411
+ .status-completed { color: #059669; font-weight: bold; }
412
+ @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.5; } 100% { opacity: 1; } }
413
+ .thought-box {
414
+ background-color: #f0f9ff; border-left: 4px solid #0ea5e9;
415
+ padding: 10px; margin: 10px 0; font-family: monospace; font-size: 0.9em;
416
+ color: #0c4a6e;
417
  }
418
  """
419
 
420
+ def format_output(text):
421
+ if not text: return ""
422
+ # Parse <think> tags for SAM-X
423
+ if "<think>" in text:
424
+ parts = text.split("<think>")
425
+ pre = parts[0]
426
+ remainder = parts[1]
427
+ if "</think>" in remainder:
428
+ thought, ans = remainder.split("</think>")
429
+ return f"{pre}<div class='thought-box'>🧠 <b>Thinking Process:</b><br>{thought}</div>{ans}"
430
+ else:
431
+ return f"{pre}<div class='thought-box'>🧠 <b>Thinking...</b><br>{remainder}</div>"
432
+ return text.replace("\n", "<br>")
433
+
434
+ with gr.Blocks(css=css, title="SmilyAI Studio") as demo:
435
  user_id_state = gr.State(None)
436
 
437
+ gr.Markdown("# 🧠 SmilyAI Studio")
 
438
 
439
+ # --- Auth Section ---
440
  with gr.Group(visible=True) as auth_group:
441
+ gr.Markdown("### Login")
442
+ u_in = gr.Textbox(label="Username")
443
+ p_in = gr.Textbox(label="Password", type="password")
444
+ login_btn = gr.Button("Login / Register", variant="primary")
445
  auth_msg = gr.Markdown("")
446
+
447
+ # --- Main Interface ---
448
  with gr.Group(visible=False) as main_group:
449
+ gr.Markdown(f"### 🚀 New Inference Task")
 
 
450
 
451
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  with gr.Column(scale=1):
453
+ model_sel = gr.Radio(["SAM-X-1 (Reasoning)", "SAM-Z-1 (Fast)"], label="Model", value="SAM-Z-1 (Fast)")
454
+ prompt_in = gr.Textbox(label="Prompt", lines=4, placeholder="Enter query...")
455
+ sub_btn = gr.Button("Generate", variant="primary")
456
+
457
+ with gr.Column(scale=1):
458
+ gr.Markdown("### 📡 Live Monitor")
459
+ monitor_id = gr.Textbox(label="Task ID", placeholder="Click a task below to copy ID")
460
+ watch_btn = gr.Button("Open Stream")
461
+ stream_out = gr.HTML(label="Output", min_height=300)
462
+
463
+ gr.Markdown("### 📋 Task History")
464
+ refresh_btn = gr.Button("🔄 Refresh List")
465
+ task_list = gr.HTML()
466
+
467
+ # --- Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
+ def login(u, p):
470
+ if not u or not p: return None, "Enter details", gr.update(), gr.update()
471
+ hashed = hashlib.sha256(p.encode()).hexdigest()
472
+ with db_lock:
473
+ c = db_conn.cursor()
474
+ c.execute("SELECT id FROM users WHERE username=? AND password_hash=?", (u, hashed))
475
+ res = c.fetchone()
476
+ if not res:
477
+ try:
478
+ c.execute("INSERT INTO users (username, password_hash) VALUES (?,?)", (u, hashed))
479
+ db_conn.commit()
480
+ c.execute("SELECT id FROM users WHERE username=?", (u,))
481
+ res = c.fetchone()
482
+ except: return None, "Error", gr.update(), gr.update()
 
483
 
484
+ return res[0], f"Welcome {u}", gr.update(visible=False), gr.update(visible=True)
485
+
486
+ def submit(uid, mod, p):
487
+ if not uid: return "Please login"
488
+ tid = create_task(uid, mod, p)
489
+ return gr.update(value=""), tid # Clear prompt, set monitor ID
490
+
491
+ def get_tasks(uid):
492
+ if not uid: return ""
493
+ with db_lock:
494
+ c = db_conn.cursor()
495
+ c.execute("SELECT id, model_name, status, progress, created_at FROM tasks WHERE user_id=? ORDER BY created_at DESC LIMIT 10", (uid,))
496
+ rows = c.fetchall()
 
 
 
497
 
498
+ html = ""
499
+ for r in rows:
500
+ cls = f"status-{r[2]}"
501
+ html += f"""<div class='task-card' onclick="navigator.clipboard.writeText('{r[0]}')">
502
+ <b>{r[1]}</b> | <span class='{cls}'>{r[2].upper()}</span> | {r[3]}%
503
+ <br><small>ID: {r[0]}</small>
504
+ </div>"""
505
  return html
506
+
507
+ # Streaming Logic
508
+ timer = gr.Timer(0.5, active=False)
509
 
510
+ def start_watch(tid):
511
+ if not tid: return gr.update(active=False)
512
+ return gr.update(active=True)
 
513
 
514
+ def update_stream(uid, tid):
515
+ if not uid or not tid: return "Select a task...", gr.update(active=False)
516
  with db_lock:
517
  c = db_conn.cursor()
518
+ c.execute("SELECT result, status FROM tasks WHERE id=?", (tid,))
519
+ res = c.fetchone()
520
+
521
+ if not res: return "Task not found", gr.update(active=False)
522
+
523
+ formatted = format_output(res[0])
524
+ is_active = res[1] in ['queued', 'processing']
525
+
526
+ return formatted, gr.update(active=is_active)
527
+
528
+ # Wiring
529
+ login_btn.click(login, [u_in, p_in], [user_id_state, auth_msg, auth_group, main_group])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
+ sub_btn.click(submit, [user_id_state, model_sel, prompt_in], [prompt_in, monitor_id]).then(
532
+ get_tasks, [user_id_state], [task_list]
 
 
533
  )
534
 
535
+ refresh_btn.click(get_tasks, [user_id_state], [task_list])
 
 
 
 
536
 
537
+ watch_btn.click(start_watch, [monitor_id], [timer])
538
+ timer.tick(update_stream, [user_id_state, monitor_id], [stream_out, timer])
 
 
 
539
 
540
  if __name__ == "__main__":
541
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)