basilboy commited on
Commit
cd1fae2
·
verified ·
1 Parent(s): b9ca465

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -81
app.py CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
8
  from safetensors.torch import load_file as load_sft
9
  from huggingface_hub import snapshot_download
10
 
11
- torch.set_default_dtype(torch.float32)
12
 
13
  # ===============================================
14
  # Default config (from your training notes)
@@ -62,9 +62,10 @@ class AttnBlock(nn.Module):
62
  return Qh2, Kh2
63
 
64
  def forward(self, x, rope, radius):
 
65
  if x.dtype != self.norm1.weight.dtype:
66
  x = x.to(self.norm1.weight.dtype)
67
-
68
  h = self.norm1(x)
69
  B, S, E = h.shape
70
  cos, sin = rope
@@ -130,6 +131,11 @@ class CNA(nn.Module):
130
  h = self.tok_emb(x)
131
  else:
132
  h = x
 
 
 
 
 
133
  B, S, E = h.shape
134
  hd = self.embed_dim // self.num_heads
135
  cos, sin = self._rope_seq(S, hd, h.device, h.dtype)
@@ -139,8 +145,7 @@ class CNA(nn.Module):
139
 
140
  # ===============================================
141
  # Helpers
142
- #
143
-
144
  def to_batch2(ids_like) -> torch.Tensor:
