the-puzzler commited on
Commit
515a8b4
·
1 Parent(s): fbcf0db

added differnt argmax or sampling lgoits

Browse files
Files changed (1) hide show
  1. app.py +83 -12
app.py CHANGED
@@ -121,6 +121,32 @@ class CNA(nn.Module):
121
  # -----------------------------
122
  # Helpers
123
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def infer_expansion_factor_from_state(state, embed_dim):
125
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
126
  if key in state:
@@ -271,14 +297,34 @@ def ensure_model(ckpt_path):
271
  # Strategy 1 core step
272
  # -----------------------------
273
  @torch.no_grad()
274
- def step_strategy1(model, x):
275
- """One iteration: choose random position, set to argmax(logits)."""
 
 
 
 
 
 
 
276
  S = x.shape[1]
277
  pos = int(torch.randint(0, S, (1,)).item())
278
  logits_pos = model_logits(model, x)[0, pos] # [V]
279
- x[0, pos] = int(torch.argmax(logits_pos).item())
 
 
 
 
 
 
 
 
 
 
 
 
280
  return x
281
 
 
282
  # -----------------------------
283
  # Gradio logic
284
  # -----------------------------
@@ -332,17 +378,23 @@ def apply_noise(ckpt_path, state_ids, seqlen, indices_csv, add_left, add_right,
332
  txt = decode(x[0], tok)
333
  return x.tolist(), txt, "Applied noise brush / prepend / append"
334
 
335
- def step_once(ckpt_path, state_ids):
336
  ensure_model(ckpt_path or DEFAULT_CKPT)
337
  tok = model_cache["tokenizer"]
338
  if state_ids is None or len(state_ids) == 0:
339
  return None, "", "No sequence to step — initialize first."
340
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
341
- x = step_strategy1(model_cache["model"], x)
 
 
 
 
 
342
  txt = decode(x[0], tok)
343
- return x.tolist(), txt, "Stepped 1 iteration"
344
 
345
- def live_denoise(ckpt_path, state_ids, steps, snap_every, seed):
 
346
  """
347
  Generator for live updates. Yields (ids, text, status) every snap_every steps and on completion.
348
  """
@@ -355,11 +407,16 @@ def live_denoise(ckpt_path, state_ids, steps, snap_every, seed):
355
  total = int(steps)
356
  snap = max(1, int(snap_every))
357
  for t in range(1, total + 1):
358
- x = step_strategy1(model_cache["model"], x)
 
 
 
 
 
359
  if (t % snap == 0) or (t == total):
360
  txt = decode(x[0], tok)
361
- yield x.tolist(), txt, f"Live denoise… step {t}/{total}"
362
- # final yield already done in loop
363
 
364
  # -----------------------------
365
  # UI
@@ -390,6 +447,20 @@ with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
390
  status = gr.Markdown("Ready.")
391
 
392
  gr.Markdown("## Mode 1 · Random → Denoise Live")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  with gr.Row():
394
  btn_random = gr.Button("Initialize Random")
395
  steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
@@ -450,14 +521,14 @@ with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
450
  # Single step
451
  btn_step_once.click(
452
  step_once,
453
- [ckpt, ids_state],
454
  [ids_state, current_text, status]
455
  )
456
 
457
  # Live denoise (streaming)
458
  btn_live.click(
459
  live_denoise,
460
- [ckpt, ids_state, steps, snap_every, seed],
461
  [ids_state, current_text, status],
462
  show_progress=True
463
  )
 
121
  # -----------------------------
122
  # Helpers
123
  # -----------------------------
124
+ @torch.no_grad()
125
+ def sample_from_logits(logits_row: torch.Tensor, temperature: float = 1.0,
126
+ current_token: int | None = None, exclude_current: bool = True) -> int:
127
+ """
128
+ Sample a token from logits_row using softmax with temperature.
129
+ If exclude_current=True and current_token is provided, set its prob to 0 (then renormalize).
130
+ """
131
+ if temperature <= 0:
132
+ # safety: treat as argmax
133
+ return int(torch.argmax(logits_row).item())
134
+
135
+ scaled = logits_row / float(temperature)
136
+ probs = torch.softmax(scaled, dim=-1)
137
+
138
+ if exclude_current and current_token is not None:
139
+ probs = probs.clone()
140
+ probs[current_token] = 0.0
141
+ s = probs.sum()
142
+ if s.item() <= 0:
143
+ # fallback to argmax if everything got zeroed
144
+ return int(torch.argmax(logits_row).item())
145
+ probs = probs / s
146
+
147
+ return int(torch.multinomial(probs, num_samples=1).item())
148
+
149
+
150
  def infer_expansion_factor_from_state(state, embed_dim):
151
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
152
  if key in state:
 
297
  # Strategy 1 core step
298
  # -----------------------------
299
  @torch.no_grad()
300
+ def step_strategy1(model, x, mode: str = "argmax",
301
+ temperature: float = 1.0,
302
+ exclude_current: bool = True):
303
+ """
304
+ One iteration: choose random position, then update via:
305
+ - mode="argmax": set token to argmax(logits)
306
+ - mode="sample": sample from softmax(logits / temperature)
307
+ (optionally excluding current token)
308
+ """
309
  S = x.shape[1]
310
  pos = int(torch.randint(0, S, (1,)).item())
311
  logits_pos = model_logits(model, x)[0, pos] # [V]
312
+
313
+ if mode == "sample":
314
+ cur_tok = int(x[0, pos].item())
315
+ new_tok = sample_from_logits(
316
+ logits_pos,
317
+ temperature=float(temperature),
318
+ current_token=cur_tok,
319
+ exclude_current=bool(exclude_current)
320
+ )
321
+ x[0, pos] = new_tok
322
+ else:
323
+ # default / fallback: argmax
324
+ x[0, pos] = int(torch.argmax(logits_pos).item())
325
  return x
326
 
327
+
328
  # -----------------------------
329
  # Gradio logic
330
  # -----------------------------
 
378
  txt = decode(x[0], tok)
379
  return x.tolist(), txt, "Applied noise brush / prepend / append"
380
 
381
+ def step_once(ckpt_path, state_ids, mode, temperature, exclude_current):
382
  ensure_model(ckpt_path or DEFAULT_CKPT)
383
  tok = model_cache["tokenizer"]
384
  if state_ids is None or len(state_ids) == 0:
385
  return None, "", "No sequence to step — initialize first."
386
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
387
+ x = step_strategy1(
388
+ model_cache["model"], x,
389
+ mode=mode,
390
+ temperature=temperature,
391
+ exclude_current=exclude_current
392
+ )
393
  txt = decode(x[0], tok)
394
+ return x.tolist(), txt, f"Stepped 1 iteration ({mode})"
395
 
396
+ def live_denoise(ckpt_path, state_ids, steps, snap_every, seed,
397
+ mode, temperature, exclude_current):
398
  """
399
  Generator for live updates. Yields (ids, text, status) every snap_every steps and on completion.
400
  """
 
407
  total = int(steps)
408
  snap = max(1, int(snap_every))
409
  for t in range(1, total + 1):
410
+ x = step_strategy1(
411
+ model_cache["model"], x,
412
+ mode=mode,
413
+ temperature=temperature,
414
+ exclude_current=exclude_current
415
+ )
416
  if (t % snap == 0) or (t == total):
417
  txt = decode(x[0], tok)
418
+ yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})"
419
+
420
 
421
  # -----------------------------
422
  # UI
 
447
  status = gr.Markdown("Ready.")
448
 
449
  gr.Markdown("## Mode 1 · Random → Denoise Live")
450
+ with gr.Row():
451
+ update_mode = gr.Radio(
452
+ choices=["argmax", "sample"],
453
+ value="argmax",
454
+ label="Update rule"
455
+ )
456
+ temperature = gr.Slider(
457
+ minimum=0.0, maximum=5.0, value=1.0, step=0.05,
458
+ label="Temperature (sampling)"
459
+ )
460
+ exclude_current = gr.Checkbox(
461
+ value=True,
462
+ label="Exclude current token when sampling"
463
+ )
464
  with gr.Row():
465
  btn_random = gr.Button("Initialize Random")
466
  steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
 
521
  # Single step
522
  btn_step_once.click(
523
  step_once,
524
+ [ckpt, ids_state, update_mode, temperature, exclude_current],
525
  [ids_state, current_text, status]
526
  )
527
 
528
  # Live denoise (streaming)
529
  btn_live.click(
530
  live_denoise,
531
+ [ckpt, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current],
532
  [ids_state, current_text, status],
533
  show_progress=True
534
  )