the-puzzler commited on
Commit
22b6693
·
1 Parent(s): 6fb6b5d
Files changed (1) hide show
  1. app.py +248 -46
app.py CHANGED
@@ -119,7 +119,7 @@ class CNA(nn.Module):
119
  return self.proj(h)
120
 
121
  # -----------------------------
122
- # Helpers (trimmed to Strategy 1)
123
  # -----------------------------
124
  def infer_expansion_factor_from_state(state, embed_dim):
125
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
@@ -132,18 +132,84 @@ def infer_expansion_factor_from_state(state, embed_dim):
132
  return 4
133
 
134
  @torch.no_grad()
135
- def decode(ids, tokenizer, max_chars=220):
136
  s = tokenizer.decode(ids.tolist(), skip_special_tokens=True)
137
  s = s.replace("\n", " ")
138
  return s[:max_chars] + ("…" if len(s) > max_chars else "")
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  @torch.no_grad()
141
  def model_logits(model, x):
142
  return model(x)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # -----------------------------
145
  # Load checkpoint & build model
146
  # -----------------------------
 
 
 
147
  def load_model(ckpt_path: str):
148
  if not os.path.exists(ckpt_path):
149
  raise FileNotFoundError(
@@ -196,68 +262,204 @@ def load_model(ckpt_path: str):
196
  model.eval()
197
  return model, tokenizer, int(radius)
198
 
 
 
 
 
 
199
  # -----------------------------
200
- # Simplest sampling: Strategy 1
201
  # -----------------------------
202
  @torch.no_grad()
203
- def strategy_random_argmax(model, tokenizer, seqlen=100, steps=200, snap_every=20, seed=0, max_chars=220):
204
- random.seed(seed); torch.manual_seed(seed)
205
- V = tokenizer.vocab_size
206
- x = torch.randint(0, V, (1, seqlen))
207
- snaps = [(0, decode(x[0].cpu(), tokenizer, max_chars))]
208
- for t in range(1, steps + 1):
209
- pos = int(torch.randint(0, seqlen, (1,)))
210
- logits_pos = model_logits(model, x)[0, pos] # [V]
211
- x[0, pos] = int(torch.argmax(logits_pos).item())
212
- if (t % snap_every == 0) or (t == steps):
213
- snaps.append((t, decode(x[0].cpu(), tokenizer, max_chars)))
214
- return snaps
215
 
216
  # -----------------------------
217
- # Gradio UI
218
  # -----------------------------
219
- DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
 
 
 
 
 
 
220
 
221
- model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None}
222
- def ensure_model(ckpt_path):
223
- if model_cache["model"] is None or model_cache["ckpt"] != ckpt_path:
224
- m, tok, rad = load_model(ckpt_path)
225
- model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": ckpt_path})
 
226
 
227
- def run_demo(ckpt_path, seqlen, steps, snap_every, seed, max_chars):
228
  ensure_model(ckpt_path or DEFAULT_CKPT)
229
- snaps = strategy_random_argmax(
230
- model_cache["model"], model_cache["tokenizer"],
231
- seqlen=seqlen, steps=steps, snap_every=snap_every,
232
- seed=seed, max_chars=max_chars
233
- )
234
- # Pretty print log
235
- log = "\n".join([f"t={t:>3}: {txt}" for (t, txt) in snaps])
236
- final_text = snaps[-1][1] if snaps else ""
237
- return log, final_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- with gr.Blocks(title="CNA — Simple Sampling (Random Position • Argmax)") as demo:
 
 
 
240
  gr.Markdown(
241
  """
242
- # CNA — Simple Sampling (Strategy 1)
243
- This Space loads your checkpoint and runs the **random position → argmax** update for a fixed-length sequence.
244
- - Put your checkpoint at `ckpt_latest.pt` (repo root), or set a custom path below.
 
 
245
  """
246
  )
 
 
247
  with gr.Row():
248
- ckpt = gr.Textbox(value=DEFAULT_CKPT, label="Checkpoint path", placeholder="ckpt_latest.pt")
249
- with gr.Row():
250
  seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
251
- steps = gr.Slider(10, 1000, value=200, step=1, label="Steps")
252
- snap_every = gr.Slider(1, 200, value=20, step=1, label="Snapshot every N steps")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  with gr.Row():
254
- seed = gr.Slider(0, 10_000, value=0, step=1, label="Seed")
255
- max_chars = gr.Slider(32, 1000, value=220, step=1, label="Max chars per snapshot")
256
- run_btn = gr.Button("Run")
257
  with gr.Row():