145
  """
146
  Normalize ids_like (list, [[...]], tensor) to int64 shape [1, S].
@@ -155,7 +160,6 @@ def to_batch2(ids_like) -> torch.Tensor:
155
  x = x.view(1, -1) # fallback reshape
156
  return x
157
 
158
-
159
  def infer_expansion_factor_from_state(state, embed_dim):
160
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
161
  if key in state:
@@ -250,36 +254,24 @@ def sample_from_logits(logits_row, temperature=1.0, current_token=None, exclude_
250
 
251
  # ===============================================
252
  # Weight loading (file / folder / HF Hub)
253
- # Handles weights-only .pt (state_dict) as well.
254
  # ===============================================
255
  DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
256
  DEFAULT_WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "weights_latest")
257
 
258
  def _read_config_from_dict_or_infer(state, cfg):
259
- # start from provided cfg merged over defaults
260
  merged = {**DEFAULT_CONF, **(cfg or {})}
261
-
262
- # infer from weights if available
263
  if "tok_emb.weight" in state:
264
  merged["embed_dim"] = state["tok_emb.weight"].shape[1]
265
- # infer num_blocks by scanning keys
266
  block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m]
267
  if block_idxs:
268
  merged["num_blocks"] = max(block_idxs) + 1
269
-
270
- # num_heads, radius, expansion_factor often aren't inferable; keep merged defaults
271
- # expansion_factor can be inferred from MLP shapes if present
272
  if "blocks.0.mlp.0.weight" in state or "blocks.0.mlp.2.weight" in state:
273
  merged["expansion_factor"] = infer_expansion_factor_from_state(state, merged["embed_dim"])
274
-
275
- # tokenizer
276
  if not merged.get("tokenizer_name"):
277
  merged["tokenizer_name"] = "gpt2"
278
-
279
  return merged
280
 
281
  def _is_state_dict(obj):
282
- # A reasonable heuristic: a dict whose values are Tensors (and keys look like module names)
283
  if isinstance(obj, dict) and obj:
284
  sample_val = next(iter(obj.values()))
285
  return isinstance(sample_val, torch.Tensor)
@@ -287,14 +279,12 @@ def _is_state_dict(obj):
287
 
288
  def _load_state_from_pt(path: str):
289
  obj = torch.load(path, map_location="cpu")
290
- # Case A: legacy payload with {"model": state_dict, "config": {...}}
291
  if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict):
292
  state = obj["model"]
293
  cfg = obj.get("config", {}) or {}
294
  if "tokenizer_name" in obj:
295
  cfg = {**cfg, "tokenizer_name": obj["tokenizer_name"]}
296
  return state, cfg
297
- # Case B: weights-only state_dict (your case)
298
  if _is_state_dict(obj):
299
  return obj, {}
300
  raise ValueError(f"Unsupported .pt format at {path}: expected a state_dict or a payload with 'model'.")
@@ -402,8 +392,8 @@ def load_model(source: str):
402
  nn.init.zeros_(model.proj.bias)
403
  else:
404
  model.load_state_dict(state, strict=True)
405
-
406
- # hard-cast ALL params & buffers to float32 (handles weights-only .pt that saved as float64)
407
  model = model.to(torch.float32)
408
  with torch.no_grad():
409
  for p in model.parameters():
@@ -412,7 +402,7 @@ def load_model(source: str):
412
  for _, buf in model.named_buffers():
413
  if buf.dtype.is_floating_point:
414
  buf.data = buf.data.float()
415
-
416
  model.eval()
417
  return model, tokenizer, conf["radius"]
418
 
@@ -427,11 +417,10 @@ def _auto_default_source():
427
  for name in ["weights_latest.pt", "ckpt_latest.pt"]:
428
  if os.path.isfile(name):
429
  return name
430
- # first .pt or .safetensors in repo root
431
  for f in sorted(os.listdir(".")):
432
  if f.endswith(".pt") or f.endswith(".safetensors"):
433
  return f
434
- return "weights_latest.pt" # sane default for your case
435
 
436
  def ensure_model(source_path_or_repo):
437
  src = source_path_or_repo or _auto_default_source()
@@ -467,33 +456,73 @@ def init_random(src, seqlen, seed):
467
  txt = decode(x[0], model_cache["tokenizer"])
468
  return x.tolist(), txt, f"Initialized random sequence (len={int(seqlen)})"
469
 
470
- def init_from_text(src, seqlen, text, seed, pad_mode):
471
- ensure_model(src)
472
- rnd = random.Random(seed)
473
- x = to_fixed_len_ids(text or "", model_cache["tokenizer"], int(seqlen), pad_mode=pad_mode, rnd=rnd)
474
- txt = decode(x[0], model_cache["tokenizer"])
475
- return x.tolist(), txt, "Initialized from text"
 
 
 
 
 
 
 
 
 
 
476
 
477
- def append_text(src, state_ids, seqlen, text_to_append, seed):
478
- ensure_model(src)
 
 
 
 
479
  tok = model_cache["tokenizer"]
480
- rnd = random.Random(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  S = int(seqlen)
482
- if state_ids is None or len(state_ids) == 0:
483
- x = to_fixed_len_ids(text_to_append or "", tok, S, pad_mode="random", rnd=rnd)
484
- else:
485
- x = to_batch2(state_ids) # <-- normalize
486
- extra = tok.encode(text_to_append or "", add_special_tokens=False)
487
- x = torch.cat([x, torch.tensor(extra, dtype=torch.long).unsqueeze(0)], dim=1)
488
- if x.shape[1] > S:
489
- x = x[:, :S]
490
- elif x.shape[1] < S:
491
- need = S - x.shape[1]
492
- V = tok.vocab_size
493
- pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
494
- x = torch.cat([x, pad], dim=1)
495
- txt = decode(x[0], tok)
496
- return x.tolist(), txt, "Appended text and resized to target length"
497
 
498
  def apply_noise(src, state_ids, seqlen, indices_csv, add_left, add_right, seed):
499
  ensure_model(src)
@@ -503,29 +532,28 @@ def apply_noise(src, state_ids, seqlen, indices_csv, add_left, add_right, seed):
503
  V = tok.vocab_size
504
  base = torch.randint(0, V, (1, S))
505
  else:
506
- base = to_batch2(state_ids) # <-- normalize
507
  x = apply_noise_ops(base, tok, indices_csv, int(add_left or 0), int(add_right or 0), S, seed=seed)
508
  txt = decode(x[0], tok)
509
- return x.tolist(), txt, "Applied noise brush / prepend / append"
510
 
511
  def step_once(src, state_ids, mode, temperature, exclude_current):
512
  ensure_model(src)
513
  tok = model_cache["tokenizer"]
514
  if state_ids is None or len(state_ids) == 0:
515
  return None, "", "No sequence to step — initialize first."
516
- x = to_batch2(state_ids) # <-- instead of torch.tensor(...).unsqueeze(0)
517
  x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
518
  txt = decode(x[0], tok)
519
  return x.tolist(), txt, f"Stepped 1 iteration ({mode})"
520
 
521
-
522
  def live_denoise(src, state_ids, steps, snap_every, seed, mode, temperature, exclude_current):
523
  ensure_model(src)
524
  tok = model_cache["tokenizer"]
525
  if state_ids is None or len(state_ids) == 0:
526
  return
527
  random.seed(seed); torch.manual_seed(seed)
528
- x = to_batch2(state_ids) # <-- normalize
529
  total = int(steps); snap = max(1, int(snap_every))
530
  for t in range(1, total + 1):
531
  x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
@@ -534,19 +562,22 @@ def live_denoise(src, state_ids, steps, snap_every, seed, mode, temperature, exc
534
  yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})"
535
 
536
  # ===============================================
537
- # UI
538
  # ===============================================
539
  with gr.Blocks(title="CNA — Interactive Denoising") as demo:
540
  gr.Markdown(
541
  """
542
  # CNA — Interactive Denoising (Strategy 1)
543
- - **Weights source** can be: a `.pt` **weights-only state_dict** (e.g., `weights_latest.pt`), a folder of shards, or a **Hub repo id**.
544
- - Update rule per step: **argmax** or **sample** (temperature + option to exclude current token).
545
- - Tools: Random init, Init from text, Noise brush (select indices, prepend/append noise), Append text, Live denoise.
546
  """
547
  )
548
 
549
- default_source = _auto_default_source()
 
 
 
550
  with gr.Row():
551
  src = gr.Textbox(value=default_source, label="Weights (file / folder / HF repo id)")
552
  seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
@@ -555,10 +586,10 @@ with gr.Blocks(title="CNA — Interactive Denoising") as demo:
555
  ids_state = gr.State(value=None)
556
 
557
  with gr.Row():
558
- current_text = gr.Textbox(lines=8, label="Current text", interactive=False)
559
  status = gr.Markdown("Ready.")
560
 
561
- gr.Markdown("## Mode 1 · Random → Denoise Live")
562
  with gr.Row():
563
  btn_random = gr.Button("Initialize Random")
564
  steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
@@ -571,32 +602,50 @@ with gr.Blocks(title="CNA — Interactive Denoising") as demo:
571
  btn_step_once = gr.Button("Step Once")
572
  btn_live = gr.Button("Denoise Live (streaming)")
573
 
574
- gr.Markdown("## Mode 2 · Initialize From Your Text")
575
  with gr.Row():
576
- init_text = gr.Textbox(lines=4, label="Initial text")
577
- with gr.Row():
578
- pad_mode = gr.Radio(choices=["random", "eos"], value="random", label="Pad mode (if text shorter than S)")
579
- btn_init_text = gr.Button("Initialize From Text")
580
-
581
- gr.Markdown("## Noise Brush · Select Positions + Prepend/Append Noise")
582
- with gr.Row():
583
- indices_csv = gr.Textbox(label="Positions to noise (e.g., 0, 5, 10-20)", placeholder="Leave empty to skip")
584
  with gr.Row():
585
  add_left = gr.Number(value=0, precision=0, label="Noise tokens to add at START")
586
  add_right = gr.Number(value=0, precision=0, label="Noise tokens to add at END")
587
- btn_apply_noise = gr.Button("Apply Noise Brush / Prepend / Append")
588
-
589
- gr.Markdown("## Append Text")
590
- with gr.Row():
591
- append_box = gr.Textbox(lines=3, label="Text to append")
592
- btn_append = gr.Button("Append to Current Sequence")
593
 
594
- # Wiring
595
  btn_random.click(init_random, [src, seqlen, seed], [ids_state, current_text, status])
596
- btn_init_text.click(init_from_text, [src, seqlen, init_text, seed, pad_mode], [ids_state, current_text, status])
597
- btn_apply_noise.click(apply_noise, [src, ids_state, seqlen, indices_csv, add_left, add_right, seed], [ids_state, current_text, status])
598
- btn_append.click(append_text, [src, ids_state, seqlen, append_box, seed], [ids_state, current_text, status])
599
- btn_step_once.click(step_once, [src, ids_state, update_mode, temperature, exclude_current], [ids_state, current_text, status])
600
- btn_live.click(live_denoise, [src, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current], [ids_state, current_text, status], show_progress=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
 
602
  demo.queue().launch()
 
8
  from safetensors.torch import load_file as load_sft
9
  from huggingface_hub import snapshot_download
10
 
11
+ torch.set_default_dtype(torch.float32)
12
 
13
  # ===============================================
14
  # Default config (from your training notes)
 
62
  return Qh2, Kh2
63
 
64
  def forward(self, x, rope, radius):
65
+ # keep LN inputs & params same dtype
66
  if x.dtype != self.norm1.weight.dtype:
67
  x = x.to(self.norm1.weight.dtype)
68
+
69
  h = self.norm1(x)
70
  B, S, E = h.shape
71
  cos, sin = rope
 
131
  h = self.tok_emb(x)
132
  else:
133
  h = x
134
+ # ensure embeddings/activations dtype follows model dtype
135
+ target_dtype = next(self.parameters()).dtype
136
+ if h.dtype != target_dtype:
137
+ h = h.to(target_dtype)
138
+
139
  B, S, E = h.shape
140
  hd = self.embed_dim // self.num_heads
141
  cos, sin = self._rope_seq(S, hd, h.device, h.dtype)
 
145
 
146
  # ===============================================
147
  # Helpers
148
+ # ===============================================
 
149
  def to_batch2(ids_like) -> torch.Tensor:
150
  """