258
- log_out = gr.Textbox(lines=18, label="Snapshots")
259
- final_out = gr.Textbox(lines=6, label="Final text (last snapshot)")
260
 
261
- run_btn.click(run_demo, [ckpt, seqlen, steps, snap_every, seed, max_chars], [log_out, final_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  demo.queue(concurrency_count=1).launch()
 
119
  return self.proj(h)
120
 
121
  # -----------------------------
122
+ # Helpers
123
  # -----------------------------
124
  def infer_expansion_factor_from_state(state, embed_dim):
125
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
 
132
  return 4
133
 
134
  @torch.no_grad()
135
+ def decode(ids, tokenizer, max_chars=1000):
136
  s = tokenizer.decode(ids.tolist(), skip_special_tokens=True)
137
  s = s.replace("\n", " ")
138
  return s[:max_chars] + ("…" if len(s) > max_chars else "")
139
 
140
+ def to_fixed_len_ids(text, tokenizer, seqlen, pad_mode="random", rnd=None):
141
+ """Encode text and force to length seqlen."""
142
+ if rnd is None:
143
+ rnd = random.Random()
144
+ ids = tokenizer.encode(text, add_special_tokens=False)
145
+ V = tokenizer.vocab_size
146
+ if len(ids) >= seqlen:
147
+ ids = ids[:seqlen]
148
+ else:
149
+ need = seqlen - len(ids)
150
+ if pad_mode == "eos" and tokenizer.eos_token_id is not None:
151
+ ids = ids + [tokenizer.eos_token_id] * need
152
+ else:
153
+ ids = ids + [rnd.randrange(V) for _ in range(need)]
154
+ return torch.tensor(ids, dtype=torch.long).unsqueeze(0)
155
+
156
  @torch.no_grad()
157
  def model_logits(model, x):
158
  return model(x)
159
 
160
+ def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right, seqlen, seed=0):
161
+ """Noise selected positions and optionally prepend/append random tokens."""
162
+ rnd = random.Random(seed)
163
+ V = tokenizer.vocab_size
164
+ x = x.clone()
165
+
166
+ # noise brush (indices like "0, 5, 6-10")
167
+ idxs = set()
168
+ if indices_csv.strip():
169
+ for part in indices_csv.split(","):
170
+ part = part.strip()
171
+ if not part:
172
+ continue
173
+ if "-" in part:
174
+ a, b = part.split("-", 1)
175
+ try:
176
+ a, b = int(a), int(b)
177
+ for j in range(min(a,b), max(a,b)+1):
178
+ idxs.add(j)
179
+ except:
180
+ continue
181
+ else:
182
+ try:
183
+ idxs.add(int(part))
184
+ except:
185
+ continue
186
+ for j in idxs:
187
+ if 0 <= j < seqlen:
188
+ x[0, j] = rnd.randrange(V)
189
+
190
+ # prepend/append random noise
191
+ if add_noise_left > 0:
192
+ prefix = torch.tensor([rnd.randrange(V) for _ in range(add_noise_left)], dtype=torch.long).unsqueeze(0)
193
+ x = torch.cat([prefix, x], dim=1)
194
+ if add_noise_right > 0:
195
+ suffix = torch.tensor([rnd.randrange(V) for _ in range(add_noise_right)], dtype=torch.long).unsqueeze(0)
196
+ x = torch.cat([x, suffix], dim=1)
197
+
198
+ # force length back to seqlen (trim or pad random)
199
+ if x.shape[1] > seqlen:
200
+ x = x[:, :seqlen]
201
+ elif x.shape[1] < seqlen:
202
+ need = seqlen - x.shape[1]
203
+ pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
204
+ x = torch.cat([x, pad], dim=1)
205
+ return x
206
+
207
  # -----------------------------
208
  # Load checkpoint & build model
209
  # -----------------------------
210
+ DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
211
+ model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None}
212
+
213
  def load_model(ckpt_path: str):
214
  if not os.path.exists(ckpt_path):