151
  Normalize ids_like (list, [[...]], tensor) to int64 shape [1, S].
 
160
  x = x.view(1, -1) # fallback reshape
161
  return x
162
 
 
163
  def infer_expansion_factor_from_state(state, embed_dim):
164
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
165
  if key in state:
 
254
 
255
  # ===============================================
256
  # Weight loading (file / folder / HF Hub)
 
257
  # ===============================================
258
  DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
259
  DEFAULT_WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "weights_latest")
260
 
261
  def _read_config_from_dict_or_infer(state, cfg):
 
262
  merged = {**DEFAULT_CONF, **(cfg or {})}
 
 
263
  if "tok_emb.weight" in state:
264
  merged["embed_dim"] = state["tok_emb.weight"].shape[1]
 
265
  block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m]
266
  if block_idxs:
267
  merged["num_blocks"] = max(block_idxs) + 1
 
 
 
268
  if "blocks.0.mlp.0.weight" in state or "blocks.0.mlp.2.weight" in state:
269
  merged["expansion_factor"] = infer_expansion_factor_from_state(state, merged["embed_dim"])
 
 
270
  if not merged.get("tokenizer_name"):
271
  merged["tokenizer_name"] = "gpt2"
 
272
  return merged
273
 
274
  def _is_state_dict(obj):
 
275
  if isinstance(obj, dict) and obj:
276
  sample_val = next(iter(obj.values()))
277
  return isinstance(sample_val, torch.Tensor)
 
279
 
280
  def _load_state_from_pt(path: str):
281
  obj = torch.load(path, map_location="cpu")
 
282
  if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict):
283
  state = obj["model"]
284
  cfg = obj.get("config", {}) or {}
285
  if "tokenizer_name" in obj:
286
  cfg = {**cfg, "tokenizer_name": obj["tokenizer_name"]}
287
  return state, cfg
 
288
  if _is_state_dict(obj):
289
  return obj, {}
290
  raise ValueError(f"Unsupported .pt format at {path}: expected a state_dict or a payload with 'model'.")
 
392
  nn.init.zeros_(model.proj.bias)
393
  else:
394
  model.load_state_dict(state, strict=True)
395
+
396
+ # enforce float32 across params & buffers
397
  model = model.to(torch.float32)
398
  with torch.no_grad():
399
  for p in model.parameters():
 
402
  for _, buf in model.named_buffers():
403
  if buf.dtype.is_floating_point:
404
  buf.data = buf.data.float()
405
+
406
  model.eval()
407
  return model, tokenizer, conf["radius"]
408
 
 
417
  for name in ["weights_latest.pt", "ckpt_latest.pt"]:
418
  if os.path.isfile(name):
419
  return name
 
420
  for f in sorted(os.listdir(".")):
421
  if f.endswith(".pt") or f.endswith(".safetensors"):
422
  return f
423
+ return "weights_latest.pt"
424
 
425
  def ensure_model(source_path_or_repo):
426
  src = source_path_or_repo or _auto_default_source()
 
456
  txt = decode(x[0], model_cache["tokenizer"])
457
  return x.tolist(), txt, f"Initialized random sequence (len={int(seqlen)})"
458
 
459
+ def to_ranges(indices):
460
+ """Compress a sorted list of token indices into 'a-b' CSV."""
461
+ if not indices:
462
+ return ""
463
+ indices = sorted(set(indices))
464
+ ranges = []
465
+ start = prev = indices[0]
466
+ for i in indices[1:]:
467
+ if i == prev + 1:
468
+ prev = i
469
+ else:
470
+ ranges.append((start, prev))
471
+ start = prev = i
472
+ ranges.append((start, prev))
473
+ parts = [f"{a}-{b}" if a != b else f"{a}" for a, b in ranges]
474
+ return ", ".join(parts)
475
 
476
+ def capture_selection(text, seqlen, current_ids, evt: gr.SelectData | None = None):
477
+ """
478
+ Map highlighted character span in `text` to token index ranges using tokenizer offsets.
479
+ Auto-fills the indices box so you can 'Noise Selection'.
480
+ """
481
+ ensure_model(None)
482
  tok = model_cache["tokenizer"]
483
+
484
+ if not text:
485
+ return gr.update(), "No text to select from."
486
+
487
+ # Try to read (start, end) from the event payload
488
+ start, end = None, None
489
+ if evt is not None:
490
+ try:
491
+ # gradio SelectData for Textbox exposes .index = (start_char, end_char)
492
+ start, end = evt.index
493
+ except Exception:
494
+ pass
495
+ # Fallback: nothing selected
496
+ if start is None or end is None or start == end:
497
+ return gr.update(), "No selection detected (drag to highlight)."
498
+
499
+ # Bound the indices defensively
500
+ start = max(0, min(len(text), int(start)))
501
+ end = max(0, min(len(text), int(end)))
502
+
503
+ # Get per-token char offsets from the fast tokenizer
504
+ enc = tok(text, add_special_tokens=False, return_offsets_mapping=True)
505
+ offsets = enc["offset_mapping"] # list of (s,e) per token
506
+ token_idxs = []
507
+ for i, (s, e) in enumerate(offsets):
508
+ if s is None or e is None:
509
+ continue
510
+ # overlap if token span intersects [start, end)
511
+ if max(s, start) < min(e, end):
512
+ token_idxs.append(i)
513
+
514
+ if not token_idxs:
515
+ return gr.update(), "Selection didn't hit any tokens (maybe whitespace)."
516
+
517
+ # Clip to current sequence length (so we don't index beyond S)
518
  S = int(seqlen)
519
+ token_idxs = [i for i in token_idxs if i < S]
520
+
521
+ if not token_idxs:
522
+ return gr.update(), "Selected span maps beyond current sequence length."
523
+
524
+ indices_csv = to_ranges(token_idxs)
525
+ return indices_csv, f"Selected chars [{start}:{end}) tokens {indices_csv}"
 
 
 
 
 
 
 
 
526
 
527
  def apply_noise(src, state_ids, seqlen, indices_csv, add_left, add_right, seed):
528
  ensure_model(src)
 
532
  V = tok.vocab_size
533
  base = torch.randint(0, V, (1, S))
534
  else:
535
+ base = to_batch2(state_ids)
536
  x = apply_noise_ops(base, tok, indices_csv, int(add_left or 0), int(add_right or 0), S, seed=seed)
537
  txt = decode(x[0], tok)
538
+ return x.tolist(), txt, "Applied noise"
539
 
540
  def step_once(src, state_ids, mode, temperature, exclude_current):
541
  ensure_model(src)
542
  tok = model_cache["tokenizer"]
543
  if state_ids is None or len(state_ids) == 0:
544
  return None, "", "No sequence to step — initialize first."
545
+ x = to_batch2(state_ids)
546
  x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