215
  raise FileNotFoundError(
 
262
  model.eval()
263
  return model, tokenizer, int(radius)
264
 
265
+ def ensure_model(ckpt_path):
266
+ if model_cache["model"] is None or model_cache["ckpt"] != ckpt_path:
267
+ m, tok, rad = load_model(ckpt_path)
268
+ model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": ckpt_path})
269
+
270
  # -----------------------------
271
+ # Strategy 1 core step
272
  # -----------------------------
273
  @torch.no_grad()
274
+ def step_strategy1(model, x):
275
+ """One iteration: choose random position, set to argmax(logits)."""
276
+ S = x.shape[1]
277
+ pos = int(torch.randint(0, S, (1,)).item())
278
+ logits_pos = model_logits(model, x)[0, pos] # [V]
279
+ x[0, pos] = int(torch.argmax(logits_pos).item())
280
+ return x
 
 
 
 
 
281
 
282
  # -----------------------------
283
+ # Gradio logic
284
  # -----------------------------
285
+ def init_random(ckpt_path, seqlen, seed):
286
+ ensure_model(ckpt_path or DEFAULT_CKPT)
287
+ random.seed(seed); torch.manual_seed(seed)
288
+ V = model_cache["tokenizer"].vocab_size
289
+ x = torch.randint(0, V, (1, seqlen))
290
+ txt = decode(x[0], model_cache["tokenizer"])
291
+ return x.tolist(), txt, f"Initialized random sequence (len={seqlen})"
292
 
293
+ def init_from_text(ckpt_path, seqlen, text, seed, pad_mode):
294
+ ensure_model(ckpt_path or DEFAULT_CKPT)
295
+ rnd = random.Random(seed)
296
+ x = to_fixed_len_ids(text or "", model_cache["tokenizer"], seqlen, pad_mode=pad_mode, rnd=rnd)
297
+ txt = decode(x[0], model_cache["tokenizer"])
298
+ return x.tolist(), txt, "Initialized from text"
299
 
300
+ def append_text(ckpt_path, state_ids, seqlen, text_to_append, seed):
301
  ensure_model(ckpt_path or DEFAULT_CKPT)
302
+ tok = model_cache["tokenizer"]
303
+ rnd = random.Random(seed)
304
+ if state_ids is None or len(state_ids) == 0:
305
+ x = to_fixed_len_ids(text_to_append or "", tok, seqlen, pad_mode="random", rnd=rnd)
306
+ else:
307
+ x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
308
+ # append
309
+ extra = tok.encode(text_to_append or "", add_special_tokens=False)
310
+ x = torch.cat([x, torch.tensor(extra, dtype=torch.long).unsqueeze(0)], dim=1)
311
+ # force length
312
+ if x.shape[1] > seqlen:
313
+ x = x[:, :seqlen]
314
+ elif x.shape[1] < seqlen:
315
+ need = seqlen - x.shape[1]
316
+ V = tok.vocab_size
317
+ pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
318
+ x = torch.cat([x, pad], dim=1)
319
+ txt = decode(x[0], tok)
320
+ return x.tolist(), txt, "Appended text and resized to target length"
321
+
322
+ def apply_noise(ckpt_path, state_ids, seqlen, indices_csv, add_left, add_right, seed):
323
+ ensure_model(ckpt_path or DEFAULT_CKPT)
324
+ tok = model_cache["tokenizer"]
325
+ if state_ids is None or len(state_ids) == 0:
326
+ # create an empty base (random) then apply ops
327
+ V = tok.vocab_size
328
+ base = torch.randint(0, V, (1, seqlen))
329
+ else:
330
+ base = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
331
+ x = apply_noise_ops(base, tok, indices_csv, int(add_left), int(add_right), seqlen, seed=seed)
332
+ txt = decode(x[0], tok)
333
+ return x.tolist(), txt, "Applied noise brush / prepend / append"
334
+
335
+ def step_once(ckpt_path, state_ids):
336
+ ensure_model(ckpt_path or DEFAULT_CKPT)
337
+ tok = model_cache["tokenizer"]
338
+ if state_ids is None or len(state_ids) == 0:
339
+ return None, "", "No sequence to step — initialize first."
340
+ x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
341
+ x = step_strategy1(model_cache["model"], x)
342
+ txt = decode(x[0], tok)
343
+ return x.tolist(), txt, "Stepped 1 iteration"
344
+
345
+ def live_denoise(ckpt_path, state_ids, steps, snap_every, seed):
346
+ """
347
+ Generator for live updates. Yields (ids, text, status) every snap_every steps and on completion.
348
+ """
349
+ ensure_model(ckpt_path or DEFAULT_CKPT)
350
+ tok = model_cache["tokenizer"]
351
+ if state_ids is None or len(state_ids) == 0:
352
+ return
353
+ random.seed(seed); torch.manual_seed(seed)
354
+ x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
355
+ total = int(steps)
356
+ snap = max(1, int(snap_every))
357
+ for t in range(1, total + 1):
358
+ x = step_strategy1(model_cache["model"], x)
359
+ if (t % snap == 0) or (t == total):
360
+ txt = decode(x[0], tok)
361
+ yield x.tolist(), txt, f"Live denoise… step {t}/{total}"
362
+ # final yield already done in loop
363
 
364
+ # -----------------------------
365
+ # UI
366
+ # -----------------------------
367
+ with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
368
  gr.Markdown(
369
  """
370
+ # CNA — Interactive Denoising (Strategy 1)
371
+ - **Mode 1:** Randomize then watch it **denoise live** (random-position → argmax).
372
+ - **Mode 2:** Initialize from **your text**.
373
+ - **Noise Brush:** Select positions (e.g., `0, 5, 10-20`), and/or add random noise tokens at **start**/**end**.
374
+ - **Append:** Add your text to the current sequence.
375
  """
376
  )
377
+
378
+ # Global settings
379
  with gr.Row():
380
+ ckpt = gr.Textbox(value=DEFAULT_CKPT, label="Checkpoint path")
 
381
  seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
382
+ seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
383
+
384
+ # Hidden state (ids list)
385
+ ids_state = gr.State(value=None)
386
+
387
+ # Displays
388
+ with gr.Row():
389
+ current_text = gr.Textbox(lines=8, label="Current text", interactive=False)
390
+ status = gr.Markdown("Ready.")
391
+
392
+ gr.Markdown("## Mode 1 · Random → Denoise Live")
393
+ with gr.Row():
394
+ btn_random = gr.Button("Initialize Random")
395
+ steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
396
+ snap_every = gr.Slider(1, 100, value=5, step=1, label="Update every K steps")
397
+ with gr.Row():
398
+ btn_step_once = gr.Button("Step Once")
399
+ btn_live = gr.Button("Denoise Live (streaming)")
400
+
401
+ gr.Markdown("## Mode 2 · Initialize From Your Text")
402
  with gr.Row():
403
+ init_text = gr.Textbox(lines=4, label="Initial text")
 
 
404
  with gr.Row():
405
+ pad_mode = gr.Radio(choices=["random", "eos"], value="random", label="Pad mode (if text shorter than S)")
406
+ btn_init_text = gr.Button("Initialize From Text")
407
 
408
+ gr.Markdown("## Noise Brush · Select Positions + Prepend/Append Noise")
409
+ with gr.Row():
410
+ indices_csv = gr.Textbox(label="Positions to noise (e.g., 0, 5, 10-20)", placeholder="Leave empty to skip")
411
+ with gr.Row():
412
+ add_left = gr.Number(value=0, precision=0, label="Noise tokens to add at START")
413
+ add_right = gr.Number(value=0, precision=0, label="Noise tokens to add at END")
414
+ btn_apply_noise = gr.Button("Apply Noise Brush / Prepend / Append")
415
+
416
+ gr.Markdown("## Append Text")
417
+ with gr.Row():
418
+ append_box = gr.Textbox(lines=3, label="Text to append")
419
+ btn_append = gr.Button("Append to Current Sequence")
420
+
421
+ # --- Wiring ---
422
+ # Random init
423
+ out = btn_random.click(
424
+ init_random,
425
+ [ckpt, seqlen, seed],
426
+ [ids_state, current_text, status]
427
+ )
428
+
429
+ # Init from text
430
+ btn_init_text.click(
431
+ init_from_text,
432
+ [ckpt, seqlen, init_text, seed, pad_mode],
433
+ [ids_state, current_text, status]
434
+ )
435
+
436
+ # Apply noise
437
+ btn_apply_noise.click(
438
+ apply_noise,
439
+ [ckpt, ids_state, seqlen, indices_csv, add_left, add_right, seed],
440
+ [ids_state, current_text, status]
441
+ )
442
+
443
+ # Append text
444
+ btn_append.click(
445
+ append_text,
446
+ [ckpt, ids_state, seqlen, append_box, seed],
447
+ [ids_state, current_text, status]
448
+ )
449
+
450
+ # Single step
451
+ btn_step_once.click(
452
+ step_once,
453
+ [ckpt, ids_state],
454
+ [ids_state, current_text, status]
455
+ )
456
+
457
+ # Live denoise (streaming)
458
+ btn_live.click(
459
+ live_denoise,
460
+ [ckpt, ids_state, steps, snap_every, seed],
461
+ [ids_state, current_text, status],
462
+ show_progress=True
463
+ )
464
 
465
  demo.queue(concurrency_count=1).launch()