547
  txt = decode(x[0], tok)
548
  return x.tolist(), txt, f"Stepped 1 iteration ({mode})"
549
 
 
550
  def live_denoise(src, state_ids, steps, snap_every, seed, mode, temperature, exclude_current):
551
  ensure_model(src)
552
  tok = model_cache["tokenizer"]
553
  if state_ids is None or len(state_ids) == 0:
554
  return
555
  random.seed(seed); torch.manual_seed(seed)
556
+ x = to_batch2(state_ids)
557
  total = int(steps); snap = max(1, int(snap_every))
558
  for t in range(1, total + 1):
559
  x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
 
562
  yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})"
563
 
564
  # ===============================================
565
+ # UI (single mode)
566
  # ===============================================
567
  with gr.Blocks(title="CNA — Interactive Denoising") as demo:
568
  gr.Markdown(
569
  """
570
  # CNA — Interactive Denoising (Strategy 1)
571
+ - **Weights source**: `.pt` weights-only (e.g., `weights_latest.pt`), a folder of shards, or a **Hub repo id**.
572
+ - Update rule per step: **argmax** or **sample** (temperature + exclude current).
573
+ - Tools: Random init, **drag to select** in the text box → *Noise Selection*, manual indices, prepend/append noise, live denoise.
574
  """
575
  )
576
 
577
+ default_source = os.environ.get("WEIGHTS_SOURCE", None)
578
+ if default_source is None:
579
+ default_source = _auto_default_source()
580
+
581
  with gr.Row():
582
  src = gr.Textbox(value=default_source, label="Weights (file / folder / HF repo id)")
583
  seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
 
586
  ids_state = gr.State(value=None)
587
 
588
  with gr.Row():
589
+ current_text = gr.Textbox(lines=8, label="Current text", interactive=True)
590
  status = gr.Markdown("Ready.")
591
 
592
+ gr.Markdown("### Initialize & Denoise")
593
  with gr.Row():
594
  btn_random = gr.Button("Initialize Random")
595
  steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
 
602
  btn_step_once = gr.Button("Step Once")
603
  btn_live = gr.Button("Denoise Live (streaming)")
604
 
605
+ gr.Markdown("### Noise Selection or Manual Indices")
606
  with gr.Row():
607
+ indices_csv = gr.Textbox(label="Positions to noise (auto-filled from selection, or enter like `0, 5, 10-20`)")
 
 
 
 
 
 
 
608
  with gr.Row():
609
  add_left = gr.Number(value=0, precision=0, label="Noise tokens to add at START")
610
  add_right = gr.Number(value=0, precision=0, label="Noise tokens to add at END")
611
+ btn_noise_selection = gr.Button("Noise Selection")
612
+ btn_apply_noise = gr.Button("Apply Noise (from indices)")
 
 
 
 
613
 
614
+ # --- Wiring ---
615
  btn_random.click(init_random, [src, seqlen, seed], [ids_state, current_text, status])
616
+
617
+ # Select in text auto-compute token indices into indices_csv
618
+ current_text.select(
619
+ capture_selection,
620
+ [current_text, seqlen, ids_state],
621
+ [indices_csv, status]
622
+ )
623
+
624
+ # “Noise Selection” just applies whatever is in indices_csv
625
+ btn_noise_selection.click(
626
+ apply_noise,
627
+ [src, ids_state, seqlen, indices_csv, 0, 0, seed],
628
+ [ids_state, current_text, status]
629
+ )
630
+
631
+ # Manual indices + prepend/append noise
632
+ btn_apply_noise.click(
633
+ apply_noise,
634
+ [src, ids_state, seqlen, indices_csv, add_left, add_right, seed],
635
+ [ids_state, current_text, status]
636
+ )
637
+
638
+ btn_step_once.click(
639
+ step_once,
640
+ [src, ids_state, update_mode, temperature, exclude_current],
641
+ [ids_state, current_text, status]
642
+ )
643
+
644
+ btn_live.click(
645
+ live_denoise,
646
+ [src, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current],
647
+ [ids_state, current_text, status],
648
+ show_progress=True
649
+ )
650
 
651
  demo.queue().launch()